| | |
| | |
| | |
| | import torch |
| | import random |
| | import numpy as np |
| | import os |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | DEVICE = torch.device('cuda') |
| | MP_CONTEXT = None |
| | PIN_MEM = True |
| | elif torch.backends.mps.is_available(): |
| | DEVICE = torch.device('mps') |
| | MP_CONTEXT = 'forkserver' |
| | PIN_MEM = False |
| | else: |
| | DEVICE = torch.device('cpu') |
| | MP_CONTEXT = None |
| | PIN_MEM = False |
| |
|
| |
|
| | |
| | |
| | |
| | def set_seed(seed: int = 0): |
| | ''' |
| | Sets random seed and deterministic settings for reproducibility across: |
| | - PyTorch |
| | - NumPy |
| | - Python's random module |
| | |
| | Args: |
| | seed (int): The seed value to set. |
| | ''' |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | |
| | torch.use_deterministic_algorithms(True) |
| | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
| |
|
| | def save_model(model: torch.nn.Module, |
| | save_dir: str, |
| | mod_name: str): |
| | ''' |
| | Saves the `state_dict()` of a model to the directory 'save_dir.' |
| | |
| | Args: |
| | model (torch.nn.Module): The PyTorch model whose state dict and keyword arguments will be saved. |
| | save_dir (str): Directory to save the model to. |
| | mod_name (str): Filename for the saved model. If this doesn't end with '.pth' or '.pt,' it will be added on for the state_dict. |
| | |
| | ''' |
| | |
| | os.makedirs(save_dir, exist_ok = True) |
| | |
| | |
| | if not mod_name.endswith('.pth') and not mod_name.endswith('.pt'): |
| | mod_name += '.pth' |
| |
|
| | |
| | save_path = os.path.join(save_dir, mod_name) |
| |
|
| | |
| | torch.save(obj = model.state_dict(), f = save_path) |