xjsc0's picture
1
64ec292
# coding: utf-8
__author__ = "Ilya Kiselev (kiselecheck): https://github.com/kiselecheck"
__version__ = "1.0.1"
import warnings
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from utils.model_utils import load_start_checkpoint
from utils.settings import (
get_model_from_config,
initialize_environment_ddp,
parse_args_valid,
)
from valid import valid_multi_gpu
warnings.filterwarnings("ignore")
def check_validation_single(rank: int, world_size: int, args=None):
args = parse_args_valid(args)
initialize_environment_ddp(rank, world_size)
model, config = get_model_from_config(args.model_type, args.config_path)
if args.start_check_point:
checkpoint = torch.load(
args.start_check_point, weights_only=False, map_location="cpu"
)
load_start_checkpoint(args, model, checkpoint, type_="valid")
if dist.get_rank() == 0:
print(f"Instruments: {config.training.instruments}")
device = torch.device(f"cuda:{rank}")
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
valid_multi_gpu(model, args, config, args.device_ids, verbose=False)
def check_validation(args=None):
world_size = torch.cuda.device_count()
mp.spawn(
check_validation_single, args=(world_size, args), nprocs=world_size, join=True
)
if __name__ == "__main__":
check_validation()