| import os |
| import random |
| import shutil |
| from argparse import ArgumentParser |
|
|
| import numpy as np |
| import torch |
| import yaml |
|
|
|
|
| def clean_dir(path): |
| if os.path.exists(path): |
| shutil.rmtree(path) |
|
|
|
|
| def get_latest_ckpt_step(load_path): |
| saved_steps = [ |
| int(os.path.splitext(path)[0].split("-")[-1]) |
| for path in os.listdir(load_path) |
| if path.endswith(".pt") |
| ] |
| latest_step = -1 if len(saved_steps) == 0 else max(saved_steps) |
| return latest_step |
|
|
|
|
| def set_random_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def load_cfg(cfg_path: str, parser: ArgumentParser) -> ArgumentParser: |
| with open(cfg_path, "r", encoding="utf-8") as file: |
| cfg: dict = yaml.safe_load(file) |
| for key, value in cfg.items(): |
| if value is None: |
| raise ValueError("'None' is not a supported value in the config file") |
| if isinstance(value, bool): |
| parser.add_argument(f"--{key}", action="store_true", default=value) |
| else: |
| parser.add_argument(f"--{key}", type=type(value), default=value) |
| return parser |
|
|
|
|
| def save_cfg(path: str, args, mode="w"): |
| with open(path, mode=mode, encoding="utf-8") as file: |
| print("#################### Training Config ####################", file=file) |
| yaml.dump(vars(args), file, default_flow_style=False) |
|
|