| | |
| | import logging |
| | from abc import abstractmethod |
| | from copy import deepcopy |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| |
|
| | from mmengine.logging import print_log |
| | from mmengine.registry import MODELS |
| |
|
| |
|
| | class BaseAveragedModel(nn.Module): |
| | """A base class for averaging model weights. |
| | |
| | Weight averaging, such as SWA and EMA, is a widely used technique for |
| | training neural networks. This class implements the averaging process |
| | for a model. All subclasses must implement the `avg_func` method. |
| | This class creates a copy of the provided module :attr:`model` |
| | on the :attr:`device` and allows computing running averages of the |
| | parameters of the :attr:`model`. |
| | |
| | The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py. |
| | |
| | Different from the `AveragedModel` in PyTorch, we use in-place operation |
| | to improve the parameter updating speed, which is about 5 times faster |
| | than the non-in-place version. |
| | |
| | In mmengine, we provide two ways to use the model averaging: |
| | |
| | 1. Use the model averaging module in hook: |
| | We provide an :class:`mmengine.hooks.EMAHook` to apply the model |
| | averaging during training. Add ``custom_hooks=[dict(type='EMAHook')]`` |
| | to the config or the runner. |
| | |
| | 2. Use the model averaging module directly in the algorithm. Take the ema |
| | teacher in semi-supervise as an example: |
| | |
| | >>> from mmengine.model import ExponentialMovingAverage |
| | >>> student = ResNet(depth=50) |
| | >>> # use ema model as teacher |
| | >>> ema_teacher = ExponentialMovingAverage(student) |
| | |
| | Args: |
| | model (nn.Module): The model to be averaged. |
| | interval (int): Interval between two updates. Defaults to 1. |
| | device (torch.device, optional): If provided, the averaged model will |
| | be stored on the :attr:`device`. Defaults to None. |
| | update_buffers (bool): if True, it will compute running averages for |
| | both the parameters and the buffers of the model. Defaults to |
| | False. |
| | """ |
| |
|
| | def __init__(self, |
| | model: nn.Module, |
| | interval: int = 1, |
| | device: Optional[torch.device] = None, |
| | update_buffers: bool = False) -> None: |
| | super().__init__() |
| | self.module = deepcopy(model).requires_grad_(False) |
| | self.interval = interval |
| | if device is not None: |
| | self.module = self.module.to(device) |
| | self.register_buffer('steps', |
| | torch.tensor(0, dtype=torch.long, device=device)) |
| | self.update_buffers = update_buffers |
| | if update_buffers: |
| | self.avg_parameters = self.module.state_dict() |
| | else: |
| | self.avg_parameters = dict(self.module.named_parameters()) |
| |
|
| | @abstractmethod |
| | def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| | steps: int) -> None: |
| | """Use in-place operation to compute the average of the parameters. All |
| | subclasses must implement this method. |
| | |
| | Args: |
| | averaged_param (Tensor): The averaged parameters. |
| | source_param (Tensor): The source parameters. |
| | steps (int): The number of times the parameters have been |
| | updated. |
| | """ |
| |
|
| | def forward(self, *args, **kwargs): |
| | """Forward method of the averaged model.""" |
| | return self.module(*args, **kwargs) |
| |
|
| | def update_parameters(self, model: nn.Module) -> None: |
| | """Update the parameters of the model. This method will execute the |
| | ``avg_func`` to compute the new parameters and update the model's |
| | parameters. |
| | |
| | Args: |
| | model (nn.Module): The model whose parameters will be averaged. |
| | """ |
| | src_parameters = ( |
| | model.state_dict() |
| | if self.update_buffers else dict(model.named_parameters())) |
| | if self.steps == 0: |
| | for k, p_avg in self.avg_parameters.items(): |
| | p_avg.data.copy_(src_parameters[k].data) |
| | elif self.steps % self.interval == 0: |
| | for k, p_avg in self.avg_parameters.items(): |
| | if p_avg.dtype.is_floating_point: |
| | device = p_avg.device |
| | self.avg_func(p_avg.data, |
| | src_parameters[k].data.to(device), |
| | self.steps) |
| | if not self.update_buffers: |
| | |
| | |
| | for b_avg, b_src in zip(self.module.buffers(), model.buffers()): |
| | b_avg.data.copy_(b_src.data.to(b_avg.device)) |
| | self.steps += 1 |
| |
|
| |
|
| | @MODELS.register_module() |
| | class StochasticWeightAverage(BaseAveragedModel): |
| | """Implements the stochastic weight averaging (SWA) of the model. |
| | |
| | Stochastic Weight Averaging was proposed in `Averaging Weights Leads to |
| | Wider Optima and Better Generalization, UAI 2018. |
| | <https://arxiv.org/abs/1803.05407>`_ by Pavel Izmailov, Dmitrii |
| | Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson. |
| | """ |
| |
|
| | def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| | steps: int) -> None: |
| | """Compute the average of the parameters using stochastic weight |
| | average. |
| | |
| | Args: |
| | averaged_param (Tensor): The averaged parameters. |
| | source_param (Tensor): The source parameters. |
| | steps (int): The number of times the parameters have been |
| | updated. |
| | """ |
| | averaged_param.add_( |
| | source_param - averaged_param, |
| | alpha=1 / float(steps // self.interval + 1)) |
| |
|
| |
|
| | @MODELS.register_module() |
| | class ExponentialMovingAverage(BaseAveragedModel): |
| | r"""Implements the exponential moving average (EMA) of the model. |
| | |
| | All parameters are updated by the formula as below: |
| | |
| | .. math:: |
| | |
| | Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t |
| | |
| | .. note:: |
| | This :attr:`momentum` argument is different from one used in optimizer |
| | classes and the conventional notion of momentum. Mathematically, |
| | :math:`Xema_{t+1}` is the moving average and :math:`X_t` is the |
| | new observed value. The value of momentum is usually a small number, |
| | allowing observed values to slowly update the ema parameters. |
| | |
| | Args: |
| | model (nn.Module): The model to be averaged. |
| | momentum (float): The momentum used for updating ema parameter. |
| | Defaults to 0.0002. |
| | Ema's parameter are updated with the formula |
| | :math:`averaged\_param = (1-momentum) * averaged\_param + |
| | momentum * source\_param`. |
| | interval (int): Interval between two updates. Defaults to 1. |
| | device (torch.device, optional): If provided, the averaged model will |
| | be stored on the :attr:`device`. Defaults to None. |
| | update_buffers (bool): if True, it will compute running averages for |
| | both the parameters and the buffers of the model. Defaults to |
| | False. |
| | """ |
| |
|
| | def __init__(self, |
| | model: nn.Module, |
| | momentum: float = 0.0002, |
| | interval: int = 1, |
| | device: Optional[torch.device] = None, |
| | update_buffers: bool = False) -> None: |
| | super().__init__(model, interval, device, update_buffers) |
| | assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\ |
| | f'but got {momentum}' |
| | if momentum > 0.5: |
| | print_log( |
| | 'The value of momentum in EMA is usually a small number,' |
| | 'which is different from the conventional notion of ' |
| | f'momentum but got {momentum}. Please make sure the ' |
| | f'value is correct.', |
| | logger='current', |
| | level=logging.WARNING) |
| | self.momentum = momentum |
| |
|
| | def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| | steps: int) -> None: |
| | """Compute the moving average of the parameters using exponential |
| | moving average. |
| | |
| | Args: |
| | averaged_param (Tensor): The averaged parameters. |
| | source_param (Tensor): The source parameters. |
| | steps (int): The number of times the parameters have been |
| | updated. |
| | """ |
| | averaged_param.lerp_(source_param, self.momentum) |
| |
|
| |
|
| | @MODELS.register_module() |
| | class MomentumAnnealingEMA(ExponentialMovingAverage): |
| | r"""Exponential moving average (EMA) with momentum annealing strategy. |
| | |
| | Args: |
| | model (nn.Module): The model to be averaged. |
| | momentum (float): The momentum used for updating ema parameter. |
| | Defaults to 0.0002. |
| | Ema's parameter are updated with the formula |
| | :math:`averaged\_param = (1-momentum) * averaged\_param + |
| | momentum * source\_param`. |
| | gamma (int): Use a larger momentum early in training and gradually |
| | annealing to a smaller value to update the ema model smoothly. The |
| | momentum is calculated as max(momentum, gamma / (gamma + steps)) |
| | Defaults to 100. |
| | interval (int): Interval between two updates. Defaults to 1. |
| | device (torch.device, optional): If provided, the averaged model will |
| | be stored on the :attr:`device`. Defaults to None. |
| | update_buffers (bool): if True, it will compute running averages for |
| | both the parameters and the buffers of the model. Defaults to |
| | False. |
| | """ |
| |
|
| | def __init__(self, |
| | model: nn.Module, |
| | momentum: float = 0.0002, |
| | gamma: int = 100, |
| | interval: int = 1, |
| | device: Optional[torch.device] = None, |
| | update_buffers: bool = False) -> None: |
| | super().__init__( |
| | model=model, |
| | momentum=momentum, |
| | interval=interval, |
| | device=device, |
| | update_buffers=update_buffers) |
| | assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' |
| | self.gamma = gamma |
| |
|
| | def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| | steps: int) -> None: |
| | """Compute the moving average of the parameters using the linear |
| | momentum strategy. |
| | |
| | Args: |
| | averaged_param (Tensor): The averaged parameters. |
| | source_param (Tensor): The source parameters. |
| | steps (int): The number of times the parameters have been |
| | updated. |
| | """ |
| | momentum = max(self.momentum, |
| | self.gamma / (self.gamma + self.steps.item())) |
| | averaged_param.lerp_(source_param, momentum) |
| |
|