| |
|
|
| import argparse |
| import os, shutil, sys |
| import time |
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| root_path = os.path.abspath('.') |
| sys.path.append(root_path) |
| from opt import opt |
|
|
|
|
| def storage_manage(): |
| if not os.path.exists("runs_last/"): |
| os.makedirs("runs_last/") |
| |
| |
| new_address = "runs_last/"+str(int(time.time()))+"/" |
| shutil.copytree("runs/", new_address) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--auto_resume_closest', action='store_true') |
| parser.add_argument('--auto_resume_best', action='store_true') |
| parser.add_argument('--pretrained_path', type = str, default="") |
|
|
| global args |
| args = parser.parse_args() |
|
|
|
|
| if args.auto_resume_closest and args.auto_resume_best: |
| print("you could only resume either nearest or best, not both") |
| os._exit(0) |
|
|
|
|
| |
| if not args.auto_resume_closest and not args.auto_resume_best: |
| |
| if os.path.exists("./runs"): |
| storage_manage() |
| shutil.rmtree("./runs") |
|
|
|
|
| def folder_prepare(): |
| def _make_folder(folder_name): |
| if not os.path.exists(folder_name): |
| os.makedirs(folder_name) |
|
|
| def _delete_and_make_folder(folder_name): |
| if os.path.exists(folder_name): |
| shutil.rmtree(folder_name) |
| os.makedirs(folder_name) |
| |
| |
| make_folder_name_lists = ["saved_models/", "saved_models/checkpoints/", "datasets/"] |
| delete_and_make_folder_name_lists = [] |
|
|
| for folder_name in make_folder_name_lists: |
| _make_folder(folder_name) |
|
|
| for folder_name in delete_and_make_folder_name_lists: |
| _delete_and_make_folder(folder_name) |
|
|
| |
|
|
| def process(options): |
| print(args) |
| start = time.time() |
|
|
| |
| if options['architecture'] == "ESRNET": |
| from train_esrnet import train_esrnet |
| obj = train_esrnet(options, args) |
| elif options['architecture'] == "ESRGAN": |
| from train_esrgan import train_esrgan |
| obj = train_esrgan(options, args) |
| elif options['architecture'] == "GRL": |
| from train_grl import train_grl |
| obj = train_grl(options, args) |
| elif options['architecture'] == "GRLGAN": |
| from train_grlgan import train_grlgan |
| obj = train_grlgan(options, args) |
| elif options['architecture'] == "CUNET": |
| from train_cunet import train_cunet |
| obj = train_cunet(options, args) |
| elif options['architecture'] == "CUGAN": |
| from train_cugan import train_cugan |
| obj = train_cugan(options, args) |
| else: |
| raise NotImplementedError("This is not a supported model architecture") |
|
|
|
|
| obj.run() |
|
|
| total_time = time.time() - start |
| print("All programs spent {} hour {} min {} s".format(str(total_time//3600), str((total_time%3600)//60), str(total_time%3600))) |
|
|
|
|
| def main(): |
| parse_args() |
|
|
| folder_prepare() |
| process(opt) |
|
|
| if __name__ == "__main__": |
| main() |