| |
| import importlib |
| import importlib.util |
| import logging |
| import numpy as np |
| import os |
| import random |
| import sys |
| from datetime import datetime |
| import torch |
| import socket |
| import subprocess |
| import time |
| from . import comm |
|
|
| __all__ = ["seed_all_rng"] |
|
|
|
|
| TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
| """ |
| PyTorch version as a tuple of 2 ints. Useful for comparison. |
| """ |
|
|
|
|
| def seed_all_rng(seed=None): |
| """ |
| Set the random seed for the RNG in torch, numpy and python. |
| |
| Args: |
| seed (int): if None, will use a strong random seed. |
| """ |
| if seed is None: |
| seed = ( |
| os.getpid() |
| + int(datetime.now().strftime("%S%f")) |
| + int.from_bytes(os.urandom(2), "big") |
| ) |
| logger = logging.getLogger(__name__) |
| logger.info("Using a generated random seed {}".format(seed)) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| random.seed(seed) |
|
|
|
|
| |
| def _import_file(module_name, file_path, make_importable=False): |
| spec = importlib.util.spec_from_file_location(module_name, file_path) |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| if make_importable: |
| sys.modules[module_name] = module |
| return module |
|
|
|
|
| def _configure_libraries(): |
| """ |
| Configurations for some libraries. |
| """ |
| |
| |
| disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False)) |
| if disable_cv2: |
| sys.modules["cv2"] = None |
| else: |
| |
| |
| os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" |
| try: |
| import cv2 |
|
|
| if int(cv2.__version__.split(".")[0]) >= 3: |
| cv2.ocl.setUseOpenCL(False) |
| except ModuleNotFoundError: |
| |
| |
| |
| pass |
|
|
| def get_version(module, digit=2): |
| return tuple(map(int, module.__version__.split(".")[:digit])) |
|
|
| |
| assert get_version(torch) >= (1, 4), "Requires torch>=1.4" |
| import fvcore |
| assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2" |
| import yaml |
| assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1" |
| |
|
|
|
|
| _ENV_SETUP_DONE = False |
|
|
|
|
| def setup_environment(): |
| """Perform environment setup work. The default setup is a no-op, but this |
| function allows the user to specify a Python source file or a module in |
| the $DETECTRON2_ENV_MODULE environment variable, that performs |
| custom setup work that may be necessary to their computing environment. |
| """ |
| global _ENV_SETUP_DONE |
| if _ENV_SETUP_DONE: |
| return |
| _ENV_SETUP_DONE = True |
|
|
| _configure_libraries() |
|
|
| custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE") |
|
|
| if custom_module_path: |
| setup_custom_environment(custom_module_path) |
| else: |
| |
| pass |
|
|
|
|
| def setup_custom_environment(custom_module): |
| """ |
| Load custom environment setup by importing a Python source file or a |
| module, and run the setup function. |
| """ |
| if custom_module.endswith(".py"): |
| module = _import_file("detectron2.utils.env.custom_module", custom_module) |
| else: |
| module = importlib.import_module(custom_module) |
| assert hasattr(module, "setup_environment") and callable(module.setup_environment), ( |
| "Custom environment module defined in {} does not have the " |
| "required callable attribute 'setup_environment'." |
| ).format(custom_module) |
| module.setup_environment() |
| |
| def check_dist_portfile(): |
| if "SLURM_JOB_ID" in os.environ and int(os.environ["SLURM_PROCID"]) == 0: |
| hostfile = "dist_url_" + os.environ["SLURM_JOBID"] + ".txt" |
| if os.path.exists(hostfile): |
| os.remove(hostfile) |
|
|
| def find_free_port(): |
| s = socket.socket() |
| s.bind(('', 0)) |
| return s.getsockname()[1] |
|
|
| def init_distributed_mode(args): |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| if int(os.environ["RANK"])==0: |
| print('this task is not running on cluster!') |
| args.rank = int(os.environ["RANK"]) |
| args.world_size = int(os.environ['WORLD_SIZE']) |
| args.gpu = int(os.environ['LOCAL_RANK']) |
| args.dist_url = 'env://' |
| os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
| addr = socket.gethostname() |
| |
| elif 'SLURM_PROCID' in os.environ: |
| proc_id = int(os.environ['SLURM_PROCID']) |
| if proc_id==0: |
| print('Init dist using slurm!') |
| print("Job Id is {} on {} ".format(os.environ["SLURM_JOBID"], os.environ['SLURM_NODELIST'])) |
| ntasks = int(os.environ['SLURM_NTASKS']) |
| node_list = os.environ['SLURM_NODELIST'] |
| num_gpus = torch.cuda.device_count() |
| addr = subprocess.getoutput( |
| 'scontrol show hostname {} | head -n1'.format(node_list)) |
| jobid = os.environ["SLURM_JOBID"] |
| hostfile = "dist_url_" + jobid + ".txt" |
| if proc_id == 0: |
| args.tcp_port = str( find_free_port()) |
| print('write port {} to file: {} '.format(args.tcp_port, hostfile)) |
| with open(hostfile, "w") as f: |
| f.write(args.tcp_port) |
| else: |
| print('read port from file: {}'.format(hostfile)) |
| while not os.path.exists(hostfile): |
| time.sleep(1) |
| time.sleep(2) |
| with open(hostfile, "r") as f: |
| args.tcp_port = f.read() |
| |
| os.environ['MASTER_PORT'] =str(args.tcp_port) |
| os.environ['MASTER_ADDR'] = addr |
| os.environ['WORLD_SIZE'] = str(ntasks) |
| os.environ['RANK'] = str(proc_id) |
| os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) |
| os.environ['LOCAL_SIZE'] = str(num_gpus) |
| args.dist_url = 'env://' |
| args.world_size = ntasks |
| args.rank = proc_id |
| args.gpu = proc_id % num_gpus |
| else: |
| print('Not using distributed mode') |
| args.distributed = False |
| return |
|
|
| args.distributed = True |
|
|
| torch.cuda.set_device(args.gpu) |
| args.dist_backend = 'nccl' |
| print('rank: {} addr: {} port: {}'.format(args.rank, addr, os.environ['MASTER_PORT'])) |
| torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| world_size=args.world_size, rank=args.rank) |
| torch.distributed.barrier() |
| if 'SLURM_PROCID' in os.environ and args.rank == 0: |
| if os.path.isfile(hostfile): |
| os.remove(hostfile) |
| if args.world_size >= 1: |
| |
| assert comm._LOCAL_PROCESS_GROUP is None |
| num_gpus = torch.cuda.device_count() |
| num_machines = args.world_size // num_gpus |
| for i in range(num_machines): |
| ranks_on_i = list(range(i * num_gpus, (i + 1) * num_gpus)) |
| print('new_group: {}'.format(ranks_on_i)) |
| pg = torch.distributed.new_group(ranks_on_i) |
| if args.rank in ranks_on_i: |
| |
| comm._LOCAL_PROCESS_GROUP = pg |
|
|