| """ |
| A main training script. |
| """ |
|
|
|
|
| |
| import warnings |
| warnings.filterwarnings('ignore') |
| import logging |
| import os |
| from collections import OrderedDict |
| import torch |
| import uniperceiver.utils.comm as comm |
| from uniperceiver.config import get_cfg, CfgNode |
| from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments |
|
|
| |
| from uniperceiver.engine import hooks |
| from uniperceiver.modeling import add_config |
| from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile |
| try: |
| import deepspeed |
| DEEPSPEED_INSTALLED = True |
| except: |
| DEEPSPEED_INSTALLED = False |
|
|
| import copy |
|
|
| def add_data_prefix(cfg): |
| |
| data_dir = os.getenv("DATA_PATH", None) |
| mapping_list = [ |
| [cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]], |
| [cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]], |
| [cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]], |
| [cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]], |
| [cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]], |
| [cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]], |
| [cfg.MODEL, 'WEIGHTS', ['MODEL',]], |
| ] |
| whitelist = ["BERT", "CLIP", "CLIP_CAPTION"] |
| if data_dir: |
| for node, attr ,_ in mapping_list: |
| if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist: |
| setattr(node, attr, os.path.join(data_dir, node[attr])) |
| for task in cfg.TASKS: |
| for _, item, key_list in mapping_list: |
| config_tmp = task |
| for key in key_list: |
| if key in config_tmp: |
| config_tmp = config_tmp[key] |
| if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist: |
| config_tmp[item] = os.path.join(data_dir, config_tmp[item]) |
|
|
| mapping_list = [ |
| ['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]], |
| ] |
| if cfg.SHARED_TARGETS is None: |
| cfg.SHARED_TARGETS = [] |
| for share_targets in cfg.SHARED_TARGETS: |
| for _, item, key_list in mapping_list: |
| config_tmp = share_targets |
| for key in key_list: |
| config_tmp = config_tmp[key] |
| if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith( |
| '/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith( |
| 'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist: |
| config_tmp[item] = os.path.join(data_dir, config_tmp[item]) |
|
|
|
|
|
|
| def add_default_setting_for_multitask_config(cfg): |
| |
|
|
| tasks_config_temp = cfg.TASKS |
| num_tasks = len(tasks_config_temp) |
| cfg.pop('TASKS', None) |
|
|
| cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)] |
|
|
| for i, task_config in enumerate(tasks_config_temp): |
| cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config)) |
| cfg.TASKS[i] = cfg.TASKS[i].to_dict_object() |
| pass |
|
|
|
|
| def setup(args): |
| """ |
| Create configs and perform basic setups. |
| """ |
| cfg = get_cfg() |
| tmp_cfg = cfg.load_from_file_tmp(args.config_file) |
| add_config(cfg, tmp_cfg) |
|
|
| cfg.merge_from_file(args.config_file) |
| add_data_prefix(cfg) |
|
|
| cfg.merge_from_list(args.opts) |
| |
| add_default_setting_for_multitask_config(cfg) |
| cfg.freeze() |
| default_setup(cfg, args) |
| return cfg |
|
|
| def main(args): |
| cfg = setup(args) |
|
|
| """ |
| If you'd like to do anything fancier than the standard training logic, |
| consider writing your own training loop (see plain_train_net.py) or |
| subclassing the trainer. |
| """ |
| trainer = build_engine(cfg) |
| trainer.resume_or_load(resume=args.resume) |
| trainer.cast_layers() |
|
|
| if args.eval_only: |
| print('---------------------------') |
| print('eval model only') |
| print('---------------------------\n') |
| res = None |
| if trainer.val_data_loader is not None: |
|
|
| if trainer.model_ema is not None and args.eval_ema: |
| if comm.is_main_process(): |
| print('using ema model for evaluation') |
| res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1) |
| else: |
| if args.eval_ema and comm.is_main_process(): |
| print('no ema model exists! using master model for evaluation') |
| res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1) |
|
|
| if comm.is_main_process(): |
| print(res) |
|
|
| if trainer.test_data_loader is not None: |
| if trainer.model_ema is not None and args.eval_ema: |
| if comm.is_main_process(): |
| print('using ema model for evaluation') |
| res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1) |
| else: |
| if args.eval_ema and comm.is_main_process(): |
| print('no ema model exists! using master model for evaluation') |
| res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1) |
| if comm.is_main_process(): |
| print(res) |
| return res |
|
|
| return trainer.train() |
|
|
| def get_args_parser(): |
| parser = default_argument_parser() |
| if DEEPSPEED_INSTALLED: |
| parser = deepspeed.add_config_arguments(parser) |
| parser = add_moe_arguments(parser) |
|
|
| parser.add_argument('--init_method', default='slurm', type=str) |
| parser.add_argument('--local_rank', default=0, type=int) |
| parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema") |
| args = parser.parse_args() |
|
|
| return args |
|
|
| if __name__ == "__main__": |
| args = get_args_parser() |
| print("Command Line Args:", args) |
| if args.init_method == 'slurm': |
| |
| check_dist_portfile() |
| init_distributed_mode(args) |
| main(args) |
| elif args.init_method == 'pytorch': |
| main(args) |
| else: |
| |
| print('using \'mp.spawn\' for dist init! ') |
| launch( |
| main, |
| args.num_gpus, |
| num_machines=args.num_machines, |
| machine_rank=args.machine_rank, |
| dist_url=args.dist_url, |
| args=(args,), |
| ) |
|
|