|
|
| import logging
|
| from datetime import timedelta
|
| import torch
|
| import torch.distributed as dist
|
| import torch.multiprocessing as mp
|
|
|
| from detectron2.utils import comm
|
|
|
| __all__ = ["DEFAULT_TIMEOUT", "launch"]
|
|
|
| DEFAULT_TIMEOUT = timedelta(minutes=30)
|
|
|
|
|
| def _find_free_port():
|
| import socket
|
|
|
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
| sock.bind(("", 0))
|
| port = sock.getsockname()[1]
|
| sock.close()
|
|
|
| return port
|
|
|
|
|
| def launch(
|
| main_func,
|
|
|
| num_gpus_per_machine,
|
| num_machines=1,
|
| machine_rank=0,
|
| dist_url=None,
|
| args=(),
|
| timeout=DEFAULT_TIMEOUT,
|
| ):
|
| """
|
| Launch multi-process or distributed training.
|
| This function must be called on all machines involved in the training.
|
| It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
|
|
| Args:
|
| main_func: a function that will be called by `main_func(*args)`
|
| num_gpus_per_machine (int): number of processes per machine. When
|
| using GPUs, this should be the number of GPUs.
|
| num_machines (int): the total number of machines
|
| machine_rank (int): the rank of this machine
|
| dist_url (str): url to connect to for distributed jobs, including protocol
|
| e.g. "tcp://127.0.0.1:8686".
|
| Can be set to "auto" to automatically select a free port on localhost
|
| timeout (timedelta): timeout of the distributed workers
|
| args (tuple): arguments passed to main_func
|
| """
|
| world_size = num_machines * num_gpus_per_machine
|
| if world_size > 1:
|
|
|
|
|
|
|
| if dist_url == "auto":
|
| assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
|
| port = _find_free_port()
|
| dist_url = f"tcp://127.0.0.1:{port}"
|
| if num_machines > 1 and dist_url.startswith("file://"):
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
| )
|
|
|
| mp.start_processes(
|
| _distributed_worker,
|
| nprocs=num_gpus_per_machine,
|
| args=(
|
| main_func,
|
| world_size,
|
| num_gpus_per_machine,
|
| machine_rank,
|
| dist_url,
|
| args,
|
| timeout,
|
| ),
|
| daemon=False,
|
| )
|
| else:
|
| main_func(*args)
|
|
|
|
|
| def _distributed_worker(
|
| local_rank,
|
| main_func,
|
| world_size,
|
| num_gpus_per_machine,
|
| machine_rank,
|
| dist_url,
|
| args,
|
| timeout=DEFAULT_TIMEOUT,
|
| ):
|
| has_gpu = torch.cuda.is_available()
|
| if has_gpu:
|
| assert num_gpus_per_machine <= torch.cuda.device_count()
|
| global_rank = machine_rank * num_gpus_per_machine + local_rank
|
| try:
|
| dist.init_process_group(
|
| backend="NCCL" if has_gpu else "GLOO",
|
| init_method=dist_url,
|
| world_size=world_size,
|
| rank=global_rank,
|
| timeout=timeout,
|
| )
|
| except Exception as e:
|
| logger = logging.getLogger(__name__)
|
| logger.error("Process group URL: {}".format(dist_url))
|
| raise e
|
|
|
|
|
| comm.create_local_process_group(num_gpus_per_machine)
|
| if has_gpu:
|
| torch.cuda.set_device(local_rank)
|
|
|
|
|
|
|
| comm.synchronize()
|
|
|
| main_func(*args)
|
|
|