| from typing import List |
| from torch import distributed |
|
|
|
|
| def barrier(): |
| if distributed.is_initialized(): |
| distributed.barrier() |
| else: |
| pass |
|
|
|
|
| def broadcast(data, src): |
| if distributed.is_initialized(): |
| distributed.broadcast(data, src) |
| else: |
| pass |
|
|
|
|
| def all_gather(data: List, src): |
| if distributed.is_initialized(): |
| distributed.all_gather(data, src) |
| else: |
| data[0] = src |
|
|
|
|
| def get_rank(): |
| if distributed.is_initialized(): |
| return distributed.get_rank() |
| else: |
| return 0 |
|
|
|
|
| def get_world_size(): |
| if distributed.is_initialized(): |
| return distributed.get_world_size() |
| else: |
| return 1 |
|
|
|
|
| def chunk_size(size, rank, world_size): |
| extra = rank < size % world_size |
| return size // world_size + extra |