| | |
| | import copy |
| | import inspect |
| | import os |
| | import os.path as osp |
| | import time |
| | from collections import OrderedDict |
| | from functools import partial |
| | from typing import Callable, Dict, List, Optional, Sequence, Union |
| |
|
| | import torch.nn as nn |
| | from torch.distributed.fsdp import (FullStateDictConfig, |
| | FullyShardedDataParallel, |
| | LocalStateDictConfig, StateDictType) |
| | from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
| | FullOptimStateDictConfig, LocalOptimStateDictConfig, OptimStateDictConfig, |
| | StateDictConfig) |
| | from torch.optim import Optimizer |
| | from torch.optim.lr_scheduler import LRScheduler |
| |
|
| | import mmengine |
| | from mmengine.config import Config, ConfigDict |
| | from mmengine.device import get_device |
| | from mmengine.dist import get_rank, is_main_process |
| | from mmengine.model import BaseDataPreprocessor, is_model_wrapper |
| | from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper, |
| | OptimWrapperDict, _ParamScheduler, |
| | build_optim_wrapper) |
| | from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS, |
| | PARAM_SCHEDULERS, STRATEGIES, Registry) |
| | from mmengine.utils import get_git_hash, mkdir_or_exist |
| | from .distributed import DDPStrategy |
| | from .utils import MetaTensorContext |
| |
|
| | FSDP = FullyShardedDataParallel |
| | FSDP_CONFIGS = Registry('fsdp configs') |
| | FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig) |
| | FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig) |
| | FSDP_CONFIGS.register_module(module=FullStateDictConfig) |
| | FSDP_CONFIGS.register_module(module=LocalStateDictConfig) |
| |
|
| |
|
| | @STRATEGIES.register_module() |
| | class FSDPStrategy(DDPStrategy): |
| | """Support training model with FullyShardedDataParallel (FSDP). |
| | |
| | Keyword Args: |
| | model_wrapper (dict, optional): Config dict for model wrapper. The |
| | default configuration is: |
| | |
| | Examples: |
| | >>> model_wrapper = dict( |
| | >>> type='MMFullyShardedDataParallel', |
| | >>> use_orig_params=True, |
| | >>> ) |
| | |
| | See more configurable arguments in |
| | :class:`MMFullyShardedDataParallel`. Defaults to None |
| | skip_init_weights (bool, optional): Whether to skip initialization of |
| | weights. Defaults to False. This is useful when the parameters of |
| | the large model are loaded from a checkpoint, since skipping the |
| | initialization of weights can save a lot of time. |
| | state_dict_cfg (str or dict): Configuration for |
| | how to save and load the state dict of the model, optimizer, and |
| | scheduler. |
| | |
| | - "local": save and load the sharded state dict in all ranks. |
| | - "full": save and load the full state dict in rank 0. |
| | - `dict` object: save and load the state dict more flexibly. For |
| | example, you can first offload the state dict to the 'cpu' and |
| | then save it to the disk. This can help you to load the |
| | checkpoint in a non-gpu environment: |
| | |
| | Examples: |
| | >>> state_dict_cfg=dict( |
| | >>> state_dict_type='FULL_STATE_DICT', |
| | >>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True), |
| | >>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True), |
| | |
| | See more configurable arguments for ``state_dict_cfg``, |
| | ``state_dict_config``, and ``optim_state_dict_config``in |
| | `FSDP official api documents`_ |
| | kwargs (dict): Additional arguments passed to :class:`DDPStrategy`: |
| | |
| | - work_dir (str): The working directory to save checkpoints. |
| | The logs will be saved in the subdirectory of `work_dir` named |
| | :attr:`timestamp`. Defaults to 'work_dirs'. |
| | - experiment_name (str, optional): Name of current experiment. If |
| | not specified, timestamp will be used as :attr:`experiment_name`. |
| | Defaults to None. |
| | - env_kwargs (dict, optional): Environment config passed in |
| | :meth:`setup_env`. Defaults to None. |
| | - log_kwargs (dict, optional): Logger config passed in |
| | :meth:`build_logger`. Defaults to None. |
| | activation_checkpointing (dict, optional): Config dict for gradient |
| | checkpoint. |
| | |
| | Examples: |
| | >>> activation_checkpointing = dict(check_fn='CustomCheckFn') |
| | >>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1)) |
| | |
| | |
| | ``check_fn`` field should behave consistently with |
| | ``auto_wrap_policy`` defined in `model_wrapper`, and other |
| | fields will be passed to ``apply_activation_checkpointing`` |
| | |
| | `New in version 0.9.0.` |
| | |
| | .. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type |
| | """ |
| |
|
| | def __init__(self, |
| | *, |
| | model_wrapper: Optional[dict] = None, |
| | skip_init_weights=False, |
| | state_dict_cfg: Union[str, dict] = 'local', |
| | activation_checkpointing: Optional[dict] = None, |
| | **kwargs): |
| | super().__init__(model_wrapper=model_wrapper, **kwargs) |
| | self._init_state_dict_cfg(state_dict_cfg) |
| | if not isinstance(skip_init_weights, bool): |
| | raise TypeError('skip_init_weights must be a boolean, but got ' |
| | f'{type(skip_init_weights)}') |
| | self.skip_init_weights = skip_init_weights |
| | self.activation_checkpointing = activation_checkpointing |
| |
|
| | def _wrap_model(self, model: nn.Module) -> None: |
| | """Wrap the model to :obj:``MMFullyShardedDataParallel`` or other |
| | custom fully sharded data parallel module wrappers. |
| | |
| | Args: |
| | model (nn.Module): Model to be wrapped. |
| | |
| | Returns: |
| | FullyShardedDataParallel: ``MMFullyShardedDataParallel`` |
| | or subclass of ``FullyShardedDataParallel``. |
| | """ |
| | try: |
| | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ |
| | apply_activation_checkpointing |
| | except ImportError: |
| | apply_activation_checkpointing = None |
| |
|
| | for module in model.modules(): |
| | if isinstance(module, BaseDataPreprocessor): |
| | module.to(get_device()) |
| |
|
| | if is_model_wrapper(model): |
| | return |
| |
|
| | if self.model_wrapper is None: |
| | self.model_wrapper = dict(type='MMFullyShardedDataParallel') |
| |
|
| | default_args = dict( |
| | module=model, |
| | device_id=int(os.environ['LOCAL_RANK']), |
| | type='MMFullyShardedDataParallel') |
| | model = MODEL_WRAPPERS.build( |
| | self.model_wrapper, default_args=default_args) |
| | model.set_state_dict_type(model, self.state_dict_type, |
| | self.state_dict_config, |
| | self.optim_state_dict_config) |
| |
|
| | if self.activation_checkpointing is not None: |
| | if apply_activation_checkpointing is None: |
| | raise RuntimeError( |
| | 'activation_checkpointing maybe deprecated by current ' |
| | 'PyTorch version, maybe you could switch to PyTorch 2.0 ' |
| | 'or 2.1 to use `activation_checkpointing`.') |
| | cfg = copy.deepcopy(self.activation_checkpointing) |
| | with FUNCTIONS.switch_scope_and_registry(None): |
| | check_fn = cfg.pop('check_fn') |
| | if isinstance(check_fn, str): |
| | check_fn = FUNCTIONS.get(check_fn) |
| | elif isinstance(check_fn, dict): |
| | fn_type = check_fn.pop('type') |
| | if isinstance(fn_type, str): |
| | fn_type = FUNCTIONS.get(fn_type) |
| | check_fn = partial(fn_type, **cfg) |
| |
|
| | if not callable(check_fn): |
| | raise TypeError('`check_fn` must be a callable function') |
| | apply_activation_checkpointing(model, check_fn=check_fn, **cfg) |
| | return model |
| |
|
| | def _is_full_state_dict(self): |
| | """Whether to save and load the full state_dict in rank 0.""" |
| | return self.state_dict_type == StateDictType.FULL_STATE_DICT |
| |
|
| | def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: |
| | """Build model. |
| | |
| | If skip_init_weights is True, the model will be built with an empty |
| | weights. It means that :meth:`load_checkpoint` must be called to fill |
| | the weights before training. |
| | |
| | Args: |
| | model (nn.Module or dict): A ``nn.Module`` object or a dict to |
| | build ``nn.Module`` object. If ``model`` is a ``nn.Module`` |
| | object, just returns itself. |
| | |
| | Returns: |
| | nn.Module: Model build from ``model``. |
| | """ |
| | if self.skip_init_weights: |
| | if isinstance(model, dict): |
| | |
| | with MetaTensorContext(): |
| | model = super().build_model(model) |
| | model.to_empty(device='cpu') |
| | else: |
| | model = super().build_model(model) |
| |
|
| | |
| | |
| | |
| | |
| | self.id_to_name = dict() |
| | for name, param in model.named_parameters(): |
| | self.id_to_name[id(param)] = name |
| | return model |
| |
|
| | def save_checkpoint(self, |
| | filename: str, |
| | *, |
| | save_optimizer: bool = True, |
| | save_param_scheduler: bool = True, |
| | extra_ckpt: Optional[dict] = None, |
| | callback: Optional[Callable] = None) -> None: |
| | """Save checkpoint to given ``filename``. |
| | |
| | If ``state_dict_type`` is `full`, the checkpoint will only be saved in |
| | rank0. The structure of the saved checkpoint is the same as the one |
| | saved by ``DDPStrategy`` |
| | |
| | If ``state_dict_type`` is `local`, each rank will save the sharded |
| | state dict to a directory, which means the saved structure will look |
| | like this: |
| | |
| | .. code-block:: bash |
| | |
| | ── epoch_0.pth |
| | ├── rank0.pth |
| | ├── rank1.pth |
| | ├── ... |
| | └── rank8.pth |
| | |
| | Args: |
| | filename (str): Filename to save checkpoint. |
| | |
| | Keyword Args: |
| | save_optimizer (bool): Whether to save the optimizer to |
| | the checkpoint. Defaults to True. |
| | save_param_scheduler (bool): Whether to save the param_scheduler |
| | to the checkpoint. Defaults to True. |
| | extra_ckpt (dict, optional): Extra checkpoint to save. |
| | Defaults to None. |
| | callback (callable, callable): Callback function to modify the |
| | checkpoint before saving the checkpoint. |
| | Defaults to None. |
| | """ |
| | from mmengine.runner.checkpoint import save_checkpoint |
| |
|
| | state_dict: dict = dict() |
| | state_dict['state_dict'] = self.model_state_dict() |
| |
|
| | |
| | if save_optimizer and hasattr(self, 'optim_wrapper'): |
| | state_dict['optimizer'] = self.optim_state_dict() |
| |
|
| | |
| | if save_param_scheduler and hasattr(self, 'param_schedulers'): |
| | state_dict['param_schedulers'] = self.scheduler_state_dict() |
| |
|
| | |
| | if extra_ckpt is None: |
| | extra_ckpt = dict() |
| | if 'meta' not in extra_ckpt: |
| | extra_ckpt['meta'] = dict() |
| |
|
| | extra_ckpt['meta'].update( |
| | seed=self.seed, |
| | time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), |
| | mmengine=mmengine.__version__ + get_git_hash(), |
| | ) |
| | state_dict.update(extra_ckpt) |
| |
|
| | |
| | if callback is not None: |
| | callback(state_dict) |
| |
|
| | |
| | |
| | if not self._is_full_state_dict(): |
| | rank = get_rank() |
| | mkdir_or_exist(filename) |
| | ckpt_name = f'rank{rank}.pth' |
| | filename = osp.join(filename, ckpt_name) |
| | save_checkpoint(state_dict, filename) |
| |
|
| | if is_main_process(): |
| | save_checkpoint(state_dict, filename) |
| |
|
| | def model_state_dict(self) -> dict: |
| | """Get model state dict based on the ``state_dict_type``. |
| | |
| | If ``state_dict_type`` is `full`, the model state dict will be the |
| | same as the one of original unsharded model. |
| | |
| | If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True`` |
| | in ``model_wrapper``. The key of the state dict will be the same as |
| | the one of original unsharded model, but its value will be the sharded |
| | one |
| | |
| | If ``state_dict_type`` is `local`, and ```use_orig_params``` is |
| | ``False`` in ``model_wrapper``, the flatten and sharded state dict will |
| | be returned. |
| | |
| | See more details in the `official api documents`_ |
| | |
| | .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict |
| | """ |
| | |
| | |
| | return self.model.state_dict() |
| |
|
| | def optim_state_dict(self) -> dict: |
| | """Get model state dict based on the ``state_dict_type``. |
| | |
| | If ``state_dict_type`` is ``full``, the optimizer state dict can be |
| | loaded by the original unsharded optimizer. |
| | |
| | Otherwise, the optimizer state dict could only be loaded by the |
| | optimizer with sharded parameters. |
| | |
| | Note: |
| | The optimizer state dict is not the same as the one of original |
| | optimizer even if in ``full`` mode, although they can be loaded |
| | correctly. |
| | |
| | See more details in the `official api documents`_ |
| | |
| | .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict |
| | """ |
| | return FSDP.optim_state_dict(self.model, self.optim_wrapper) |
| |
|
| | def load_checkpoint(self, filename: str, **kwargs) -> dict: |
| | """Load checkpoint from given ``filename``. |
| | |
| | Note: |
| | If ``state_dict_type`` is `local`, the filename should be a |
| | directory contains ``rank{i}.pth``. |
| | |
| | Args: |
| | filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
| | ``open-mmlab://xxx``. |
| | |
| | Keyword Args: |
| | map_location (str or callable): A string or a callable function to |
| | specifying how to remap storage locations. |
| | Defaults to 'cpu'. |
| | strict (bool): strict (bool): Whether to allow different params for |
| | the model and checkpoint. |
| | revise_keys (list): A list of customized keywords to modify the |
| | state_dict in checkpoint. Each item is a (pattern, replacement) |
| | pair of the regular expression operations. Defaults to strip |
| | the prefix 'module.' by [(r'^module\\.', '')]. |
| | callback (callable, callable): Callback function to modify the |
| | checkpoint after loading the checkpoint. |
| | Defaults to None. |
| | """ |
| | if self._is_full_state_dict(): |
| | return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) |
| | else: |
| | rank = get_rank() |
| | filename = osp.join(filename, f'rank{rank}.pth') |
| | return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) |
| |
|
| | def load_model_state_dict( |
| | self, |
| | state_dict: dict, |
| | *, |
| | strict: bool = False, |
| | revise_keys: list = [(r'^module.', '')], |
| | ) -> None: |
| | """Load model state from dict. |
| | |
| | Warning: |
| | `revise_keys` is not supported yet. |
| | |
| | Args: |
| | state_dict (dict): Model state dict returned by |
| | :meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type`` |
| | is ``full``. ``state_dict`` could be the result of |
| | ``model.state_dict()`` |
| | strict (bool): Whether to load model state dict strictly. |
| | Defaults to False. |
| | """ |
| | |
| | self.model.load_state_dict(state_dict, strict=strict) |
| |
|
| | def load_optim_state_dict(self, state_dict: dict) -> None: |
| | """Load optimizer state from dict. |
| | |
| | Args: |
| | state_dict (dict): The optimizer state dict. If ``state_dict_type`` |
| | is ``full``. ``state_dict`` could be the result of |
| | ``optimizer.state_dict()`` |
| | """ |
| | optim_state_dict = FSDP.optim_state_dict_to_load( |
| | state_dict, self.model, self.optim_wrapper.optimizer) |
| | self.optim_wrapper.load_state_dict(optim_state_dict) |
| |
|
| | def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: |
| | """Make ``state_dict_type`` and ``state_dict_config`` can be configured |
| | with string.""" |
| | if isinstance(state_dict_cfg, str): |
| | if state_dict_cfg == 'full': |
| | self.state_dict_type = StateDictType.FULL_STATE_DICT |
| | self.state_dict_config = FullStateDictConfig( |
| | rank0_only=True, offload_to_cpu=True) |
| | self.optim_state_dict_config = FullOptimStateDictConfig( |
| | rank0_only=True, offload_to_cpu=True) |
| | elif state_dict_cfg == 'local': |
| | self.state_dict_type = StateDictType.LOCAL_STATE_DICT |
| | self.state_dict_config = LocalStateDictConfig() |
| | self.optim_state_dict_config = LocalOptimStateDictConfig() |
| | else: |
| | raise ValueError('FSDP only supports `full` and `local` ' |
| | f'state_dict_type, but got {state_dict_cfg}') |
| | elif isinstance(state_dict_cfg, dict): |
| | if 'state_dict_type' not in state_dict_cfg: |
| | self.state_dict_type = StateDictType.LOCAL_STATE_DICT |
| | else: |
| | state_dict_type = state_dict_cfg['state_dict_type'] |
| | if isinstance(state_dict_type, str): |
| | self.state_dict_type = StateDictType[ |
| | state_dict_cfg['state_dict_type']] |
| | else: |
| | self.state_dict_type = state_dict_type |
| | state_dict_config = state_dict_cfg.get('state_dict_config') |
| | if state_dict_config is None: |
| | self.state_dict_config = LocalStateDictConfig() |
| | elif isinstance(state_dict_config, dict): |
| | self.state_dict_config = FSDP_CONFIGS.build( |
| | state_dict_cfg['state_dict_config']) |
| | else: |
| | self.state_dict_config = state_dict_config |
| |
|
| | optim_state_dict_config = state_dict_cfg.get( |
| | 'optim_state_dict_config') |
| | if optim_state_dict_config is None: |
| | self.optim_state_dict_config = LocalOptimStateDictConfig() |
| | elif isinstance(optim_state_dict_config, dict): |
| | self.optim_state_dict_config = FSDP_CONFIGS.build( |
| | state_dict_cfg['optim_state_dict_config']) |
| | else: |
| | self.optim_state_dict_config = optim_state_dict_config |
| | else: |
| | raise TypeError('state_dict_cfg should be a `str` or a `dict`, ' |
| | f'but got {type(state_dict_cfg)}') |
| |
|
| | if not isinstance(self.state_dict_type, StateDictType): |
| | raise TypeError('state_dict_type must be StateDictType, but got ' |
| | f'{type(self.state_dict_type)}') |
| | if not isinstance(self.state_dict_config, StateDictConfig): |
| | raise TypeError('state_dict_config must be StateDictConfig, but ' |
| | f'got {type(self.state_dict_config)}') |
| | if not isinstance(self.optim_state_dict_config, OptimStateDictConfig): |
| | raise TypeError('optim_state_dict_config must be ' |
| | 'OptimStateDictConfig, but got ' |
| | f'{type(self.optim_state_dict_config)}') |
| |
|
| | def build_optim_wrapper( |
| | self, |
| | optim_wrapper: Union[Optimizer, OptimWrapper, dict], |
| | model: Optional[nn.Module] = None, |
| | ) -> BaseOptimWrapper: |
| | """Support sharding the optimizer state dict given a built optimizer or |
| | optim_wrapper. |
| | |
| | See specific usage in :meth:`BaseStrategy.build_optim_wrapper`. |
| | """ |
| | if isinstance(optim_wrapper, Optimizer): |
| | optim_wrapper = OptimWrapper(optim_wrapper) |
| | if isinstance(optim_wrapper, BaseOptimWrapper): |
| | assert model is not None |
| | |
| | |
| | optimizer = optim_wrapper.optimizer |
| | param_groups = optimizer.param_groups |
| | optim_state_dict = optimizer.state_dict() |
| | assert not optim_state_dict['state'], ( |
| | 'Optimizer state_dict should be empty when giving an built ' |
| | 'optim_wrapper to FSDPStrategy') |
| | |
| | |
| | new_param_groups = [] |
| | for group in param_groups: |
| | new_group = { |
| | key: value |
| | for key, value in group.items() if key != 'param' |
| | } |
| | new_group['params'] = [ |
| | self.id_to_name[id(param)] for param in group['params'] |
| | ] |
| | new_param_groups.append(new_group) |
| | optim_state_dict['param_groups'] = new_param_groups |
| | defaults = { |
| | k: v |
| | for k, v in optimizer.defaults.items() if k != 'differentiable' |
| | } |
| |
|
| | params_dict = {} |
| | for k, v in model.named_parameters(): |
| | if '_fsdp_wrapped_module' in k: |
| | k = k.replace('_fsdp_wrapped_module.', '') |
| | params_dict[k] = v |
| |
|
| | params = [] |
| | for param_group in new_param_groups: |
| | _params = [] |
| | for param_name in param_group['params']: |
| | if param_name not in params_dict: |
| | raise RuntimeError( |
| | 'Failed to reconstruct the sharded optimizer. ' |
| | 'You can try to set `use_orig_params=True` in ' |
| | '`model_wrapper`') |
| | _params.append(params_dict[param_name]) |
| | param_group = { |
| | k: v |
| | for k, v in param_group.items() if k != 'param' |
| | } |
| | param_group['params'] = _params |
| | params.append(param_group) |
| |
|
| | new_optimizer = optimizer.__class__(params, **defaults) |
| |
|
| | |
| | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): |
| | optim_state_dict = FSDP.optim_state_dict_to_load( |
| | optim_state_dict, model, new_optimizer) |
| | new_optimizer.load_state_dict(optim_state_dict) |
| | optim_wrapper.optimizer = new_optimizer |
| | return optim_wrapper |
| | if isinstance(optim_wrapper, (dict, ConfigDict, Config)): |
| | assert model is not None |
| | |
| | optimizer = optim_wrapper.get('optimizer', None) |
| | optim_wrapper.setdefault('type', 'OptimWrapper') |
| | if optim_wrapper.get('type', |
| | 'AmpOptimWrapper') in ('AmpOptimWrapper', |
| | AmpOptimWrapper): |
| | optim_wrapper.setdefault('use_fsdp', True) |
| |
|
| | |
| | |
| | if isinstance(optimizer, Optimizer): |
| | return OPTIM_WRAPPERS.build(optim_wrapper) |
| |
|
| | |
| | |
| | |
| | if optimizer is not None or 'constructor' in optim_wrapper: |
| | return build_optim_wrapper(model, optim_wrapper) |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | optim_wrappers = OrderedDict() |
| | for name, optim in optim_wrapper.items(): |
| | if not isinstance(optim, OptimWrapper): |
| | raise ValueError( |
| | 'each item mush be an optimizer object when ' |
| | '"type" and "constructor" are not in ' |
| | f'optimizer, but got {name}={optim}') |
| | optim_wrappers[name] = optim |
| | return OptimWrapperDict(**optim_wrappers) |
| | else: |
| | raise TypeError('optimizer wrapper should be an OptimWrapper ' |
| | f'object or dict, but got {optim_wrapper}') |
| |
|
| | def _build_param_scheduler( |
| | self, |
| | scheduler: Union[_ParamScheduler, Dict, List], |
| | optim_wrapper: BaseOptimWrapper, |
| | default_args: dict, |
| | ) -> List[_ParamScheduler]: |
| | """Override this method to update the scheduler with the reconstructed |
| | sharded optimzer.""" |
| | if not isinstance(scheduler, Sequence): |
| | schedulers = [scheduler] |
| | else: |
| | schedulers = scheduler |
| |
|
| | max_epochs = default_args.pop('max_epochs', None) |
| | max_iters = default_args.pop('max_iters', None) |
| |
|
| | param_schedulers = [] |
| | for scheduler in schedulers: |
| | |
| | if isinstance(scheduler, (_ParamScheduler, LRScheduler)): |
| | parameter_keys = inspect.signature( |
| | scheduler.__class__).parameters.keys() |
| | kwargs = { |
| | k: v |
| | for k, v in scheduler.state_dict().items() |
| | if k in parameter_keys |
| | } |
| | scheduler = scheduler.__class__(optim_wrapper, **kwargs) |
| | elif isinstance(scheduler, dict): |
| | _scheduler = copy.deepcopy(scheduler) |
| |
|
| | |
| | if _scheduler.get('by_epoch', True): |
| | if max_epochs is None: |
| | raise ValueError( |
| | 'max_epochs must be specified in default_args') |
| | default_end = max_epochs |
| | else: |
| | if max_iters is None: |
| | raise ValueError( |
| | 'max_iters must be specified in default_args') |
| | default_end = max_iters |
| | _scheduler.setdefault('end', default_end) |
| | self.logger.debug( |
| | f'The `end` of {_scheduler["type"]} is not set. ' |
| | 'Use the max epochs/iters of train loop as default.') |
| |
|
| | param_schedulers.append( |
| | PARAM_SCHEDULERS.build( |
| | _scheduler, |
| | default_args=dict( |
| | optimizer=optim_wrapper, **default_args))) |
| | else: |
| | raise TypeError( |
| | 'scheduler should be a _ParamScheduler object or dict, ' |
| | f'but got {scheduler}') |
| | return param_schedulers |
| |
|