| | |
| | import copy |
| | import itertools |
| | import logging |
| | from typing import Dict, Optional |
| |
|
| | from mmengine.logging import print_log |
| | from mmengine.model import is_model_wrapper |
| | from mmengine.registry import HOOKS, MODELS |
| | from .hook import DATA_BATCH, Hook |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class EMAHook(Hook): |
| | """A Hook to apply Exponential Moving Average (EMA) on the model during |
| | training. |
| | |
| | Note: |
| | - EMAHook takes priority over CheckpointHook. |
| | - The original model parameters are actually saved in ema field after |
| | train. |
| | - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. |
| | |
| | Args: |
| | ema_type (str): The type of EMA strategy to use. You can find the |
| | supported strategies in :mod:`mmengine.model.averaged_model`. |
| | Defaults to 'ExponentialMovingAverage'. |
| | strict_load (bool): Whether to strictly enforce that the keys of |
| | ``state_dict`` in checkpoint match the keys returned by |
| | ``self.module.state_dict``. Defaults to False. |
| | Changed in v0.3.0. |
| | begin_iter (int): The number of iteration to enable ``EMAHook``. |
| | Defaults to 0. |
| | begin_epoch (int): The number of epoch to enable ``EMAHook``. |
| | Defaults to 0. |
| | **kwargs: Keyword arguments passed to subclasses of |
| | :obj:`BaseAveragedModel` |
| | """ |
| |
|
| | priority = 'NORMAL' |
| |
|
| | def __init__(self, |
| | ema_type: str = 'ExponentialMovingAverage', |
| | strict_load: bool = False, |
| | begin_iter: int = 0, |
| | begin_epoch: int = 0, |
| | **kwargs): |
| | self.strict_load = strict_load |
| | self.ema_cfg = dict(type=ema_type, **kwargs) |
| | assert not (begin_iter != 0 and begin_epoch != 0), ( |
| | '`begin_iter` and `begin_epoch` should not be both set.') |
| | assert begin_iter >= 0, ( |
| | '`begin_iter` must larger than or equal to 0, ' |
| | f'but got begin_iter: {begin_iter}') |
| | assert begin_epoch >= 0, ( |
| | '`begin_epoch` must larger than or equal to 0, ' |
| | f'but got begin_epoch: {begin_epoch}') |
| | self.begin_iter = begin_iter |
| | self.begin_epoch = begin_epoch |
| | |
| | |
| | self.enabled_by_epoch = self.begin_epoch > 0 |
| |
|
| | def before_run(self, runner) -> None: |
| | """Create an ema copy of the model. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | model = runner.model |
| | if is_model_wrapper(model): |
| | model = model.module |
| | self.src_model = model |
| | self.ema_model = MODELS.build( |
| | self.ema_cfg, default_args=dict(model=self.src_model)) |
| |
|
| | def before_train(self, runner) -> None: |
| | """Check the begin_epoch/iter is smaller than max_epochs/iters. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if self.enabled_by_epoch: |
| | assert self.begin_epoch <= runner.max_epochs, ( |
| | 'self.begin_epoch should be smaller than or equal to ' |
| | f'runner.max_epochs: {runner.max_epochs}, but got ' |
| | f'begin_epoch: {self.begin_epoch}') |
| | else: |
| | assert self.begin_iter <= runner.max_iters, ( |
| | 'self.begin_iter should be smaller than or equal to ' |
| | f'runner.max_iters: {runner.max_iters}, but got ' |
| | f'begin_iter: {self.begin_iter}') |
| |
|
| | def after_train_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[dict] = None) -> None: |
| | """Update ema parameter. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (Sequence[dict], optional): Data from dataloader. |
| | Defaults to None. |
| | outputs (dict, optional): Outputs from model. Defaults to None. |
| | """ |
| | if self._ema_started(runner): |
| | self.ema_model.update_parameters(self.src_model) |
| | else: |
| | ema_params = self.ema_model.module.state_dict() |
| | src_params = self.src_model.state_dict() |
| | for k, p in ema_params.items(): |
| | p.data.copy_(src_params[k].data) |
| |
|
| | def before_val_epoch(self, runner) -> None: |
| | """We load parameter values from ema model to source model before |
| | validation. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | self._swap_ema_parameters() |
| |
|
| | def after_val_epoch(self, |
| | runner, |
| | metrics: Optional[Dict[str, float]] = None) -> None: |
| | """We recover source model's parameter from ema model after validation. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | metrics (Dict[str, float], optional): Evaluation results of all |
| | metrics on validation dataset. The keys are the names of the |
| | metrics, and the values are corresponding results. |
| | """ |
| | self._swap_ema_parameters() |
| |
|
| | def before_test_epoch(self, runner) -> None: |
| | """We load parameter values from ema model to source model before test. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | self._swap_ema_parameters() |
| |
|
| | def after_test_epoch(self, |
| | runner, |
| | metrics: Optional[Dict[str, float]] = None) -> None: |
| | """We recover source model's parameter from ema model after test. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | metrics (Dict[str, float], optional): Evaluation results of all |
| | metrics on test dataset. The keys are the names of the |
| | metrics, and the values are corresponding results. |
| | """ |
| | self._swap_ema_parameters() |
| |
|
| | def before_save_checkpoint(self, runner, checkpoint: dict) -> None: |
| | """Save ema parameters to checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | """ |
| | checkpoint['ema_state_dict'] = self.ema_model.state_dict() |
| | |
| | |
| | |
| | |
| | |
| | self._swap_ema_state_dict(checkpoint) |
| |
|
| | def after_load_checkpoint(self, runner, checkpoint: dict) -> None: |
| | """Resume ema parameters from checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | """ |
| | from mmengine.runner.checkpoint import load_state_dict |
| | if 'ema_state_dict' in checkpoint and runner._resume: |
| | |
| | |
| | self._swap_ema_state_dict(checkpoint) |
| | self.ema_model.load_state_dict( |
| | checkpoint['ema_state_dict'], strict=self.strict_load) |
| |
|
| | |
| | else: |
| | if runner._resume: |
| | print_log( |
| | 'There is no `ema_state_dict` in checkpoint. ' |
| | '`EMAHook` will make a copy of `state_dict` as the ' |
| | 'initial `ema_state_dict`', 'current', logging.WARNING) |
| | load_state_dict( |
| | self.ema_model.module, |
| | copy.deepcopy(checkpoint['state_dict']), |
| | strict=self.strict_load) |
| |
|
| | def _swap_ema_parameters(self) -> None: |
| | """Swap the parameter of model with ema_model.""" |
| | avg_param = ( |
| | itertools.chain(self.ema_model.module.parameters(), |
| | self.ema_model.module.buffers()) |
| | if self.ema_model.update_buffers else |
| | self.ema_model.module.parameters()) |
| | src_param = ( |
| | itertools.chain(self.src_model.parameters(), |
| | self.src_model.buffers()) |
| | if self.ema_model.update_buffers else self.src_model.parameters()) |
| | for p_avg, p_src in zip(avg_param, src_param): |
| | tmp = p_avg.data.clone() |
| | p_avg.data.copy_(p_src.data) |
| | p_src.data.copy_(tmp) |
| |
|
| | def _swap_ema_state_dict(self, checkpoint): |
| | """Swap the state dict values of model with ema_model.""" |
| | model_state = checkpoint['state_dict'] |
| | ema_state = checkpoint['ema_state_dict'] |
| | for k in ema_state: |
| | if k[:7] == 'module.': |
| | tmp = ema_state[k] |
| | ema_state[k] = model_state[k[7:]] |
| | model_state[k[7:]] = tmp |
| |
|
| | def _ema_started(self, runner) -> bool: |
| | """Whether ``EMAHook`` has been initialized at current iteration or |
| | epoch. |
| | |
| | :attr:`ema_model` will be initialized when ``runner.iter`` or |
| | ``runner.epoch`` is greater than ``self.begin`` for the first time. |
| | |
| | Args: |
| | runner (Runner): Runner of the training, validation process. |
| | |
| | Returns: |
| | bool: Whether ``EMAHook`` has been initialized. |
| | """ |
| | if self.enabled_by_epoch: |
| | return runner.epoch + 1 >= self.begin_epoch |
| | else: |
| | return runner.iter + 1 >= self.begin_iter |
| |
|