| |
|
|
| |
| |
| |
| |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| class GatherLayer(torch.autograd.Function): |
| """ |
| Gather tensors from all process, supporting backward propagation. |
| https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py |
| """ |
| @staticmethod |
| def forward(ctx, input): |
| ctx.save_for_backward(input) |
| output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] |
| dist.all_gather(output, input) |
| return tuple(output) |
|
|
| @staticmethod |
| def backward(ctx, *grads): |
| (input,) = ctx.saved_tensors |
| grad_out = torch.zeros_like(input) |
| grad_out[:] = grads[dist.get_rank()] |
| return grad_out |
|
|
|
|
| def dist_gather(x: torch.tensor): |
| if not dist.is_initialized(): return x |
| if len(x.shape) == 0: |
| x = x.reshape(1) |
| x_gather = GatherLayer.apply(x) |
| x_gather = torch.cat(x_gather, dim=0) |
| return x_gather |
|
|
|
|
| @torch.no_grad() |
| def dist_gather_nograd(x: torch.tensor): |
| if not dist.is_initialized(): return x |
| x_gather = [torch.ones_like(x) for _ in range(get_world_size())] |
| dist.all_gather(x_gather, x, async_op=False) |
| x_gather = torch.cat(x_gather, dim=0) |
| return x_gather |
|
|
|
|
| def get_rank(): |
| if not dist.is_available(): |
| return 0 |
| if not dist.is_initialized(): |
| return 0 |
| return dist.get_rank() |
|
|
|
|
| def is_main(): |
| return get_rank() == 0 |
|
|
|
|
| def get_world_size(): |
| if not dist.is_initialized(): |
| return 1 |
| else: |
| return dist.get_world_size() |
|
|
| def barrier(): |
| if dist.is_initialized(): |
| dist.barrier() |
|
|
|
|
| @torch.no_grad() |
| def varsize_gather_nograd(x: torch.Tensor): |
| """gather tensors of different sizes along the first dimension""" |
| if not dist.is_initialized(): |
| return x |
|
|
| |
| size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) |
| allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] |
| dist.all_gather(allsizes, size) |
| max_size = max([size.cpu().max() for size in allsizes]) |
|
|
| padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) |
| padded[: x.shape[0]] = x |
| output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] |
| dist.all_gather(output, padded) |
|
|
| output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] |
| output = torch.cat(output, dim=0) |
|
|
| return output |
|
|