| | """ |
| | Helpers for distributed training. |
| | """ |
| |
|
| | import io |
| | import os |
| | import socket |
| |
|
| | import blobfile as bf |
| | from mpi4py import MPI |
| | import torch as th |
| | import torch.distributed as dist |
| |
|
| | |
| | |
| | GPUS_PER_NODE = 8 |
| |
|
| | SETUP_RETRY_COUNT = 3 |
| |
|
| |
|
| | def setup_dist(): |
| | """ |
| | Setup a distributed process group. |
| | """ |
| | if dist.is_initialized(): |
| | return |
| | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" |
| |
|
| | comm = MPI.COMM_WORLD |
| | backend = "gloo" if not th.cuda.is_available() else "nccl" |
| |
|
| | if backend == "gloo": |
| | hostname = "localhost" |
| | else: |
| | hostname = socket.gethostbyname(socket.getfqdn()) |
| | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) |
| | os.environ["RANK"] = str(comm.rank) |
| | os.environ["WORLD_SIZE"] = str(comm.size) |
| |
|
| | port = comm.bcast(_find_free_port(), root=0) |
| | os.environ["MASTER_PORT"] = str(port) |
| | dist.init_process_group(backend=backend, init_method="env://") |
| |
|
| |
|
| | def dev(): |
| | """ |
| | Get the device to use for torch.distributed. |
| | """ |
| | if th.cuda.is_available(): |
| | return th.device(f"cuda") |
| | return th.device("cpu") |
| |
|
| |
|
| | def load_state_dict(path, **kwargs): |
| | """ |
| | Load a PyTorch file without redundant fetches across MPI ranks. |
| | """ |
| | chunk_size = 2 ** 30 |
| | if MPI.COMM_WORLD.Get_rank() == 0: |
| | with bf.BlobFile(path, "rb") as f: |
| | data = f.read() |
| | num_chunks = len(data) // chunk_size |
| | if len(data) % chunk_size: |
| | num_chunks += 1 |
| | MPI.COMM_WORLD.bcast(num_chunks) |
| | for i in range(0, len(data), chunk_size): |
| | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) |
| | else: |
| | num_chunks = MPI.COMM_WORLD.bcast(None) |
| | data = bytes() |
| | for _ in range(num_chunks): |
| | data += MPI.COMM_WORLD.bcast(None) |
| |
|
| | return th.load(io.BytesIO(data), **kwargs) |
| |
|
| |
|
| | def sync_params(params): |
| | """ |
| | Synchronize a sequence of Tensors across ranks from rank 0. |
| | """ |
| | for p in params: |
| | with th.no_grad(): |
| | dist.broadcast(p, 0) |
| |
|
| |
|
| | def _find_free_port(): |
| | try: |
| | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| | s.bind(("", 0)) |
| | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| | return s.getsockname()[1] |
| | finally: |
| | s.close() |