| | |
| | import datetime |
| | import functools |
| | import os |
| | import subprocess |
| | from typing import Callable, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.multiprocessing as mp |
| | from torch import Tensor |
| | from torch import distributed as torch_dist |
| | from torch.distributed import ProcessGroup |
| | from mmengine.device import (is_mlu_available, is_npu_available, |
| | is_musa_available) |
| |
|
| | from collections.abc import Iterable, Mapping |
| |
|
| | _LOCAL_PROCESS_GROUP = None |
| |
|
| |
|
| | def is_distributed() -> bool: |
| | """Return True if distributed environment has been initialized.""" |
| | return torch_dist.is_available() and torch_dist.is_initialized() |
| |
|
| |
|
| | def get_local_group() -> Optional[ProcessGroup]: |
| | """Return local process group.""" |
| | if not is_distributed(): |
| | return None |
| |
|
| | if _LOCAL_PROCESS_GROUP is None: |
| | raise RuntimeError('Local process group is not created, please use ' |
| | '`init_local_group` to setup local process group.') |
| |
|
| | return _LOCAL_PROCESS_GROUP |
| |
|
| |
|
| | def get_default_group() -> Optional[ProcessGroup]: |
| | """Return default process group.""" |
| |
|
| | return torch_dist.distributed_c10d._get_default_group() |
| |
|
| |
|
| | def infer_launcher(): |
| | if 'WORLD_SIZE' in os.environ: |
| | return 'pytorch' |
| | elif 'SLURM_NTASKS' in os.environ: |
| | return 'slurm' |
| | elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: |
| | return 'mpi' |
| | else: |
| | return 'none' |
| |
|
| |
|
| | def init_dist(launcher, |
| | backend='nccl', |
| | init_backend='torch', |
| | **kwargs) -> None: |
| | """Initialize distributed environment. |
| | |
| | Args: |
| | launcher (str): Way to launcher multi processes. Supported launchers |
| | are 'pytorch', 'mpi' and 'slurm'. |
| | backend (str): Communication Backends. Supported backends are 'nccl', |
| | 'gloo' and 'mpi'. Defaults to 'nccl'. |
| | **kwargs: keyword arguments are passed to ``init_process_group``. |
| | """ |
| | timeout = kwargs.get('timeout', None) |
| | if timeout is not None: |
| | |
| | |
| | |
| | try: |
| | kwargs['timeout'] = datetime.timedelta(seconds=timeout) |
| | except TypeError as exception: |
| | raise TypeError( |
| | f'Timeout for distributed training must be provided as ' |
| | f"timeout in seconds, but we've received the type " |
| | f'{type(timeout)}. Please specify the timeout like this: ' |
| | f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception |
| | if mp.get_start_method(allow_none=True) is None: |
| | mp.set_start_method('spawn') |
| | if launcher == 'pytorch': |
| | _init_dist_pytorch(backend, init_backend=init_backend, **kwargs) |
| | elif launcher == 'mpi': |
| | _init_dist_mpi(backend, **kwargs) |
| | elif launcher == 'slurm': |
| | _init_dist_slurm(backend, init_backend=init_backend, **kwargs) |
| | else: |
| | raise ValueError(f'Invalid launcher type: {launcher}') |
| |
|
| |
|
| | def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: |
| | """Initialize distributed environment with PyTorch launcher. |
| | |
| | Args: |
| | backend (str): Backend of torch.distributed. Supported backends are |
| | 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. |
| | **kwargs: keyword arguments are passed to ``init_process_group``. |
| | """ |
| | rank = int(os.environ['RANK']) |
| | |
| | local_rank = int(os.environ['LOCAL_RANK']) |
| | if is_mlu_available(): |
| | import torch_mlu |
| | torch.mlu.set_device(local_rank) |
| | torch_dist.init_process_group( |
| | backend='cncl', |
| | rank=rank, |
| | world_size=int(os.environ['WORLD_SIZE']), |
| | **kwargs) |
| | elif is_npu_available(): |
| | import torch_npu |
| | torch.npu.set_device(local_rank) |
| | torch_dist.init_process_group( |
| | backend='hccl', |
| | rank=rank, |
| | world_size=int(os.environ['WORLD_SIZE']), |
| | **kwargs) |
| | elif is_musa_available(): |
| | import torch_musa |
| | torch.musa.set_device(rank) |
| | torch_dist.init_process_group( |
| | backend='mccl', |
| | rank=rank, |
| | world_size=int(os.environ['WORLD_SIZE']), |
| | **kwargs) |
| | else: |
| | torch.cuda.set_device(local_rank) |
| |
|
| | if init_backend == 'torch': |
| | torch_dist.init_process_group(backend=backend, **kwargs) |
| | elif init_backend == 'deepspeed': |
| | import deepspeed |
| | deepspeed.init_distributed(dist_backend=backend, **kwargs) |
| | elif init_backend == 'colossalai': |
| | import colossalai |
| | colossalai.launch_from_torch(backend=backend, **kwargs) |
| | else: |
| | raise ValueError( |
| | 'supported "init_backend" is "torch" or "deepspeed", ' |
| | f'but got {init_backend}') |
| |
|
| |
|
| | def _init_dist_mpi(backend, **kwargs) -> None: |
| | """Initialize distributed environment with MPI launcher. |
| | |
| | Args: |
| | backend (str): Backend of torch.distributed. Supported backends are |
| | 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. |
| | **kwargs: keyword arguments are passed to ``init_process_group``. |
| | """ |
| | if backend == 'smddp': |
| | try: |
| | import smdistributed.dataparallel.torch.torch_smddp |
| | except ModuleNotFoundError as e: |
| | raise ModuleNotFoundError( |
| | 'Please use an Amazon SageMaker DLC to access smdistributed: ' |
| | 'https://github.com/aws/deep-learning-containers/blob/master' |
| | '/available_images.md#sagemaker-framework-containers' |
| | '-sm-support-only') from e |
| | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
| | torch.cuda.set_device(local_rank) |
| | if 'MASTER_PORT' not in os.environ: |
| | |
| | os.environ['MASTER_PORT'] = '29500' |
| | if 'MASTER_ADDR' not in os.environ: |
| | raise KeyError('The environment variable MASTER_ADDR is not set') |
| | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] |
| | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] |
| | torch_dist.init_process_group(backend=backend, **kwargs) |
| |
|
| |
|
| | def _init_dist_slurm(backend, |
| | port=None, |
| | init_backend='torch', |
| | **kwargs) -> None: |
| | """Initialize slurm distributed training environment. |
| | |
| | If argument ``port`` is not specified, then the master port will be system |
| | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system |
| | environment variable, then a default port ``29500`` will be used. |
| | |
| | Args: |
| | backend (str): Backend of torch.distributed. |
| | port (int, optional): Master port. Defaults to None. |
| | """ |
| | proc_id = int(os.environ['SLURM_PROCID']) |
| | ntasks = int(os.environ['SLURM_NTASKS']) |
| | node_list = os.environ['SLURM_NODELIST'] |
| | |
| | local_rank_env = os.environ.get('SLURM_LOCALID', None) |
| | if local_rank_env is not None: |
| | local_rank = int(local_rank_env) |
| | else: |
| | num_gpus = torch.cuda.device_count() |
| | local_rank = proc_id % num_gpus |
| | addr = subprocess.getoutput( |
| | f'scontrol show hostname {node_list} | head -n1') |
| | |
| | if port is not None: |
| | os.environ['MASTER_PORT'] = str(port) |
| | elif 'MASTER_PORT' in os.environ: |
| | pass |
| | else: |
| | |
| | os.environ['MASTER_PORT'] = '29500' |
| | |
| | if 'MASTER_ADDR' not in os.environ: |
| | os.environ['MASTER_ADDR'] = addr |
| | os.environ['WORLD_SIZE'] = str(ntasks) |
| | os.environ['LOCAL_RANK'] = str(local_rank) |
| | os.environ['RANK'] = str(proc_id) |
| |
|
| | if is_mlu_available(): |
| | import torch_mlu |
| | torch.mlu.set_device(local_rank) |
| | torch_dist.init_process_group(backend='cncl', **kwargs) |
| | else: |
| | torch.cuda.set_device(local_rank) |
| |
|
| | if init_backend == 'torch': |
| | torch_dist.init_process_group(backend=backend, **kwargs) |
| | elif init_backend == 'deepspeed': |
| | import deepspeed |
| | deepspeed.init_distributed(dist_backend=backend, **kwargs) |
| | elif init_backend == 'colossalai': |
| | import colossalai |
| | colossalai.launch_from_slurm( |
| | backend=backend, |
| | host=os.environ['MASTER_ADDR'], |
| | port=os.environ['MASTER_PORT'], |
| | **kwargs, |
| | ) |
| | else: |
| | raise ValueError( |
| | 'supported "init_backend" is "torch" or "deepspeed", ' |
| | f'but got {init_backend}') |
| |
|
| |
|
| | def init_local_group(node_rank: int, num_gpus_per_node: int): |
| | """Setup the local process group. |
| | |
| | Setup a process group which only includes processes that on the same |
| | machine as the current process. |
| | |
| | The code is modified from |
| | https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py |
| | |
| | Args: |
| | node_rank (int): Rank of machines used for training. |
| | num_gpus_per_node (int): Number of gpus used for training in a single |
| | machine. |
| | """ |
| | global _LOCAL_PROCESS_GROUP |
| | assert _LOCAL_PROCESS_GROUP is None |
| |
|
| | ranks = list( |
| | range(node_rank * num_gpus_per_node, |
| | (node_rank + 1) * num_gpus_per_node)) |
| | _LOCAL_PROCESS_GROUP = torch_dist.new_group(ranks) |
| |
|
| |
|
| | def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]: |
| | """Return the backend of the given process group. |
| | |
| | Note: |
| | Calling ``get_backend`` in non-distributed environment will return |
| | None. |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. The |
| | default is the general main process group. If another specific |
| | group is specified, the calling process must be part of |
| | :attr:`group`. Defaults to None. |
| | |
| | Returns: |
| | str or None: Return the backend of the given process group as a lower |
| | case string if in distributed environment, otherwise None. |
| | """ |
| | if is_distributed(): |
| | |
| | |
| | if group is None: |
| | group = get_default_group() |
| | return torch_dist.get_backend(group) |
| | else: |
| | return None |
| |
|
| |
|
| | def get_world_size(group: Optional[ProcessGroup] = None) -> int: |
| | """Return the number of the given process group. |
| | |
| | Note: |
| | Calling ``get_world_size`` in non-distributed environment will return |
| | 1. |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. If None, |
| | the default process group will be used. Defaults to None. |
| | |
| | Returns: |
| | int: Return the number of processes of the given process group if in |
| | distributed environment, otherwise 1. |
| | """ |
| | if is_distributed(): |
| | |
| | |
| | if group is None: |
| | group = get_default_group() |
| | return torch_dist.get_world_size(group) |
| | else: |
| | return 1 |
| |
|
| |
|
| | def get_rank(group: Optional[ProcessGroup] = None) -> int: |
| | """Return the rank of the given process group. |
| | |
| | Rank is a unique identifier assigned to each process within a distributed |
| | process group. They are always consecutive integers ranging from 0 to |
| | ``world_size``. |
| | |
| | Note: |
| | Calling ``get_rank`` in non-distributed environment will return 0. |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. If None, |
| | the default process group will be used. Defaults to None. |
| | |
| | Returns: |
| | int: Return the rank of the process group if in distributed |
| | environment, otherwise 0. |
| | """ |
| |
|
| | if is_distributed(): |
| | |
| | |
| | if group is None: |
| | group = get_default_group() |
| | return torch_dist.get_rank(group) |
| | else: |
| | return 0 |
| |
|
| |
|
| | def get_local_size() -> int: |
| | """Return the number of the current node. |
| | |
| | Returns: |
| | int: Return the number of processes in the current node if in |
| | distributed environment, otherwise 1. |
| | """ |
| | if not is_distributed(): |
| | return 1 |
| |
|
| | if _LOCAL_PROCESS_GROUP is None: |
| | raise RuntimeError('Local process group is not created, please use ' |
| | '`init_local_group` to setup local process group.') |
| |
|
| | return torch_dist.get_world_size(_LOCAL_PROCESS_GROUP) |
| |
|
| |
|
| | def get_local_rank() -> int: |
| | """Return the rank of current process in the current node. |
| | |
| | Returns: |
| | int: Return the rank of current process in the current node if in |
| | distributed environment, otherwise 0 |
| | """ |
| | if not is_distributed(): |
| | return 0 |
| |
|
| | if _LOCAL_PROCESS_GROUP is None: |
| | raise RuntimeError('Local process group is not created, please use ' |
| | '`init_local_group` to setup local process group.') |
| |
|
| | return torch_dist.get_rank(_LOCAL_PROCESS_GROUP) |
| |
|
| |
|
| | def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]: |
| | """Get distributed information of the given process group. |
| | |
| | Note: |
| | Calling ``get_dist_info`` in non-distributed environment will return |
| | (0, 1). |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. If None, |
| | the default process group will be used. Defaults to None. |
| | |
| | Returns: |
| | tuple[int, int]: Return a tuple containing the ``rank`` and |
| | ``world_size``. |
| | """ |
| | world_size = get_world_size(group) |
| | rank = get_rank(group) |
| | return rank, world_size |
| |
|
| |
|
| | def is_main_process(group: Optional[ProcessGroup] = None) -> bool: |
| | """Whether the current rank of the given process group is equal to 0. |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. If None, |
| | the default process group will be used. Defaults to None. |
| | |
| | Returns: |
| | bool: Return True if the current rank of the given process group is |
| | equal to 0, otherwise False. |
| | """ |
| | return get_rank(group) == 0 |
| |
|
| |
|
| | def master_only(func: Callable) -> Callable: |
| | """Decorate those methods which should be executed in master process. |
| | |
| | Args: |
| | func (callable): Function to be decorated. |
| | |
| | Returns: |
| | callable: Return decorated function. |
| | """ |
| |
|
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | if is_main_process(): |
| | return func(*args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | def barrier(group: Optional[ProcessGroup] = None) -> None: |
| | """Synchronize all processes from the given process group. |
| | |
| | This collective blocks processes until the whole group enters this |
| | function. |
| | |
| | Note: |
| | Calling ``barrier`` in non-distributed environment will do nothing. |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. If None, |
| | the default process group will be used. Defaults to None. |
| | """ |
| | if is_distributed(): |
| | |
| | |
| | if group is None: |
| | group = get_default_group() |
| | torch_dist.barrier(group) |
| |
|
| |
|
| | def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device: |
| | """Return the device of ``data``. |
| | |
| | If ``data`` is a sequence of Tensor, all items in ``data`` should have a |
| | same device type. |
| | |
| | If ``data`` is a dict whose values are Tensor, all values should have a |
| | same device type. |
| | |
| | Args: |
| | data (Tensor or Sequence or dict): Inputs to be inferred the device. |
| | |
| | Returns: |
| | torch.device: The device of ``data``. |
| | |
| | Examples: |
| | >>> import torch |
| | >>> from mmengine.dist import cast_data_device |
| | >>> # data is a Tensor |
| | >>> data = torch.tensor([0, 1]) |
| | >>> get_data_device(data) |
| | device(type='cpu') |
| | >>> # data is a list of Tensor |
| | >>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])] |
| | >>> get_data_device(data) |
| | device(type='cpu') |
| | >>> # data is a dict |
| | >>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])} |
| | >>> get_data_device(data) |
| | device(type='cpu') |
| | """ |
| | if isinstance(data, Tensor): |
| | return data.device |
| | elif isinstance(data, Mapping): |
| | pre = None |
| | for v in data.values(): |
| | cur = get_data_device(v) |
| | if pre is None: |
| | pre = cur |
| | else: |
| | if cur != pre: |
| | raise ValueError( |
| | 'device type in data should be consistent, but got ' |
| | f'{cur} and {pre}') |
| | if pre is None: |
| | raise ValueError('data should not be empty.') |
| | return pre |
| | elif isinstance(data, Iterable) and not isinstance(data, str): |
| | pre = None |
| | for item in data: |
| | cur = get_data_device(item) |
| | if pre is None: |
| | pre = cur |
| | else: |
| | if cur != pre: |
| | raise ValueError( |
| | 'device type in data should be consistent, but got ' |
| | f'{cur} and {pre}') |
| | if pre is None: |
| | raise ValueError('data should not be empty.') |
| | return pre |
| | else: |
| | raise TypeError('data should be a Tensor, sequence of tensor or dict, ' |
| | f'but got {data}') |
| |
|
| |
|
| | def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: |
| | """Return the device for communication among groups. |
| | |
| | Args: |
| | group (ProcessGroup, optional): The process group to work on. |
| | |
| | Returns: |
| | torch.device: The device of backend. |
| | """ |
| | backend = get_backend(group) |
| | if backend == 'hccl': |
| | import torch_npu |
| | return torch.device('npu', torch.npu.current_device()) |
| | elif backend == torch_dist.Backend.NCCL: |
| | return torch.device('cuda', torch.cuda.current_device()) |
| | elif backend == 'cncl': |
| | import torch_mlu |
| | return torch.device('mlu', torch.mlu.current_device()) |
| | elif backend == 'smddp': |
| | return torch.device('cuda', torch.cuda.current_device()) |
| | elif backend == 'mccl': |
| | import torch_musa |
| | return torch.device('musa', torch_musa.current_device()) |
| | else: |
| | |
| | return torch.device('cpu') |
| |
|
| |
|
| | def cast_data_device( |
| | data: Union[Tensor, Mapping, Iterable], |
| | device: torch.device, |
| | out: Optional[Union[Tensor, Mapping, Iterable]] = None |
| | ) -> Union[Tensor, Mapping, Iterable]: |
| | """Recursively convert Tensor in ``data`` to ``device``. |
| | |
| | If ``data`` has already on the ``device``, it will not be casted again. |
| | |
| | Args: |
| | data (Tensor or list or dict): Inputs to be casted. |
| | device (torch.device): Destination device type. |
| | out (Tensor or list or dict, optional): If ``out`` is specified, its |
| | value will be equal to ``data``. Defaults to None. |
| | |
| | Returns: |
| | Tensor or list or dict: ``data`` was casted to ``device``. |
| | """ |
| | if out is not None: |
| | if type(data) is not type(out): |
| | raise TypeError( |
| | 'out should be the same type with data, but got data is ' |
| | f'{type(data)} and out is {type(data)}') |
| |
|
| | if isinstance(out, set): |
| | raise TypeError('out should not be a set') |
| |
|
| | if isinstance(data, Tensor): |
| | if get_data_device(data) == device: |
| | data_on_device = data |
| | else: |
| | data_on_device = data.to(device) |
| |
|
| | if out is not None: |
| | |
| | out.copy_(data_on_device) |
| |
|
| | return data_on_device |
| | elif isinstance(data, Mapping): |
| | data_on_device = {} |
| | if out is not None: |
| | data_len = len(data) |
| | out_len = len(out) |
| | if data_len != out_len: |
| | raise ValueError('length of data and out should be same, ' |
| | f'but got {data_len} and {out_len}') |
| |
|
| | for k, v in data.items(): |
| | data_on_device[k] = cast_data_device(v, device, |
| | out[k]) |
| | else: |
| | for k, v in data.items(): |
| | data_on_device[k] = cast_data_device(v, device) |
| |
|
| | if len(data_on_device) == 0: |
| | raise ValueError('data should not be empty') |
| |
|
| | |
| | |
| | return type(data)(data_on_device) |
| | elif isinstance(data, Iterable) and not isinstance( |
| | data, str) and not isinstance(data, np.ndarray): |
| | data_on_device = [] |
| | if out is not None: |
| | for v1, v2 in zip(data, out): |
| | data_on_device.append(cast_data_device(v1, device, v2)) |
| | else: |
| | for v in data: |
| | data_on_device.append(cast_data_device(v, device)) |
| |
|
| | if len(data_on_device) == 0: |
| | raise ValueError('data should not be empty') |
| |
|
| | return type(data)(data_on_device) |
| | else: |
| | raise TypeError('data should be a Tensor, list of tensor or dict, ' |
| | f'but got {data}') |
| |
|