| import datetime |
| import argparse, importlib |
| from pytorch_lightning import seed_everything |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| def setup_dist(local_rank): |
| if dist.is_initialized(): |
| return |
| torch.cuda.set_device(local_rank) |
| torch.distributed.init_process_group('nccl', init_method='env://') |
|
|
|
|
| def get_dist_info(): |
| if dist.is_available(): |
| initialized = dist.is_initialized() |
| else: |
| initialized = False |
| if initialized: |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| else: |
| rank = 0 |
| world_size = 1 |
| return rank, world_size |
|
|
|
|
| if __name__ == '__main__': |
| now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--module", type=str, help="module name", default="inference") |
| parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) |
| args, unknown = parser.parse_known_args() |
| inference_api = importlib.import_module(args.module, package=None) |
|
|
| inference_parser = inference_api.get_parser() |
| inference_args, unknown = inference_parser.parse_known_args() |
|
|
| seed_everything(inference_args.seed) |
| setup_dist(args.local_rank) |
| torch.backends.cudnn.benchmark = True |
| rank, gpu_num = get_dist_info() |
|
|
| |
| print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now)) |
| inference_api.run_inference(inference_args, gpu_num, rank) |