| | |
| | from typing import Any, Dict, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from mmengine.registry import HOOKS |
| | from mmengine.utils import get_git_hash |
| | from mmengine.version import __version__ |
| | from .hook import Hook |
| |
|
| | DATA_BATCH = Optional[Union[dict, tuple, list]] |
| |
|
| |
|
| | def _is_scalar(value: Any) -> bool: |
| | """Determine the value is a scalar type value. |
| | |
| | Args: |
| | value (Any): value of log. |
| | |
| | Returns: |
| | bool: whether the value is a scalar type value. |
| | """ |
| | if isinstance(value, np.ndarray): |
| | return value.size == 1 |
| | elif isinstance(value, (int, float, np.number)): |
| | return True |
| | elif isinstance(value, torch.Tensor): |
| | return value.numel() == 1 |
| | return False |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class RuntimeInfoHook(Hook): |
| | """A hook that updates runtime information into message hub. |
| | |
| | E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the |
| | training state. Components that cannot access the runner can get runtime |
| | information through the message hub. |
| | """ |
| |
|
| | priority = 'VERY_HIGH' |
| |
|
| | def before_run(self, runner) -> None: |
| | """Update metainfo. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | metainfo = dict( |
| | cfg=runner.cfg.pretty_text, |
| | seed=runner.seed, |
| | experiment_name=runner.experiment_name, |
| | mmengine_version=__version__ + get_git_hash()) |
| | runner.message_hub.update_info_dict(metainfo) |
| |
|
| | self.last_loop_stage = None |
| |
|
| | def before_train(self, runner) -> None: |
| | """Update resumed training state. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | runner.message_hub.update_info('loop_stage', 'train') |
| | runner.message_hub.update_info('epoch', runner.epoch) |
| | runner.message_hub.update_info('iter', runner.iter) |
| | runner.message_hub.update_info('max_epochs', runner.max_epochs) |
| | runner.message_hub.update_info('max_iters', runner.max_iters) |
| | if hasattr(runner.train_dataloader.dataset, 'metainfo'): |
| | runner.message_hub.update_info( |
| | 'dataset_meta', runner.train_dataloader.dataset.metainfo) |
| |
|
| | def after_train(self, runner) -> None: |
| | runner.message_hub.pop_info('loop_stage') |
| |
|
| | def before_train_epoch(self, runner) -> None: |
| | """Update current epoch information before every epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | runner.message_hub.update_info('epoch', runner.epoch) |
| |
|
| | def before_train_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None) -> None: |
| | """Update current iter and learning rate information before every |
| | iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (Sequence[dict], optional): Data from dataloader. |
| | Defaults to None. |
| | """ |
| | runner.message_hub.update_info('iter', runner.iter) |
| | lr_dict = runner.optim_wrapper.get_lr() |
| | assert isinstance(lr_dict, dict), ( |
| | '`runner.optim_wrapper.get_lr()` should return a dict ' |
| | 'of learning rate when training with OptimWrapper(single ' |
| | 'optimizer) or OptimWrapperDict(multiple optimizer), ' |
| | f'but got {type(lr_dict)} please check your optimizer ' |
| | 'constructor return an `OptimWrapper` or `OptimWrapperDict` ' |
| | 'instance') |
| | for name, lr in lr_dict.items(): |
| | runner.message_hub.update_scalar(f'train/{name}', lr[0]) |
| |
|
| | def after_train_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[dict] = None) -> None: |
| | """Update ``log_vars`` in model outputs every iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (Sequence[dict], optional): Data from dataloader. |
| | Defaults to None. |
| | outputs (dict, optional): Outputs from model. Defaults to None. |
| | """ |
| | if outputs is not None: |
| | for key, value in outputs.items(): |
| | runner.message_hub.update_scalar(f'train/{key}', value) |
| |
|
| | def before_val(self, runner) -> None: |
| | self.last_loop_stage = runner.message_hub.get_info('loop_stage') |
| | runner.message_hub.update_info('loop_stage', 'val') |
| |
|
| | def after_val_epoch(self, |
| | runner, |
| | metrics: Optional[Dict[str, float]] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each validation epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | metrics (Dict[str, float], optional): Evaluation results of all |
| | metrics on validation dataset. The keys are the names of the |
| | metrics, and the values are corresponding results. |
| | """ |
| | if metrics is not None: |
| | for key, value in metrics.items(): |
| | if _is_scalar(value): |
| | runner.message_hub.update_scalar(f'val/{key}', value) |
| | else: |
| | runner.message_hub.update_info(f'val/{key}', value) |
| |
|
| | def after_val(self, runner) -> None: |
| | |
| | |
| | |
| | if self.last_loop_stage == 'train': |
| | runner.message_hub.update_info('loop_stage', self.last_loop_stage) |
| | self.last_loop_stage = None |
| | else: |
| | runner.message_hub.pop_info('loop_stage') |
| |
|
| | def before_test(self, runner) -> None: |
| | runner.message_hub.update_info('loop_stage', 'test') |
| |
|
| | def after_test(self, runner) -> None: |
| | runner.message_hub.pop_info('loop_stage') |
| |
|
| | def after_test_epoch(self, |
| | runner, |
| | metrics: Optional[Dict[str, float]] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each test epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | metrics (Dict[str, float], optional): Evaluation results of all |
| | metrics on test dataset. The keys are the names of the |
| | metrics, and the values are corresponding results. |
| | """ |
| | if metrics is not None: |
| | for key, value in metrics.items(): |
| | if _is_scalar(value): |
| | runner.message_hub.update_scalar(f'test/{key}', value) |
| | else: |
| | runner.message_hub.update_info(f'test/{key}', value) |
| |
|