Spaces:
Runtime error
Runtime error
| # coding: utf-8 | |
| __author__ = "Ilya Kiselev (kiselecheck): https://github.com/kiselecheck" | |
| __version__ = "1.0.1" | |
| import warnings | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from utils.model_utils import load_start_checkpoint | |
| from utils.settings import ( | |
| get_model_from_config, | |
| initialize_environment_ddp, | |
| parse_args_valid, | |
| ) | |
| from valid import valid_multi_gpu | |
| warnings.filterwarnings("ignore") | |
| def check_validation_single(rank: int, world_size: int, args=None): | |
| args = parse_args_valid(args) | |
| initialize_environment_ddp(rank, world_size) | |
| model, config = get_model_from_config(args.model_type, args.config_path) | |
| if args.start_check_point: | |
| checkpoint = torch.load( | |
| args.start_check_point, weights_only=False, map_location="cpu" | |
| ) | |
| load_start_checkpoint(args, model, checkpoint, type_="valid") | |
| if dist.get_rank() == 0: | |
| print(f"Instruments: {config.training.instruments}") | |
| device = torch.device(f"cuda:{rank}") | |
| model.to(device) | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) | |
| valid_multi_gpu(model, args, config, args.device_ids, verbose=False) | |
| def check_validation(args=None): | |
| world_size = torch.cuda.device_count() | |
| mp.spawn( | |
| check_validation_single, args=(world_size, args), nprocs=world_size, join=True | |
| ) | |
| if __name__ == "__main__": | |
| check_validation() | |