| |
| |
| |
| |
| |
| |
| |
|
|
| """Train a GAN using the techniques described in the paper |
| "Diffusion-GAN: Training GANs with Diffusion".""" |
| """Modified on the StyleGAN2-ADA Pytorch implementation. """ |
|
|
| import os |
| import click |
| import re |
| import json |
| import tempfile |
| import torch |
| import dnnlib |
|
|
| from training import training_loop |
| from metrics import metric_main |
| from torch_utils import training_stats |
| from torch_utils import custom_ops |
|
|
| |
|
|
| class UserError(Exception): |
| pass |
|
|
| |
|
|
| def setup_training_loop_kwargs( |
| |
| gpus = None, |
| snap = None, |
| metrics = None, |
| seed = None, |
| |
| |
| data = None, |
| cond = None, |
| subset = None, |
| mirror = None, |
| |
| |
| cfg = None, |
| gamma = None, |
| kimg = None, |
| batch = None, |
| |
| |
| beta_schedule = None, |
| beta_start = None, |
| beta_end = None, |
| t_min = None, |
| t_max = None, |
| noise_sd = None, |
| ts_dist = None, |
| target = None, |
| ada_kimg = None, |
| aug = None, |
| ada_maxp = None, |
| |
| |
| resume = None, |
| freezed = None, |
| |
| |
| fp32 = None, |
| nhwc = None, |
| allow_tf32 = None, |
| nobench = None, |
| workers = None, |
| |
| |
| exp_id = None, |
| ): |
| args = dnnlib.EasyDict() |
|
|
| |
| |
| |
|
|
| if gpus is None: |
| gpus = 1 |
| assert isinstance(gpus, int) |
| if not (gpus >= 1 and gpus & (gpus - 1) == 0): |
| raise UserError('--gpus must be a power of two') |
| args.num_gpus = gpus |
|
|
| if snap is None: |
| snap = 50 |
| assert isinstance(snap, int) |
| if snap < 1: |
| raise UserError('--snap must be at least 1') |
| args.image_snapshot_ticks = snap |
| args.network_snapshot_ticks = snap |
|
|
| if metrics is None: |
| metrics = ['fid50k_full'] |
| assert isinstance(metrics, list) |
| if not all(metric_main.is_valid_metric(metric) for metric in metrics): |
| raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) |
| args.metrics = metrics |
|
|
| if seed is None: |
| seed = 0 |
| assert isinstance(seed, int) |
| args.random_seed = seed |
|
|
| |
| |
| |
|
|
| assert data is not None |
| assert isinstance(data, str) |
| args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) |
| args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) |
| try: |
| training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) |
| args.training_set_kwargs.resolution = training_set.resolution |
| args.training_set_kwargs.use_labels = training_set.has_labels |
| args.training_set_kwargs.max_size = len(training_set) |
| desc = training_set.name |
| del training_set |
| except IOError as err: |
| raise UserError(f'--data: {err}') |
|
|
| if cond is None: |
| cond = False |
| assert isinstance(cond, bool) |
| if cond: |
| if not args.training_set_kwargs.use_labels: |
| raise UserError('--cond=True requires labels specified in dataset.json') |
| desc += '-cond' |
| else: |
| args.training_set_kwargs.use_labels = False |
|
|
| if subset is not None: |
| assert isinstance(subset, int) |
| if not 1 <= subset <= args.training_set_kwargs.max_size: |
| raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}') |
| desc += f'-subset{subset}' |
| if subset < args.training_set_kwargs.max_size: |
| args.training_set_kwargs.max_size = subset |
| args.training_set_kwargs.random_seed = args.random_seed |
|
|
| if mirror is None: |
| mirror = False |
| assert isinstance(mirror, bool) |
| if mirror: |
| desc += '-mirror' |
| args.training_set_kwargs.xflip = True |
|
|
| |
| |
| |
|
|
| if cfg is None: |
| cfg = 'auto' |
| assert isinstance(cfg, str) |
| desc += f'-{cfg}' |
|
|
| cfg_specs = { |
| 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), |
| 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), |
| 'paper256': dict(ref_gpus=4, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), |
| 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), |
| 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), |
| 'cifar': dict(ref_gpus=4, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), |
| 'stl10': dict(ref_gpus=4, kimg=50000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.05, ema=500, ramp=0.05, map=2), |
| } |
|
|
| assert cfg in cfg_specs |
| spec = dnnlib.EasyDict(cfg_specs[cfg]) |
| if cfg == 'auto': |
| desc += f'{gpus:d}' |
| spec.ref_gpus = gpus |
| res = args.training_set_kwargs.resolution |
| spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) |
| spec.mbstd = min(spec.mb // gpus, 4) |
| spec.fmaps = 1 if res >= 512 else 0.5 |
| spec.lrate = 0.002 if res >= 1024 else 0.0025 |
| spec.gamma = 0.0002 * (res ** 2) / spec.mb |
| spec.ema = spec.mb * 10 / 32 |
|
|
| args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) |
| args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) |
| args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768) |
| args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 |
| args.G_kwargs.mapping_kwargs.num_layers = spec.map |
| args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 |
| args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 |
| args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd |
|
|
| args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) |
| args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) |
| args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) |
|
|
| args.total_kimg = spec.kimg |
| args.batch_size = spec.mb |
| args.batch_gpu = spec.mb // spec.ref_gpus |
| args.ema_kimg = spec.ema |
| args.ema_rampup = spec.ramp |
|
|
| if cfg == 'cifar': |
| args.loss_kwargs.pl_weight = 0 |
| args.loss_kwargs.style_mixing_prob = 0 |
| args.D_kwargs.architecture = 'orig' |
|
|
| if gamma is not None: |
| assert isinstance(gamma, float) |
| if not gamma >= 0: |
| raise UserError('--gamma must be non-negative') |
| desc += f'-gamma{gamma:g}' |
| args.loss_kwargs.r1_gamma = gamma |
|
|
| if kimg is not None: |
| assert isinstance(kimg, int) |
| if not kimg >= 1: |
| raise UserError('--kimg must be at least 1') |
| |
| args.total_kimg = kimg |
|
|
| if batch is not None: |
| assert isinstance(batch, int) |
| if not (batch >= 1 and batch % gpus == 0): |
| raise UserError('--batch must be at least 1 and divisible by --gpus') |
| desc += f'-batch{batch}' |
| args.batch_size = batch |
| args.batch_gpu = batch // gpus |
|
|
| |
| |
| |
|
|
| if aug is None: |
| aug = 'no' |
|
|
| if target is not None: |
| assert isinstance(target, float) |
| if not 0 <= target <= 1: |
| raise UserError('--target must be between 0 and 1') |
| desc += f'-target{target:g}' |
| args.ada_target = target |
|
|
| if ada_kimg is None: |
| args.ada_kimg = 100 |
| else: |
| desc += f'-ada_kimg{ada_kimg}' |
| args.ada_kimg = ada_kimg |
|
|
| diffusion_specs = dict(beta_schedule=beta_schedule, beta_start=beta_start, beta_end=beta_end, |
| t_min=t_min, t_max=t_max, noise_std=noise_sd, |
| aug=aug, ada_maxp=ada_maxp, ts_dist=ts_dist) |
|
|
| desc += f"-ts_dist-{ts_dist}-image_aug{aug}-noise_sd{noise_sd}" |
| args.diffusion_kwargs = dnnlib.EasyDict(class_name='training.diffusion.Diffusion', **diffusion_specs) |
|
|
| |
| |
| |
|
|
| resume_specs = { |
| 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', |
| 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', |
| 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', |
| 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', |
| 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', |
| } |
|
|
| assert resume is None or isinstance(resume, str) |
| if resume is None: |
| resume = 'noresume' |
| elif resume == 'noresume': |
| desc += '-noresume' |
| elif resume in resume_specs: |
| desc += f'-resume{resume}' |
| args.resume_pkl = resume_specs[resume] |
| else: |
| desc += '-resumecustom' |
| args.resume_pkl = resume |
|
|
| if resume != 'noresume': |
| args.ema_rampup = None |
|
|
| if freezed is not None: |
| assert isinstance(freezed, int) |
| if not freezed >= 0: |
| raise UserError('--freezed must be non-negative') |
| desc += f'-freezed{freezed:d}' |
| args.D_kwargs.block_kwargs.freeze_layers = freezed |
|
|
| if exp_id is not None: |
| desc += f'-{exp_id}' |
|
|
| |
| |
| |
|
|
| if fp32 is None: |
| fp32 = False |
| assert isinstance(fp32, bool) |
| if fp32: |
| args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 |
| args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None |
|
|
| if nhwc is None: |
| nhwc = False |
| assert isinstance(nhwc, bool) |
| if nhwc: |
| args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True |
|
|
| if nobench is None: |
| nobench = False |
| assert isinstance(nobench, bool) |
| if nobench: |
| args.cudnn_benchmark = False |
|
|
| if allow_tf32 is None: |
| allow_tf32 = False |
| assert isinstance(allow_tf32, bool) |
| if allow_tf32: |
| args.allow_tf32 = True |
|
|
| if workers is not None: |
| assert isinstance(workers, int) |
| if not workers >= 1: |
| raise UserError('--workers must be at least 1') |
| args.data_loader_kwargs.num_workers = workers |
|
|
| return desc, args |
|
|
| |
|
|
| def subprocess_fn(rank, args, temp_dir): |
| dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True) |
|
|
| |
| if args.num_gpus > 1: |
| init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) |
| if os.name == 'nt': |
| init_method = 'file:///' + init_file.replace('\\', '/') |
| torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) |
| else: |
| init_method = f'file://{init_file}' |
| torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) |
|
|
| |
| sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None |
| training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) |
| if rank != 0: |
| custom_ops.verbosity = 'none' |
|
|
| |
| training_loop.training_loop(rank=rank, **args) |
|
|
| |
|
|
| class CommaSeparatedList(click.ParamType): |
| name = 'list' |
|
|
| def convert(self, value, param, ctx): |
| _ = param, ctx |
| if value is None or value.lower() == 'none' or value == '': |
| return [] |
| return value.split(',') |
|
|
| |
|
|
| @click.command() |
| @click.pass_context |
|
|
| |
| @click.option('--outdir', help='Where to save the results', required=True, metavar='DIR') |
| @click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT') |
| @click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT') |
| @click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList()) |
| @click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT') |
| @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) |
|
|
| |
| @click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True) |
| @click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL') |
| @click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT') |
| @click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL', default=True) |
|
|
| |
| @click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'stl10'])) |
| @click.option('--gamma', help='Override R1 gamma', type=float) |
| @click.option('--kimg', help='Override training duration', type=int, metavar='INT') |
| @click.option('--batch', help='Override batch size', type=int, metavar='INT') |
|
|
| |
| @click.option('--beta_schedule', help='Forward diffusion beta schedule (we use linear always)', type=str, default='linear') |
| @click.option('--beta_start', help='Forward diffusion process beta_start', type=float, default=1e-4) |
| @click.option('--beta_end', help='Forward diffusion process beta_end', type=float, default=2e-2) |
| @click.option('--t_min', help='Minimum # of timesteps for adaptively modification', type=int, default=10) |
| @click.option('--t_max', help='Maximum # of timesteps for adaptively modification', type=int, default=1000) |
| @click.option('--noise_sd', help='Diffusion noise standard deviation', type=float, default=0.05) |
| @click.option('--ts_dist', help='Diffusion t sampling way', type=click.Choice(['priority', 'uniform']), default='priority') |
| @click.option('--target', help='Discriminator target value', type=float, default=0.6) |
| @click.option('--ada_kimg', help='# kimgs needed to push diffusion to maximum level', type=int, default=100) |
|
|
| |
| @click.option('--aug', help='Common image augmentation mode', type=click.Choice(['no', 'ada', 'diff']), default='no') |
| @click.option('--ada_maxp', help='The maximum value of p if adding ADA augmentation', type=float, default=0.25) |
|
|
| |
| @click.option('--resume', help='Resume training [default: noresume]', metavar='PKL') |
| @click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT') |
|
|
| |
| @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL') |
| @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL') |
| @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL') |
| @click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL') |
| @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT') |
|
|
| @click.option('--exp_id', help='String to include in result dir name', metavar='STR', type=str) |
|
|
| def main(ctx, outdir, dry_run, **config_kwargs): |
| """Train a GAN using the techniques described in the paper |
| "Training Generative Adversarial Networks with Limited Data". |
| |
| Examples: |
| |
| \b |
| # Train with custom dataset using 1 GPU. |
| python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 |
| |
| \b |
| # Train class-conditional CIFAR-10 using 2 GPUs. |
| python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\ |
| --gpus=2 --cfg=cifar --cond=1 |
| |
| \b |
| # Transfer learn MetFaces from FFHQ using 4 GPUs. |
| python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\ |
| --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 |
| |
| \b |
| # Reproduce original StyleGAN2 config F. |
| python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\ |
| --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug |
| |
| \b |
| Base configs (--cfg): |
| auto Automatically select reasonable defaults based on resolution |
| and GPU count. Good starting point for new datasets. |
| stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. |
| paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. |
| paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. |
| paper1024 Reproduce results for MetFaces at 1024x1024. |
| cifar Reproduce results for CIFAR-10 at 32x32. |
| |
| \b |
| Transfer learning source networks (--resume): |
| ffhq256 FFHQ trained at 256x256 resolution. |
| ffhq512 FFHQ trained at 512x512 resolution. |
| ffhq1024 FFHQ trained at 1024x1024 resolution. |
| celebahq256 CelebA-HQ trained at 256x256 resolution. |
| lsundog256 LSUN Dog trained at 256x256 resolution. |
| <PATH or URL> Custom network pickle. |
| """ |
| dnnlib.util.Logger(should_flush=True) |
|
|
| |
| try: |
| run_desc, args = setup_training_loop_kwargs(**config_kwargs) |
| except UserError as err: |
| ctx.fail(err) |
|
|
| |
| prev_run_dirs = [] |
| if os.path.isdir(outdir): |
| prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] |
|
|
| matching_dirs = [re.fullmatch(r'\d{5}' + f'-{run_desc}', x) for x in prev_run_dirs if |
| re.fullmatch(r'\d{5}' + f'-{run_desc}', x) is not None] |
| if len(matching_dirs) > 0: |
| assert len(matching_dirs) == 1, f'Multiple directories found for resuming: {matching_dirs}' |
| run_dir = os.path.join(outdir, matching_dirs[0].group()) |
| else: |
| prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] |
| prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] |
| cur_run_id = max(prev_run_ids, default=-1) + 1 |
| run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}') |
| assert not os.path.exists(run_dir) |
| args.run_dir = run_dir |
|
|
| |
| print() |
| print('Training options:') |
| print(json.dumps(args, indent=2)) |
| print() |
| print(f'Output directory: {args.run_dir}') |
| print(f'Training data: {args.training_set_kwargs.path}') |
| print(f'Training duration: {args.total_kimg} kimg') |
| print(f'Number of GPUs: {args.num_gpus}') |
| print(f'Number of images: {args.training_set_kwargs.max_size}') |
| print(f'Image resolution: {args.training_set_kwargs.resolution}') |
| print(f'Conditional model: {args.training_set_kwargs.use_labels}') |
| print(f'Dataset x-flips: {args.training_set_kwargs.xflip}') |
| print() |
|
|
| |
| if dry_run: |
| print('Dry run; exiting.') |
| return |
|
|
| |
| print('Creating output directory...') |
| os.makedirs(args.run_dir, exist_ok=True) |
| with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f: |
| json.dump(args, f, indent=2) |
|
|
| |
| print('Launching processes...') |
| torch.multiprocessing.set_start_method('spawn') |
| with tempfile.TemporaryDirectory() as temp_dir: |
| if args.num_gpus == 1: |
| subprocess_fn(rank=0, args=args, temp_dir=temp_dir) |
| else: |
| torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) |
|
|
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|
| |
|
|