| | """ |
| | This file runs the main training/val loop |
| | """ |
| | import os |
| | import json |
| | import math |
| | import sys |
| | import pprint |
| | import torch |
| | from argparse import Namespace |
| |
|
| | sys.path.append(".") |
| | sys.path.append("..") |
| |
|
| | from options.train_options import TrainOptions |
| | from training.coach import Coach |
| |
|
| |
|
| | def main(): |
| | opts = TrainOptions().parse() |
| | previous_train_ckpt = None |
| | if opts.resume_training_from_ckpt: |
| | opts, previous_train_ckpt = load_train_checkpoint(opts) |
| | else: |
| | setup_progressive_steps(opts) |
| | create_initial_experiment_dir(opts) |
| |
|
| | coach = Coach(opts, previous_train_ckpt) |
| | coach.train() |
| |
|
| |
|
| | def load_train_checkpoint(opts): |
| | train_ckpt_path = opts.resume_training_from_ckpt |
| | previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu') |
| | new_opts_dict = vars(opts) |
| | opts = previous_train_ckpt['opts'] |
| | opts['resume_training_from_ckpt'] = train_ckpt_path |
| | update_new_configs(opts, new_opts_dict) |
| | pprint.pprint(opts) |
| | opts = Namespace(**opts) |
| | if opts.sub_exp_dir is not None: |
| | sub_exp_dir = opts.sub_exp_dir |
| | opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir) |
| | create_initial_experiment_dir(opts) |
| | return opts, previous_train_ckpt |
| |
|
| |
|
| | def setup_progressive_steps(opts): |
| | log_size = int(math.log(opts.stylegan_size, 2)) |
| | num_style_layers = 2*log_size - 2 |
| | num_deltas = num_style_layers - 1 |
| | if opts.progressive_start is not None: |
| | opts.progressive_steps = [0] |
| | next_progressive_step = opts.progressive_start |
| | for i in range(num_deltas): |
| | opts.progressive_steps.append(next_progressive_step) |
| | next_progressive_step += opts.progressive_step_every |
| |
|
| | assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ |
| | "Invalid progressive training input" |
| |
|
| |
|
| | def is_valid_progressive_steps(opts, num_style_layers): |
| | return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 |
| |
|
| |
|
| | def create_initial_experiment_dir(opts): |
| | if os.path.exists(opts.exp_dir): |
| | raise Exception('Oops... {} already exists'.format(opts.exp_dir)) |
| | os.makedirs(opts.exp_dir) |
| |
|
| | opts_dict = vars(opts) |
| | pprint.pprint(opts_dict) |
| | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: |
| | json.dump(opts_dict, f, indent=4, sort_keys=True) |
| |
|
| |
|
| | def update_new_configs(ckpt_opts, new_opts): |
| | for k, v in new_opts.items(): |
| | if k not in ckpt_opts: |
| | ckpt_opts[k] = v |
| | if new_opts['update_param_list']: |
| | for param in new_opts['update_param_list']: |
| | ckpt_opts[param] = new_opts[param] |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|