| |
| import logging |
| from abc import abstractmethod |
| from copy import deepcopy |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
| from mmengine.logging import print_log |
| from mmengine.registry import MODELS |
|
|
|
|
| class BaseAveragedModel(nn.Module): |
| """A base class for averaging model weights. |
| |
| Weight averaging, such as SWA and EMA, is a widely used technique for |
| training neural networks. This class implements the averaging process |
| for a model. All subclasses must implement the `avg_func` method. |
| This class creates a copy of the provided module :attr:`model` |
| on the :attr:`device` and allows computing running averages of the |
| parameters of the :attr:`model`. |
| |
| The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py. |
| |
| Different from the `AveragedModel` in PyTorch, we use in-place operation |
| to improve the parameter updating speed, which is about 5 times faster |
| than the non-in-place version. |
| |
| In mmengine, we provide two ways to use the model averaging: |
| |
| 1. Use the model averaging module in hook: |
| We provide an :class:`mmengine.hooks.EMAHook` to apply the model |
| averaging during training. Add ``custom_hooks=[dict(type='EMAHook')]`` |
| to the config or the runner. |
| |
| 2. Use the model averaging module directly in the algorithm. Take the ema |
| teacher in semi-supervise as an example: |
| |
| >>> from mmengine.model import ExponentialMovingAverage |
| >>> student = ResNet(depth=50) |
| >>> # use ema model as teacher |
| >>> ema_teacher = ExponentialMovingAverage(student) |
| |
| Args: |
| model (nn.Module): The model to be averaged. |
| interval (int): Interval between two updates. Defaults to 1. |
| device (torch.device, optional): If provided, the averaged model will |
| be stored on the :attr:`device`. Defaults to None. |
| update_buffers (bool): if True, it will compute running averages for |
| both the parameters and the buffers of the model. Defaults to |
| False. |
| """ |
|
|
| def __init__(self, |
| model: nn.Module, |
| interval: int = 1, |
| device: Optional[torch.device] = None, |
| update_buffers: bool = False) -> None: |
| super().__init__() |
| self.module = deepcopy(model).requires_grad_(False) |
| self.interval = interval |
| if device is not None: |
| self.module = self.module.to(device) |
| self.register_buffer('steps', |
| torch.tensor(0, dtype=torch.long, device=device)) |
| self.update_buffers = update_buffers |
| if update_buffers: |
| self.avg_parameters = self.module.state_dict() |
| else: |
| self.avg_parameters = dict(self.module.named_parameters()) |
|
|
| @abstractmethod |
| def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| steps: int) -> None: |
| """Use in-place operation to compute the average of the parameters. All |
| subclasses must implement this method. |
| |
| Args: |
| averaged_param (Tensor): The averaged parameters. |
| source_param (Tensor): The source parameters. |
| steps (int): The number of times the parameters have been |
| updated. |
| """ |
|
|
| def forward(self, *args, **kwargs): |
| """Forward method of the averaged model.""" |
| return self.module(*args, **kwargs) |
|
|
| def update_parameters(self, model: nn.Module) -> None: |
| """Update the parameters of the model. This method will execute the |
| ``avg_func`` to compute the new parameters and update the model's |
| parameters. |
| |
| Args: |
| model (nn.Module): The model whose parameters will be averaged. |
| """ |
| src_parameters = ( |
| model.state_dict() |
| if self.update_buffers else dict(model.named_parameters())) |
| if self.steps == 0: |
| for k, p_avg in self.avg_parameters.items(): |
| p_avg.data.copy_(src_parameters[k].data) |
| elif self.steps % self.interval == 0: |
| for k, p_avg in self.avg_parameters.items(): |
| if p_avg.dtype.is_floating_point: |
| device = p_avg.device |
| self.avg_func(p_avg.data, |
| src_parameters[k].data.to(device), |
| self.steps) |
| if not self.update_buffers: |
| |
| |
| for b_avg, b_src in zip(self.module.buffers(), model.buffers()): |
| b_avg.data.copy_(b_src.data.to(b_avg.device)) |
| self.steps += 1 |
|
|
|
|
| @MODELS.register_module() |
| class StochasticWeightAverage(BaseAveragedModel): |
| """Implements the stochastic weight averaging (SWA) of the model. |
| |
| Stochastic Weight Averaging was proposed in `Averaging Weights Leads to |
| Wider Optima and Better Generalization, UAI 2018. |
| <https://arxiv.org/abs/1803.05407>`_ by Pavel Izmailov, Dmitrii |
| Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson. |
| """ |
|
|
| def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| steps: int) -> None: |
| """Compute the average of the parameters using stochastic weight |
| average. |
| |
| Args: |
| averaged_param (Tensor): The averaged parameters. |
| source_param (Tensor): The source parameters. |
| steps (int): The number of times the parameters have been |
| updated. |
| """ |
| averaged_param.add_( |
| source_param - averaged_param, |
| alpha=1 / float(steps // self.interval + 1)) |
|
|
|
|
| @MODELS.register_module() |
| class ExponentialMovingAverage(BaseAveragedModel): |
| r"""Implements the exponential moving average (EMA) of the model. |
| |
| All parameters are updated by the formula as below: |
| |
| .. math:: |
| |
| Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, |
| :math:`Xema_{t+1}` is the moving average and :math:`X_t` is the |
| new observed value. The value of momentum is usually a small number, |
| allowing observed values to slowly update the ema parameters. |
| |
| Args: |
| model (nn.Module): The model to be averaged. |
| momentum (float): The momentum used for updating ema parameter. |
| Defaults to 0.0002. |
| Ema's parameter are updated with the formula |
| :math:`averaged\_param = (1-momentum) * averaged\_param + |
| momentum * source\_param`. |
| interval (int): Interval between two updates. Defaults to 1. |
| device (torch.device, optional): If provided, the averaged model will |
| be stored on the :attr:`device`. Defaults to None. |
| update_buffers (bool): if True, it will compute running averages for |
| both the parameters and the buffers of the model. Defaults to |
| False. |
| """ |
|
|
| def __init__(self, |
| model: nn.Module, |
| momentum: float = 0.0002, |
| interval: int = 1, |
| device: Optional[torch.device] = None, |
| update_buffers: bool = False) -> None: |
| super().__init__(model, interval, device, update_buffers) |
| assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\ |
| f'but got {momentum}' |
| if momentum > 0.5: |
| print_log( |
| 'The value of momentum in EMA is usually a small number,' |
| 'which is different from the conventional notion of ' |
| f'momentum but got {momentum}. Please make sure the ' |
| f'value is correct.', |
| logger='current', |
| level=logging.WARNING) |
| self.momentum = momentum |
|
|
| def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| steps: int) -> None: |
| """Compute the moving average of the parameters using exponential |
| moving average. |
| |
| Args: |
| averaged_param (Tensor): The averaged parameters. |
| source_param (Tensor): The source parameters. |
| steps (int): The number of times the parameters have been |
| updated. |
| """ |
| averaged_param.lerp_(source_param, self.momentum) |
|
|
|
|
| @MODELS.register_module() |
| class MomentumAnnealingEMA(ExponentialMovingAverage): |
| r"""Exponential moving average (EMA) with momentum annealing strategy. |
| |
| Args: |
| model (nn.Module): The model to be averaged. |
| momentum (float): The momentum used for updating ema parameter. |
| Defaults to 0.0002. |
| Ema's parameter are updated with the formula |
| :math:`averaged\_param = (1-momentum) * averaged\_param + |
| momentum * source\_param`. |
| gamma (int): Use a larger momentum early in training and gradually |
| annealing to a smaller value to update the ema model smoothly. The |
| momentum is calculated as max(momentum, gamma / (gamma + steps)) |
| Defaults to 100. |
| interval (int): Interval between two updates. Defaults to 1. |
| device (torch.device, optional): If provided, the averaged model will |
| be stored on the :attr:`device`. Defaults to None. |
| update_buffers (bool): if True, it will compute running averages for |
| both the parameters and the buffers of the model. Defaults to |
| False. |
| """ |
|
|
| def __init__(self, |
| model: nn.Module, |
| momentum: float = 0.0002, |
| gamma: int = 100, |
| interval: int = 1, |
| device: Optional[torch.device] = None, |
| update_buffers: bool = False) -> None: |
| super().__init__( |
| model=model, |
| momentum=momentum, |
| interval=interval, |
| device=device, |
| update_buffers=update_buffers) |
| assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' |
| self.gamma = gamma |
|
|
| def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| steps: int) -> None: |
| """Compute the moving average of the parameters using the linear |
| momentum strategy. |
| |
| Args: |
| averaged_param (Tensor): The averaged parameters. |
| source_param (Tensor): The source parameters. |
| steps (int): The number of times the parameters have been |
| updated. |
| """ |
| momentum = max(self.momentum, |
| self.gamma / (self.gamma + self.steps.item())) |
| averaged_param.lerp_(source_param, momentum) |
|
|