File size: 1,442 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# coding: utf-8
__author__ = "Ilya Kiselev (kiselecheck): https://github.com/kiselecheck"
__version__ = "1.0.1"


import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from utils.model_utils import load_start_checkpoint
from utils.settings import (
    get_model_from_config,
    initialize_environment_ddp,
    parse_args_valid,
)
from valid import valid_multi_gpu

warnings.filterwarnings("ignore")


def check_validation_single(rank: int, world_size: int, args=None):
    args = parse_args_valid(args)

    initialize_environment_ddp(rank, world_size)
    model, config = get_model_from_config(args.model_type, args.config_path)

    if args.start_check_point:
        checkpoint = torch.load(
            args.start_check_point, weights_only=False, map_location="cpu"
        )
        load_start_checkpoint(args, model, checkpoint, type_="valid")

    if dist.get_rank() == 0:
        print(f"Instruments: {config.training.instruments}")

    device = torch.device(f"cuda:{rank}")
    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    valid_multi_gpu(model, args, config, args.device_ids, verbose=False)


def check_validation(args=None):
    world_size = torch.cuda.device_count()
    mp.spawn(
        check_validation_single, args=(world_size, args), nprocs=world_size, join=True
    )


if __name__ == "__main__":
    check_validation()