| |
| |
| |
| |
| |
|
|
| import math |
| import warnings |
| import weakref |
| from collections import Counter |
| from functools import wraps |
| from typing import Callable, List, Optional, Sequence, Union |
|
|
| from torch.optim import Optimizer |
|
|
| from mmengine.logging import print_log |
| from mmengine.optim import BaseOptimWrapper |
| from mmengine.registry import PARAM_SCHEDULERS |
|
|
| INF = int(1e9) |
|
|
| OptimizerType = Union[BaseOptimWrapper, Optimizer] |
|
|
|
|
| class _ParamScheduler: |
| """Base class for parameter schedulers. |
| |
| It should be inherited by all schedulers that schedule parameters in the |
| optimizer's ``param_groups``. All subclasses should overwrite the |
| ``_get_value()`` according to their own schedule strategy. |
| The implementation is motivated by |
| https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py. |
| |
| Args: |
| optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resuming without |
| state dict. Default value ``-1`` means the ``step`` function is |
| never be called before. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: OptimizerType, |
| param_name: str, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
|
|
| |
| if not isinstance(optimizer, (Optimizer, BaseOptimWrapper)): |
| raise TypeError('``optimizer`` should be an Optimizer,' |
| 'but got {}'.format(type(optimizer).__name__)) |
| self.optimizer = optimizer |
| self.param_name = param_name |
|
|
| if end <= begin: |
| raise ValueError('end should be larger than begin, but got' |
| ' begin={}, end={}'.format(begin, end)) |
| self.begin = begin |
| self.end = end |
|
|
| self.by_epoch = by_epoch |
|
|
| assert isinstance(last_step, int) and last_step >= -1 |
| |
| if last_step == -1: |
| for group in optimizer.param_groups: |
| |
| |
| group.setdefault(f'initial_{param_name}', group[param_name]) |
| else: |
| for i, group in enumerate(optimizer.param_groups): |
| if f'initial_{param_name}' not in group: |
| raise KeyError( |
| f"param 'initial_{param_name}' is not specified " |
| 'in param_groups[{}] when resuming an optimizer'. |
| format(i)) |
| self.base_values = [ |
| group[f'initial_{param_name}'] for group in optimizer.param_groups |
| ] |
| self.last_step = last_step |
|
|
| |
| |
| |
| def with_counter(method: Callable): |
| if getattr(method, '_with_counter', False): |
| |
| return method |
|
|
| |
| |
| instance_ref = weakref.ref(method.__self__) |
| |
| func = method.__func__ |
| cls = instance_ref().__class__ |
| del method |
|
|
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| instance = instance_ref() |
| instance._global_step += 1 |
| wrapped = func.__get__(instance, cls) |
| return wrapped(*args, **kwargs) |
|
|
| |
| |
| wrapper._with_counter = True |
| return wrapper |
|
|
| |
| self.optimizer.step = with_counter(self.optimizer.step) |
| self.optimizer._global_step = -1 |
|
|
| self._global_step = -1 |
| self.verbose = verbose |
|
|
| self.step() |
|
|
| def state_dict(self) -> dict: |
| """Returns the state of the scheduler as a :class:`dict`. |
| |
| It contains an entry for every variable in self.__dict__ which is not |
| the optimizer. |
| |
| Returns: |
| dict: scheduler state. |
| """ |
| return { |
| key: value |
| for key, value in self.__dict__.items() if key != 'optimizer' |
| } |
|
|
| def load_state_dict(self, state_dict: dict): |
| """Loads the schedulers state. |
| |
| Args: |
| state_dict (dict): scheduler state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| self.__dict__.update(state_dict) |
|
|
| def get_last_value(self): |
| """Return the last computed value by current scheduler. |
| |
| Returns: |
| list: A list of the last computed value of the optimizer's |
| ``param_group``. |
| """ |
| return self._last_value |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| raise NotImplementedError |
|
|
| def print_value(self, is_verbose: bool, group: int, value: float): |
| """Display the current parameter value. |
| |
| Args: |
| is_verbose (bool): Whether to print the value. |
| group (int): The index of the current ``param_group``. |
| value (float): The parameter value. |
| """ |
| if is_verbose: |
| print_log( |
| f'Adjusting parameter value of group {group} to {value:.4e}.', |
| logger='current') |
|
|
| def step(self): |
| """Adjusts the parameter value of each parameter group based on the |
| specified schedule.""" |
| |
| |
| if self._global_step == 0: |
| if not hasattr(self.optimizer.step, '_with_counter'): |
| warnings.warn( |
| 'Seems like `optimizer.step()` has been overridden after ' |
| 'parameter value scheduler initialization. Please, make ' |
| 'sure to call `optimizer.step()` before ' |
| '`scheduler.step()`. See more details at ' |
| 'https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', |
| UserWarning) |
|
|
| |
| |
| elif self.optimizer._global_step < 0: |
| warnings.warn( |
| 'Detected call of `scheduler.step()` before ' |
| '`optimizer.step()`. In PyTorch 1.1.0 and later, you ' |
| 'should call them in the opposite order: ' |
| '`optimizer.step()` before `scheduler.step()`. ' |
| 'Failure to do this will result in PyTorch skipping ' |
| 'the first value of the parameter value schedule. ' |
| 'See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', |
| UserWarning) |
| self._global_step += 1 |
|
|
| |
| if self.begin <= self._global_step < self.end: |
| self.last_step += 1 |
| values = self._get_value() |
|
|
| for i, data in enumerate(zip(self.optimizer.param_groups, values)): |
| param_group, value = data |
| param_group[self.param_name] = value |
| self.print_value(self.verbose, i, value) |
|
|
| self._last_value = [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class StepParamScheduler(_ParamScheduler): |
| """Decays the parameter value of each parameter group by gamma every |
| step_size epochs. Notice that such decay can happen simultaneously with |
| other changes to the parameter value from outside this scheduler. |
| |
| Args: |
| optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| step_size (int): Period of parameter value decay. |
| gamma (float): Multiplicative factor of parameter value decay. |
| Defaults to 0.1. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: OptimizerType, |
| param_name: str, |
| step_size: int, |
| gamma: float = 0.1, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
| self.step_size = step_size |
| self.gamma = gamma |
| super().__init__( |
| optimizer=optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| step_size, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| step_size = step_size * epoch_length |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls( |
| *args, |
| step_size=step_size, |
| begin=begin, |
| end=end, |
| by_epoch=by_epoch, |
| **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| if (self.last_step == 0) or (self.last_step % self.step_size != 0): |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
| return [ |
| group[self.param_name] * self.gamma |
| for group in self.optimizer.param_groups |
| ] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class MultiStepParamScheduler(_ParamScheduler): |
| """Decays the specified parameter in each parameter group by gamma once the |
| number of epoch reaches one of the milestones. Notice that such decay can |
| happen simultaneously with other changes to the parameter from outside this |
| scheduler. |
| |
| Args: |
| optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| milestones (list): List of epoch indices. Must be increasing. |
| gamma (float): Multiplicative factor of parameter value decay. |
| Defaults to 0.1. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: OptimizerType, |
| param_name: str, |
| milestones: List[int], |
| gamma: float = 0.1, |
| last_step: int = -1, |
| begin: int = 0, |
| end: int = INF, |
| by_epoch: bool = True, |
| verbose: bool = False): |
| self.milestones = Counter(milestones) |
| self.gamma = gamma |
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| milestones, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| milestones = [i * epoch_length for i in milestones] |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls( |
| *args, |
| milestones=milestones, |
| begin=begin, |
| end=end, |
| by_epoch=by_epoch, |
| **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| if self.last_step not in self.milestones: |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
| return [ |
| group[self.param_name] * |
| self.gamma**self.milestones[self.last_step] |
| for group in self.optimizer.param_groups |
| ] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class ConstantParamScheduler(_ParamScheduler): |
| """Decays the parameter value of each parameter group by a small constant |
| factor until the number of epoch reaches a pre-defined milestone: ``end``. |
| Notice that such decay can happen simultaneously with other changes to the |
| parameter value from outside this scheduler. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| factor (float): The number we multiply parameter value until the |
| milestone. Defaults to 1./3. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: OptimizerType, |
| param_name: str, |
| factor: float = 1.0 / 3, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
| if factor > 1.0 or factor < 0: |
| raise ValueError( |
| 'Constant multiplicative factor should between 0 and 1.') |
|
|
| self.factor = factor |
| self.total_iters = end - begin - 1 |
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| if self.last_step == 0: |
| return [ |
| group[self.param_name] * self.factor |
| for group in self.optimizer.param_groups |
| ] |
|
|
| if (self.last_step > self.total_iters |
| or (self.last_step != self.total_iters)): |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
|
|
| if self.last_step == self.total_iters: |
| return [ |
| group[self.param_name] * (1.0 / self.factor) |
| for group in self.optimizer.param_groups |
| ] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class ExponentialParamScheduler(_ParamScheduler): |
| """Decays the parameter value of each parameter group by gamma every epoch. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| gamma (float): Multiplicative factor of parameter value decay. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: OptimizerType, |
| param_name: str, |
| gamma: float, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
| self.gamma = gamma |
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| if self.last_step == 0: |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
| return [ |
| group[self.param_name] * self.gamma |
| for group in self.optimizer.param_groups |
| ] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class CosineAnnealingParamScheduler(_ParamScheduler): |
| r"""Set the parameter value of each parameter group using a cosine |
| annealing schedule, where :math:`\eta_{max}` is set to the initial value |
| and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: |
| |
| .. math:: |
| \begin{aligned} |
| \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 |
| + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), |
| & T_{cur} \neq (2k+1)T_{max}; \\ |
| \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) |
| \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), |
| & T_{cur} = (2k+1)T_{max}. |
| \end{aligned} |
| |
| Notice that because the schedule |
| is defined recursively, the parameter value can be simultaneously modified |
| outside this scheduler by other operators. If the parameter value is set |
| solely by this scheduler, the parameter value at each step becomes: |
| |
| .. math:: |
| \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + |
| \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) |
| |
| It has been proposed in |
| `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this |
| only implements the cosine annealing part of SGDR, and not the restarts. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| T_max (int, optional): Maximum number of iterations. If not specified, |
| use ``end - begin``. Defaults to None. |
| eta_min (float, optional): Minimum parameter value. Defaults to None. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| eta_min_ratio (float, optional): The ratio of the minimum parameter |
| value to the base parameter value. Either `eta_min` or |
| `eta_min_ratio` should be specified. Defaults to None. |
| New in version 0.3.2. |
| |
| .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: |
| https://arxiv.org/abs/1608.03983 |
| """ |
|
|
| def __init__(self, |
| optimizer: Union[Optimizer, BaseOptimWrapper], |
| param_name: str, |
| T_max: Optional[int] = None, |
| eta_min: Optional[float] = None, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False, |
| eta_min_ratio: Optional[float] = None): |
| |
| if eta_min is None and eta_min_ratio is None: |
| eta_min = 0. |
| assert (eta_min is None) ^ (eta_min_ratio is None), \ |
| 'Either `eta_min` or `eta_min_ratio should be specified' |
| self.T_max = T_max or (end - begin) |
| self.eta_min = eta_min |
| self.eta_min_ratio = eta_min_ratio |
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| T_max=None, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| if T_max is not None: |
| T_max = T_max * epoch_length |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls( |
| *args, |
| T_max=T_max, |
| begin=begin, |
| end=end, |
| by_epoch=by_epoch, |
| **kwargs) |
|
|
| def _get_value(self) -> list: |
| """Compute value using chainable form of the scheduler.""" |
|
|
| def _get_eta_min(base_value): |
| if self.eta_min_ratio is None: |
| return self.eta_min |
| return base_value * self.eta_min_ratio |
|
|
| if self.last_step == 0: |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
| elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: |
| return [ |
| group[self.param_name] + |
| (base_value - _get_eta_min(base_value)) * |
| (1 - math.cos(math.pi / self.T_max)) / 2 |
| for base_value, group in zip(self.base_values, |
| self.optimizer.param_groups) |
| ] |
| return [(1 + math.cos(math.pi * self.last_step / self.T_max)) / |
| (1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) * |
| (group[self.param_name] - _get_eta_min(base_value)) + |
| _get_eta_min(base_value) for base_value, group in zip( |
| self.base_values, self.optimizer.param_groups)] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class LinearParamScheduler(_ParamScheduler): |
| """Decays the parameter value of each parameter group by linearly changing |
| small multiplicative factor until the number of epoch reaches a pre-defined |
| milestone: ``end``. |
| |
| Notice that such decay can happen simultaneously with other changes to the |
| parameter value from outside this scheduler. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| start_factor (float): The number we multiply parameter value in the |
| first epoch. The multiplication factor changes towards end_factor |
| in the following epochs. Defaults to 1./3. |
| end_factor (float): The number we multiply parameter value at the end |
| of linear changing process. Defaults to 1.0. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: Union[Optimizer, BaseOptimWrapper], |
| param_name: str, |
| start_factor: float = 1.0 / 3, |
| end_factor: float = 1.0, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
| if start_factor > 1.0 or start_factor < 0: |
| raise ValueError( |
| 'Starting multiplicative factor should between 0 and 1.') |
|
|
| if end_factor > 1.0 or end_factor < 0: |
| raise ValueError( |
| 'Ending multiplicative factor should between 0 and 1.') |
|
|
| self.start_factor = start_factor |
| self.end_factor = end_factor |
| self.total_iters = end - begin - 1 |
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| if self.last_step == 0: |
| return [ |
| group[self.param_name] * self.start_factor |
| for group in self.optimizer.param_groups |
| ] |
|
|
| return [ |
| group[self.param_name] * |
| (1. + (self.end_factor - self.start_factor) / |
| (self.total_iters * self.start_factor + (self.last_step - 1) * |
| (self.end_factor - self.start_factor))) |
| for group in self.optimizer.param_groups |
| ] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class PolyParamScheduler(_ParamScheduler): |
| """Decays the parameter value of each parameter group in a polynomial decay |
| scheme. |
| |
| Notice that such decay can happen simultaneously with other changes to the |
| parameter value from outside this scheduler. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| eta_min (float): Minimum parameter value at the end of scheduling. |
| Defaults to 0. |
| power (float): The power of the polynomial. Defaults to 1.0. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: Union[Optimizer, BaseOptimWrapper], |
| param_name: str, |
| eta_min: float = 0, |
| power: float = 1.0, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
|
|
| self.eta_min = eta_min |
| self.power = power |
| self.total_iters = end - begin - 1 |
|
|
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| if self.last_step == 0: |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
|
|
| return [(group[self.param_name] - self.eta_min) * |
| (1 - 1 / (self.total_iters - self.last_step + 1))**self.power + |
| self.eta_min for group in self.optimizer.param_groups] |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class OneCycleParamScheduler(_ParamScheduler): |
| r"""Sets the parameters of each parameter group according to the |
| 1cycle learning rate policy. The 1cycle policy anneals the learning |
| rate from an initial learning rate to some maximum learning rate and then |
| from that maximum learning rate to some minimum learning rate much lower |
| than the initial learning rate. |
| This policy was initially described in the paper `Super-Convergence: |
| Very Fast Training of Neural Networks Using Large Learning Rates`_. |
| |
| The 1cycle learning rate policy changes the learning rate after every |
| batch. `step` should be called after a batch has been used for training. |
| |
| This scheduler is not chainable. |
| |
| Note also that the total number of steps in the cycle can be determined in |
| one of two ways (listed in order of precedence): |
| |
| #. A value for total_steps is explicitly provided. |
| #. If total_steps is not defined, begin and end of the ParamSchedul will |
| works for it. In this case, the number of total steps is inferred by |
| total_steps = end - begin |
| |
| The default behaviour of this scheduler follows the fastai implementation |
| of 1cycle, which claims that "unpublished work has shown even better |
| results by using only two phases". To mimic the behaviour of the original |
| paper instead, set ``three_phase=True``. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| eta_max (float or list): Upper parameter value boundaries in the cycle |
| for each parameter group. |
| total_steps (int): The total number of steps in the cycle. Note that |
| if a value is not provided here, then it will be equal to |
| ``end - begin``. Defaults to None |
| pct_start (float): The percentage of the cycle (in number of steps) |
| spent increasing the learning rate. |
| Defaults to 0.3 |
| anneal_strategy (str): {'cos', 'linear'} |
| Specifies the annealing strategy: "cos" for cosine annealing, |
| "linear" for linear annealing. |
| Defaults to 'cos' |
| div_factor (float): Determines the initial learning rate via |
| initial_param = eta_max/div_factor |
| Defaults to 25 |
| final_div_factor (float): Determines the minimum learning rate via |
| eta_min = initial_param/final_div_factor |
| Defaults to 1e4 |
| three_phase (bool): If ``True``, use a third phase of the schedule to |
| annihilate the learning rate according to 'final_div_factor' |
| instead of modifying the second phase (the first two phases will be |
| symmetrical about the step indicated by 'pct_start'). |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| |
| .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: |
| https://arxiv.org/abs/1708.07120 |
| """ |
|
|
| def __init__(self, |
| optimizer: Union[Optimizer, BaseOptimWrapper], |
| param_name: str, |
| eta_max: float = 0, |
| total_steps: Optional[int] = None, |
| pct_start: float = 0.3, |
| anneal_strategy: str = 'cos', |
| div_factor: float = 25., |
| final_div_factor: float = 1e4, |
| three_phase: bool = False, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
|
|
| assert param_name == 'lr', ('OneCycle only works for learning rate ' |
| 'updating, but got patam_name as ' |
| f'{param_name}') |
|
|
| self.eta_max = eta_max |
| self.div_factor = div_factor |
| self.final_div_factor = final_div_factor |
|
|
| |
| if total_steps is not None: |
| if total_steps <= 0 or not isinstance(total_steps, int): |
| raise ValueError('Expected positive integer total_steps, ' |
| f'but got {total_steps}') |
| self.total_steps = total_steps |
| else: |
| self.total_steps = end - begin |
|
|
| |
| if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): |
| raise ValueError('Expected float between 0 and 1 pct_start, ' |
| f'but got {pct_start}') |
|
|
| |
| if anneal_strategy not in ['cos', 'linear']: |
| raise ValueError( |
| 'anneal_strategy must by one of "cos" or "linear", ' |
| f'instead got {anneal_strategy}') |
| elif anneal_strategy == 'cos': |
| self.anneal_func = self._annealing_cos |
| elif anneal_strategy == 'linear': |
| self.anneal_func = self._annealing_linear |
|
|
| if three_phase: |
| self._schedule_phases = [ |
| { |
| 'end_step': float(pct_start * self.total_steps) - 1, |
| f'start_{param_name}': f'initial_{param_name}', |
| f'end_{param_name}': f'max_{param_name}' |
| }, |
| { |
| 'end_step': float(2 * pct_start * self.total_steps) - 2, |
| f'start_{param_name}': f'max_{param_name}', |
| f'end_{param_name}': f'initial_{param_name}' |
| }, |
| { |
| 'end_step': self.total_steps - 1, |
| f'start_{param_name}': f'initial_{param_name}', |
| f'end_{param_name}': f'min_{param_name}' |
| }, |
| ] |
| else: |
| self._schedule_phases = [ |
| { |
| 'end_step': float(pct_start * self.total_steps) - 1, |
| f'start_{param_name}': f'initial_{param_name}', |
| f'end_{param_name}': f'max_{param_name}' |
| }, |
| { |
| 'end_step': self.total_steps - 1, |
| f'start_{param_name}': f'max_{param_name}', |
| f'end_{param_name}': f'min_{param_name}' |
| }, |
| ] |
|
|
| |
| max_values = self._format_param(f'max_{param_name}', optimizer, |
| eta_max) |
| if last_step == -1: |
| for idx, group in enumerate(optimizer.param_groups): |
| group[f'initial_{param_name}'] = max_values[idx] / div_factor |
| group[f'max_{param_name}'] = max_values[idx] |
| group[f'min_{param_name}'] = \ |
| group[f'initial_{param_name}'] / final_div_factor |
|
|
| super().__init__( |
| optimizer=optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| def _format_param(self, name, optimizer, param): |
| """Return correctly formatted lr/momentum for each param group.""" |
| if isinstance(param, (list, tuple)): |
| if len(param) != len(optimizer.param_groups): |
| raise ValueError( |
| f'expected {len(optimizer.param_groups)} values ' |
| f'for {name}, got {len(param)}') |
| return param |
| else: |
| return [param] * len(optimizer.param_groups) |
|
|
| @staticmethod |
| def _annealing_cos(start, end, pct): |
| """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" |
|
|
| cos_out = math.cos(math.pi * pct) + 1 |
| return end + (start - end) / 2.0 * cos_out |
|
|
| @staticmethod |
| def _annealing_linear(start, end, pct): |
| """Linearly anneal from `start` to `end` as pct goes from 0.0 to |
| 1.0.""" |
| return (end - start) * pct + start |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| begin=0, |
| end=INF, |
| total_steps=None, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| by_epoch = False |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| if total_steps is not None: |
| total_steps = total_steps * epoch_length |
| return cls( |
| *args, |
| begin=begin, |
| end=end, |
| total_steps=total_steps, |
| by_epoch=by_epoch, |
| **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
|
|
| params = [] |
| step_num = self.last_step |
|
|
| if step_num > self.total_steps: |
| raise ValueError( |
| f'Tried to step {step_num + 1} times. ' |
| f'The specified number of total steps is {self.total_steps}') |
|
|
| for group in self.optimizer.param_groups: |
| start_step = 0 |
| for i, phase in enumerate(self._schedule_phases): |
| end_step = phase['end_step'] |
| if step_num <= end_step or i == len(self._schedule_phases) - 1: |
| pct = (step_num - start_step) / (end_step - start_step) |
| computed_param = self.anneal_func( |
| group[phase['start_' + self.param_name]], |
| group[phase['end_' + self.param_name]], pct) |
| break |
| start_step = phase['end_step'] |
|
|
| params.append(computed_param) |
|
|
| return params |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class CosineRestartParamScheduler(_ParamScheduler): |
| """Sets the parameters of each parameter group according to the cosine |
| annealing with restarts scheme. The cosine restart policy anneals the |
| parameter from the initial value to `eta_min` with a cosine annealing |
| schedule and then restarts another period from the maximum value multiplied |
| with `restart_weight`. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| periods (list[int]): Periods for each cosine anneling cycle. |
| restart_weights (list[float]): Restart weights at each |
| restart iteration. Defaults to [1]. |
| eta_min (float, optional): Minimum parameter value at the end of |
| scheduling. Defaults to None. |
| eta_min_ratio (float, optional): The ratio of minimum parameter value |
| to the base parameter value. Either `eta_min` or `eta_min_ratio` |
| should be specified. Defaults to None. |
| begin (int): Step at which to start updating the parameters. |
| Defaults to 0. |
| end (int): Step at which to stop updating the parameters. |
| Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| """ |
|
|
| def __init__(self, |
| optimizer: Union[Optimizer, BaseOptimWrapper], |
| param_name: str, |
| periods: List[int], |
| restart_weights: Sequence[float] = (1, ), |
| eta_min: Optional[float] = None, |
| eta_min_ratio: Optional[float] = None, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
| assert (eta_min is None) ^ (eta_min_ratio is None) |
| self.periods = periods |
| self.eta_min = eta_min |
| self.eta_min_ratio = eta_min_ratio |
| self.restart_weights = restart_weights |
| assert (len(self.periods) == len(self.restart_weights) |
| ), 'periods and restart_weights should have the same length.' |
| self.cumulative_periods = [ |
| sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) |
| ] |
|
|
| super().__init__( |
| optimizer, |
| param_name=param_name, |
| begin=begin, |
| end=end, |
| last_step=last_step, |
| by_epoch=by_epoch, |
| verbose=verbose) |
|
|
| @classmethod |
| def build_iter_from_epoch(cls, |
| *args, |
| periods, |
| begin=0, |
| end=INF, |
| by_epoch=True, |
| epoch_length=None, |
| **kwargs): |
| """Build an iter-based instance of this scheduler from an epoch-based |
| config.""" |
| assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ |
| 'be converted to iter-based.' |
| assert epoch_length is not None and epoch_length > 0, \ |
| f'`epoch_length` must be a positive integer, ' \ |
| f'but got {epoch_length}.' |
| periods = [p * epoch_length for p in periods] |
| by_epoch = False |
| begin = int(begin * epoch_length) |
| if end != INF: |
| end = int(end * epoch_length) |
| return cls( |
| *args, |
| periods=periods, |
| begin=begin, |
| end=end, |
| by_epoch=by_epoch, |
| **kwargs) |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| idx = self.get_position_from_periods(self.last_step, |
| self.cumulative_periods) |
| |
| if idx is None: |
| return [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
| current_weight = self.restart_weights[idx] |
| nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] |
| current_periods = self.periods[idx] |
| step = self.last_step - nearest_restart |
| values = [] |
| for base_value, group in zip(self.base_values, |
| self.optimizer.param_groups): |
| eta_max = base_value * current_weight |
| if self.eta_min_ratio is None: |
| eta_min = self.eta_min |
| else: |
| eta_min = base_value * self.eta_min_ratio |
| if step == 0: |
| values.append(eta_max) |
| else: |
| values.append( |
| (1 + math.cos(math.pi * step / current_periods)) / |
| (1 + math.cos(math.pi * (step - 1) / current_periods)) * |
| (group[self.param_name] - eta_min) + eta_min) |
|
|
| return values |
|
|
| @staticmethod |
| def get_position_from_periods( |
| iteration: int, cumulative_periods: List[int]) -> Optional[int]: |
| """Get the position from a period list. |
| |
| It will return the index of the right-closest number in the period |
| list. |
| For example, the cumulative_periods = [100, 200, 300, 400], |
| if iteration == 50, return 0; |
| if iteration == 210, return 2; |
| if iteration == 300, return 3. |
| |
| Args: |
| iteration (int): Current iteration. |
| cumulative_periods (list[int]): Cumulative period list. |
| |
| Returns: |
| Optional[int]: The position of the right-closest number in the |
| period list. If not in the period, return None. |
| """ |
| for i, period in enumerate(cumulative_periods): |
| if iteration < period: |
| return i |
| return None |
|
|
|
|
| @PARAM_SCHEDULERS.register_module() |
| class ReduceOnPlateauParamScheduler(_ParamScheduler): |
| """Reduce the parameters of each parameter group when a metric has stopped |
| improving. Models often benefit from reducing the parameters by a factor of |
| 2-10 once learning stagnates. This scheduler reads a metrics quantity and |
| if no improvement is seen for a ``patience`` number of epochs, the |
| parameters are reduced. |
| |
| The implementation is motivated by `PyTorch ReduceLROnPlateau`_. |
| |
| Args: |
| optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped |
| optimizer. |
| param_name (str): Name of the parameter to be adjusted, such as |
| ``lr``, ``momentum``. |
| monitor (str): The name of the metric to measure whether |
| the performance of the model is improved. |
| rule (str): One of `less`, `greater`. In `less` rule, parameters will |
| be reduced when the quantity monitored has stopped |
| decreasing; in `greater` rule it will be reduced when the |
| quantity monitored has stopped increasing. Defaults to 'less'. |
| The ``rule`` is the renaming of ``mode`` in pytorch. |
| factor (float): Factor by which the parameters will be |
| reduced. new_param = param * factor. Defaults to 0.1. |
| patience (int): Number of epochs with no improvement after |
| which parameters will be reduced. For example, if |
| ``patience = 2``, then we will ignore the first 2 epochs |
| with no improvement, and will only decrease the parameters after |
| the 3rd epoch if the monitor value still hasn't improved then. |
| Defaults to 10. |
| threshold (float): Threshold for measuring the new optimum, |
| to only focus on significant changes. Defaults to 1e-4. |
| threshold_rule (str): One of `rel`, `abs`. In `rel` rule, |
| dynamic_threshold = best * ( 1 + threshold ) in 'greater' |
| rule or best * ( 1 - threshold ) in `less` rule. |
| In `abs` rule, dynamic_threshold = best + threshold in |
| `greater` rule or best - threshold in `less` rule. |
| Defaults to 'rel'. |
| cooldown (int): Number of epochs to wait before resuming |
| normal operation after parameters have been reduced. Defaults to 0. |
| min_value (float or list[float]): A scalar or a sequence of scalars. |
| A lower bound on the parameters of each parameter group |
| respectively. Defaults to 0. . |
| eps (float): Minimal decay applied to parameters. If the difference |
| between new and old parameters are smaller than eps, the update is |
| ignored. Defaults to 1e-8. |
| begin (int): Step at which to start triggering the scheduler |
| to monitor in val within the interval calculated |
| according to epoch of training. Defaults to 0. |
| end (int): Step at which to stop triggering the scheduler |
| to monitor in val within the interval calculated |
| according to epoch of training. Defaults to INF. |
| last_step (int): The index of last step. Used for resume without |
| state dict. Defaults to -1. |
| by_epoch (bool): Whether the scheduled parameters are updated by |
| epochs. Defaults to True. |
| verbose (bool): Whether to print the value for each update. |
| Defaults to False. |
| |
| .. _PyTorch ReduceLROnPlateau: |
| https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py |
| """ |
|
|
| need_val_args = True |
|
|
| def __init__(self, |
| optimizer: OptimizerType, |
| param_name: str, |
| monitor: str = 'loss', |
| rule: str = 'less', |
| factor: float = 0.1, |
| patience: int = 10, |
| threshold: float = 1e-4, |
| threshold_rule: str = 'rel', |
| cooldown: int = 0, |
| min_value: Union[float, Sequence[float]] = 0., |
| eps: float = 1e-8, |
| begin: int = 0, |
| end: int = INF, |
| last_step: int = -1, |
| by_epoch: bool = True, |
| verbose: bool = False): |
|
|
| |
| if not isinstance(optimizer, (Optimizer, BaseOptimWrapper)): |
| raise TypeError('``optimizer`` should be an Optimizer,' |
| 'but got {}'.format(type(optimizer).__name__)) |
| self.optimizer = optimizer |
| self.param_name = param_name |
|
|
| if end <= begin: |
| raise ValueError('end should be larger than begin, but got' |
| ' begin={}, end={}'.format(begin, end)) |
| self.begin = begin |
| self.end = end |
|
|
| assert by_epoch, \ |
| f'Now {type(self).__name__} only support by_epoch=True' |
| self.by_epoch = by_epoch |
|
|
| assert isinstance(last_step, int) and last_step >= -1 |
| |
| if last_step == -1: |
| for group in optimizer.param_groups: |
| |
| |
| group.setdefault(f'initial_{param_name}', group[param_name]) |
| else: |
| for i, group in enumerate(optimizer.param_groups): |
| if f'initial_{param_name}' not in group: |
| raise KeyError( |
| f"param 'initial_{param_name}' is not specified " |
| 'in param_groups[{}] when resuming an optimizer'. |
| format(i)) |
|
|
| self.last_step = last_step |
|
|
| self._global_step = 0 |
| self.verbose = verbose |
|
|
| if factor >= 1.0: |
| raise ValueError('Factor should be < 1.0.') |
| self.factor = factor |
|
|
| |
| |
| |
| |
| |
| if isinstance(optimizer, BaseOptimWrapper): |
| raw_optimizer = optimizer.optimizer |
| else: |
| raw_optimizer = optimizer |
|
|
| if isinstance(min_value, (list, tuple)): |
| if len(min_value) != len(raw_optimizer.param_groups): |
| raise ValueError('expected {} min_lrs, got {}'.format( |
| len(raw_optimizer.param_groups), len(min_value))) |
| self.min_values = list(min_value) |
| |
| |
| |
| if isinstance(optimizer, BaseOptimWrapper) and \ |
| optimizer.base_param_settings is not None: |
| self.min_values.append(self.min_values[-1]) |
|
|
| else: |
| self.min_values = [min_value] * len( |
| optimizer.param_groups) |
|
|
| self.patience = patience |
| self.cooldown = cooldown |
| self.cooldown_counter = 0 |
| self.rule_worse = None |
| self.best = None |
| self.num_bad_epochs = 0 |
| self.eps = eps |
|
|
| self.monitor = monitor |
| self._init_is_better( |
| rule=rule, threshold=threshold, threshold_rule=threshold_rule) |
| self._reset() |
|
|
| |
| self._last_value = [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
|
|
| def step(self, metrics=None): |
| """Adjusts the parameter value of each parameter group based on the |
| specified schedule. |
| |
| Args: |
| metrics (Dict[str, float], optional): Evaluation results of all |
| metrics on validation dataset. The keys are the names of the |
| metrics, and the values are corresponding results. |
| Defaults to None. |
| """ |
| if metrics is None: |
| |
| self._global_step += 1 |
| return |
|
|
| if not isinstance(metrics, dict): |
| raise TypeError('metrics type should be dict,' |
| f' but got type {type(metrics)}') |
|
|
| |
| if self.begin <= self._global_step < self.end: |
| self.last_step += 1 |
|
|
| |
| metric = metrics.get(self.monitor, None) |
| if metric is not None: |
| if self._is_better(metric, self.best): |
| self.best = metric |
| self.num_bad_epochs = 0 |
| else: |
| self.num_bad_epochs += 1 |
|
|
| if self._in_cooldown(): |
| self.cooldown_counter -= 1 |
| self.num_bad_epochs = 0 |
|
|
| if self.num_bad_epochs > self.patience: |
| values = self._get_value() |
|
|
| for i, data in enumerate( |
| zip(self.optimizer.param_groups, values)): |
| param_group, value = data |
| if param_group[self.param_name] - value > self.eps: |
| param_group[self.param_name] = value |
| self.print_value(self.verbose, i, value) |
| self.cooldown_counter = self.cooldown |
| self.num_bad_epochs = 0 |
|
|
| else: |
| raise KeyError(f'Excepted key in {list(metrics.keys())},' |
| f' but got key {self.monitor} is not in dict') |
|
|
| self._last_value = [ |
| group[self.param_name] for group in self.optimizer.param_groups |
| ] |
|
|
| def print_value(self, is_verbose: bool, group: int, value: float) -> None: |
| """Display the current parameter value. |
| |
| Args: |
| is_verbose (bool): Whether to print the value. |
| group (int): The index of the current ``param_group``. |
| value (float): The parameter value. |
| """ |
| if is_verbose: |
| step_name = 'epoch' if self.by_epoch else 'iter' |
| print_log( |
| f'Adjusting parameter value of group {group} to {value:.4e} ' |
| f'in {step_name} {self.last_step}.', |
| logger='current') |
|
|
| def _get_value(self): |
| """Compute value using chainable form of the scheduler.""" |
| values = [ |
| float(group[self.param_name]) * self.factor |
| for group in self.optimizer.param_groups |
| ] |
| return [max(v, min_v) for v, min_v in zip(values, self.min_values)] |
|
|
| def _in_cooldown(self): |
| """Judge whether it is in cooldown.""" |
| return self.cooldown_counter > 0 |
|
|
| def _is_better(self, a, best): |
| """Judge whether the monitor value is better.""" |
| if self.rule == 'less' and self.threshold_rule == 'rel': |
| rel_epsilon = 1. - self.threshold |
| return a < best * rel_epsilon |
|
|
| elif self.rule == 'less' and self.threshold_rule == 'abs': |
| return a < best - self.threshold |
|
|
| elif self.rule == 'greater' and self.threshold_rule == 'rel': |
| rel_epsilon = self.threshold + 1. |
| return a > best * rel_epsilon |
|
|
| else: |
| return a > best + self.threshold |
|
|
| def _init_is_better(self, rule, threshold, threshold_rule): |
| """Initialize rule and its associated values.""" |
| if threshold < 0: |
| raise ValueError(f'threshold {threshold} should be >= 0.') |
| if rule not in {'less', 'greater'}: |
| raise ValueError(f'mode {rule} is unknown!') |
| if threshold_rule not in {'rel', 'abs'}: |
| raise ValueError(f'threshold mode {threshold_rule}' |
| ' is unknown!') |
|
|
| if rule == 'less': |
| self.rule_worse = INF |
| else: |
| self.rule_worse = -INF |
|
|
| self.rule = rule |
| self.threshold = threshold |
| self.threshold_rule = threshold_rule |
|
|
| def _reset(self): |
| """Resets num_bad_epochs counter and cooldown counter.""" |
| self.best = self.rule_worse |
| self.cooldown_counter = 0 |
| self.num_bad_epochs = 0 |
|
|