| | |
| | import logging |
| | import random |
| | from typing import List, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import DataLoader |
| |
|
| | from mmengine.device import is_cuda_available, is_musa_available |
| | from mmengine.dist import get_rank, sync_random_seed |
| | from mmengine.logging import print_log |
| | from mmengine.utils import digit_version, is_list_of |
| | from mmengine.utils.dl_utils import TORCH_VERSION |
| |
|
| |
|
| | def calc_dynamic_intervals( |
| | start_interval: int, |
| | dynamic_interval_list: Optional[List[Tuple[int, int]]] = None |
| | ) -> Tuple[List[int], List[int]]: |
| | """Calculate dynamic intervals. |
| | |
| | Args: |
| | start_interval (int): The interval used in the beginning. |
| | dynamic_interval_list (List[Tuple[int, int]], optional): The |
| | first element in the tuple is a milestone and the second |
| | element is a interval. The interval is used after the |
| | corresponding milestone. Defaults to None. |
| | |
| | Returns: |
| | Tuple[List[int], List[int]]: a list of milestone and its corresponding |
| | intervals. |
| | """ |
| | if dynamic_interval_list is None: |
| | return [0], [start_interval] |
| |
|
| | assert is_list_of(dynamic_interval_list, tuple) |
| |
|
| | dynamic_milestones = [0] |
| | dynamic_milestones.extend( |
| | [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) |
| | dynamic_intervals = [start_interval] |
| | dynamic_intervals.extend( |
| | [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) |
| | return dynamic_milestones, dynamic_intervals |
| |
|
| |
|
| | def set_random_seed(seed: Optional[int] = None, |
| | deterministic: bool = False, |
| | diff_rank_seed: bool = False) -> int: |
| | """Set random seed. |
| | |
| | Args: |
| | seed (int, optional): Seed to be used. |
| | deterministic (bool): Whether to set the deterministic option for |
| | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` |
| | to True and `torch.backends.cudnn.benchmark` to False. |
| | Defaults to False. |
| | diff_rank_seed (bool): Whether to add rank number to the random seed to |
| | have different random seed in different threads. Defaults to False. |
| | """ |
| | if seed is None: |
| | seed = sync_random_seed() |
| |
|
| | if diff_rank_seed: |
| | rank = get_rank() |
| | seed += rank |
| |
|
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | |
| | if is_cuda_available(): |
| | torch.cuda.manual_seed_all(seed) |
| | elif is_musa_available(): |
| | torch.musa.manual_seed_all(seed) |
| | |
| | if deterministic: |
| | if torch.backends.cudnn.benchmark: |
| | print_log( |
| | 'torch.backends.cudnn.benchmark is going to be set as ' |
| | '`False` to cause cuDNN to deterministically select an ' |
| | 'algorithm', |
| | logger='current', |
| | level=logging.WARNING) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| | if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): |
| | torch.use_deterministic_algorithms(True) |
| | return seed |
| |
|
| |
|
| | def _get_batch_size(dataloader: dict): |
| | if isinstance(dataloader, dict): |
| | if 'batch_size' in dataloader: |
| | return dataloader['batch_size'] |
| | elif ('batch_sampler' in dataloader |
| | and 'batch_size' in dataloader['batch_sampler']): |
| | return dataloader['batch_sampler']['batch_size'] |
| | else: |
| | raise ValueError('Please set batch_size in `Dataloader` or ' |
| | '`batch_sampler`') |
| | elif isinstance(dataloader, DataLoader): |
| | return dataloader.batch_sampler.batch_size |
| | else: |
| | raise ValueError('dataloader should be a dict or a Dataloader ' |
| | f'instance, but got {type(dataloader)}') |
| |
|