| r""" Helper functions """ |
| import random |
|
|
| import torch |
| import numpy as np |
|
|
|
|
| def fix_randseed(seed): |
| r""" Set random seeds for reproducibility """ |
| if seed is None: |
| seed = int(random.random() * 1e5) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |
|
|
|
|
| def mean(x): |
| return sum(x) / len(x) if len(x) > 0 else 0.0 |
|
|
|
|
| def to_cuda(batch): |
| for key, value in batch.items(): |
| if isinstance(value, torch.Tensor): |
| batch[key] = value.cuda() |
| return batch |
|
|
|
|
| def to_cpu(tensor): |
| return tensor.detach().clone().cpu() |
|
|