| | |
| | import copy |
| | import logging |
| | from collections import OrderedDict |
| | from typing import TYPE_CHECKING, Any, Optional, Union |
| |
|
| | import numpy as np |
| |
|
| | from mmengine.utils import ManagerMixin |
| | from .history_buffer import HistoryBuffer |
| | from .logger import print_log |
| |
|
| | if TYPE_CHECKING: |
| | import torch |
| |
|
| |
|
| | class MessageHub(ManagerMixin): |
| | """Message hub for component interaction. MessageHub is created and |
| | accessed in the same way as ManagerMixin. |
| | |
| | ``MessageHub`` will record log information and runtime information. The |
| | log information refers to the learning rate, loss, etc. of the model |
| | during training phase, which will be stored as ``HistoryBuffer``. The |
| | runtime information refers to the iter times, meta information of |
| | runner etc., which will be overwritten by next update. |
| | |
| | Args: |
| | name (str): Name of message hub used to get corresponding instance |
| | globally. |
| | log_scalars (dict, optional): Each key-value pair in the |
| | dictionary is the name of the log information such as "loss", "lr", |
| | "metric" and their corresponding values. The type of value must be |
| | HistoryBuffer. Defaults to None. |
| | runtime_info (dict, optional): Each key-value pair in the |
| | dictionary is the name of the runtime information and their |
| | corresponding values. Defaults to None. |
| | resumed_keys (dict, optional): Each key-value pair in the |
| | dictionary decides whether the key in :attr:`_log_scalars` and |
| | :attr:`_runtime_info` will be serialized. |
| | |
| | Note: |
| | Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or |
| | :attr:`_runtime_info`. The corresponding value cannot be set |
| | repeatedly. |
| | |
| | Examples: |
| | >>> # create empty `MessageHub`. |
| | >>> message_hub1 = MessageHub('name') |
| | >>> log_scalars = dict(loss=HistoryBuffer()) |
| | >>> runtime_info = dict(task='task') |
| | >>> resumed_keys = dict(loss=True) |
| | >>> # create `MessageHub` from data. |
| | >>> message_hub2 = MessageHub( |
| | >>> name='name', |
| | >>> log_scalars=log_scalars, |
| | >>> runtime_info=runtime_info, |
| | >>> resumed_keys=resumed_keys) |
| | """ |
| |
|
| | def __init__(self, |
| | name: str, |
| | log_scalars: Optional[dict] = None, |
| | runtime_info: Optional[dict] = None, |
| | resumed_keys: Optional[dict] = None): |
| | super().__init__(name) |
| | self._log_scalars = self._parse_input('log_scalars', log_scalars) |
| | self._runtime_info = self._parse_input('runtime_info', runtime_info) |
| | self._resumed_keys = self._parse_input('resumed_keys', resumed_keys) |
| |
|
| | for value in self._log_scalars.values(): |
| | assert isinstance(value, HistoryBuffer), \ |
| | ("The type of log_scalars'value must be HistoryBuffer, but " |
| | f'got {type(value)}') |
| |
|
| | for key in self._resumed_keys.keys(): |
| | assert key in self._log_scalars or key in self._runtime_info, \ |
| | ('Key in `resumed_keys` must contained in `log_scalars` or ' |
| | f'`runtime_info`, but got {key}') |
| |
|
| | @classmethod |
| | def get_current_instance(cls) -> 'MessageHub': |
| | """Get latest created ``MessageHub`` instance. |
| | |
| | :obj:`MessageHub` can call :meth:`get_current_instance` before any |
| | instance has been created, and return a message hub with the instance |
| | name "mmengine". |
| | |
| | Returns: |
| | MessageHub: Empty ``MessageHub`` instance. |
| | """ |
| | if not cls._instance_dict: |
| | cls.get_instance('mmengine') |
| | return super().get_current_instance() |
| |
|
| | def update_scalar(self, |
| | key: str, |
| | value: Union[int, float, np.ndarray, 'torch.Tensor'], |
| | count: int = 1, |
| | resumed: bool = True) -> None: |
| | """Update :attr:_log_scalars. |
| | |
| | Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key |
| | ``HistoryBuffer`` has been created, ``value`` and ``count`` is the |
| | argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar`` |
| | will create an ``HistoryBuffer`` with value and count via the |
| | constructor of ``HistoryBuffer``. |
| | |
| | Examples: |
| | >>> message_hub = MessageHub(name='name') |
| | >>> # create loss `HistoryBuffer` with value=1, count=1 |
| | >>> message_hub.update_scalar('loss', 1) |
| | >>> # update loss `HistoryBuffer` with value |
| | >>> message_hub.update_scalar('loss', 3) |
| | >>> message_hub.update_scalar('loss', 3, resumed=False) |
| | AssertionError: loss used to be true, but got false now. resumed |
| | keys cannot be modified repeatedly' |
| | |
| | Note: |
| | The ``resumed`` argument needs to be consistent for the same |
| | ``key``. |
| | |
| | Args: |
| | key (str): Key of ``HistoryBuffer``. |
| | value (torch.Tensor or np.ndarray or int or float): Value of log. |
| | count (torch.Tensor or np.ndarray or int or float): Accumulation |
| | times of log, defaults to 1. `count` will be used in smooth |
| | statistics. |
| | resumed (str): Whether the corresponding ``HistoryBuffer`` |
| | could be resumed. Defaults to True. |
| | """ |
| | self._set_resumed_keys(key, resumed) |
| | checked_value = self._get_valid_value(value) |
| | assert isinstance(count, int), ( |
| | f'The type of count must be int. but got {type(count): {count}}') |
| | if key in self._log_scalars: |
| | self._log_scalars[key].update(checked_value, count) |
| | else: |
| | self._log_scalars[key] = HistoryBuffer([checked_value], [count]) |
| |
|
| | def update_scalars(self, log_dict: dict, resumed: bool = True) -> None: |
| | """Update :attr:`_log_scalars` with a dict. |
| | |
| | ``update_scalars`` iterates through each pair of log_dict key-value, |
| | and calls ``update_scalar``. If type of value is dict, the value should |
| | be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in |
| | ``log_dict`` has the same resume option. |
| | |
| | Note: |
| | The ``resumed`` argument needs to be consistent for the same |
| | ``log_dict``. |
| | |
| | Args: |
| | log_dict (str): Used for batch updating :attr:`_log_scalars`. |
| | resumed (bool): Whether all ``HistoryBuffer`` referred in |
| | log_dict should be resumed. Defaults to True. |
| | |
| | Examples: |
| | >>> message_hub = MessageHub.get_instance('mmengine') |
| | >>> log_dict = dict(a=1, b=2, c=3) |
| | >>> message_hub.update_scalars(log_dict) |
| | >>> # The default count of `a`, `b` and `c` is 1. |
| | >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) |
| | >>> message_hub.update_scalars(log_dict) |
| | >>> # The count of `c` is 2. |
| | """ |
| | assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, ' |
| | f'but got {type(log_dict)}') |
| | for log_name, log_val in log_dict.items(): |
| | if isinstance(log_val, dict): |
| | assert 'value' in log_val, \ |
| | f'value must be defined in {log_val}' |
| | count = self._get_valid_value(log_val.get('count', 1)) |
| | value = log_val['value'] |
| | else: |
| | count = 1 |
| | value = log_val |
| | assert isinstance(count, |
| | int), ('The type of count must be int. but got ' |
| | f'{type(count): {count}}') |
| | self.update_scalar(log_name, value, count, resumed) |
| |
|
| | def update_info(self, key: str, value: Any, resumed: bool = True) -> None: |
| | """Update runtime information. |
| | |
| | The key corresponding runtime information will be overwritten each |
| | time calling ``update_info``. |
| | |
| | Note: |
| | The ``resumed`` argument needs to be consistent for the same |
| | ``key``. |
| | |
| | Examples: |
| | >>> message_hub = MessageHub(name='name') |
| | >>> message_hub.update_info('iter', 100) |
| | |
| | Args: |
| | key (str): Key of runtime information. |
| | value (Any): Value of runtime information. |
| | resumed (bool): Whether the corresponding ``HistoryBuffer`` |
| | could be resumed. |
| | """ |
| | self._set_resumed_keys(key, resumed) |
| | self._runtime_info[key] = value |
| |
|
| | def pop_info(self, key: str, default: Optional[Any] = None) -> Any: |
| | """Remove runtime information by key. If the key does not exist, this |
| | method will return the default value. |
| | |
| | Args: |
| | key (str): Key of runtime information. |
| | default (Any, optional): The default returned value for the |
| | given key. |
| | |
| | Returns: |
| | Any: The runtime information if the key exists. |
| | """ |
| | return self._runtime_info.pop(key, default) |
| |
|
| | def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None: |
| | """Update runtime information with dictionary. |
| | |
| | The key corresponding runtime information will be overwritten each |
| | time calling ``update_info``. |
| | |
| | Note: |
| | The ``resumed`` argument needs to be consistent for the same |
| | ``info_dict``. |
| | |
| | Examples: |
| | >>> message_hub = MessageHub(name='name') |
| | >>> message_hub.update_info({'iter': 100}) |
| | |
| | Args: |
| | info_dict (str): Runtime information dictionary. |
| | resumed (bool): Whether the corresponding ``HistoryBuffer`` |
| | could be resumed. |
| | """ |
| | assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, ' |
| | f'but got {type(info_dict)}') |
| | for key, value in info_dict.items(): |
| | self.update_info(key, value, resumed=resumed) |
| |
|
| | def _set_resumed_keys(self, key: str, resumed: bool) -> None: |
| | """Set corresponding resumed keys. |
| | |
| | This method is called by ``update_scalar``, ``update_scalars`` and |
| | ``update_info`` to set the corresponding key is true or false in |
| | :attr:`_resumed_keys`. |
| | |
| | Args: |
| | key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`. |
| | resumed (bool): Whether the corresponding ``HistoryBuffer`` |
| | could be resumed. |
| | """ |
| | if key not in self._resumed_keys: |
| | self._resumed_keys[key] = resumed |
| | else: |
| | assert self._resumed_keys[key] == resumed, \ |
| | f'{key} used to be {self._resumed_keys[key]}, but got ' \ |
| | '{resumed} now. resumed keys cannot be modified repeatedly.' |
| |
|
| | @property |
| | def log_scalars(self) -> OrderedDict: |
| | """Get all ``HistoryBuffer`` instances. |
| | |
| | Note: |
| | Considering the large memory footprint of history buffers in the |
| | post-training, :meth:`get_scalar` will return a reference of |
| | history buffer rather than a copy. |
| | |
| | Returns: |
| | OrderedDict: All ``HistoryBuffer`` instances. |
| | """ |
| | return self._log_scalars |
| |
|
| | @property |
| | def runtime_info(self) -> OrderedDict: |
| | """Get all runtime information. |
| | |
| | Returns: |
| | OrderedDict: A copy of all runtime information. |
| | """ |
| | return self._runtime_info |
| |
|
| | def get_scalar(self, key: str) -> HistoryBuffer: |
| | """Get ``HistoryBuffer`` instance by key. |
| | |
| | Note: |
| | Considering the large memory footprint of history buffers in the |
| | post-training, :meth:`get_scalar` will not return a reference of |
| | history buffer rather than a copy. |
| | |
| | Args: |
| | key (str): Key of ``HistoryBuffer``. |
| | |
| | Returns: |
| | HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the |
| | key exists. |
| | """ |
| | if key not in self.log_scalars: |
| | raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' |
| | f'instance name is: {MessageHub.instance_name}') |
| | return self.log_scalars[key] |
| |
|
| | def get_info(self, key: str, default: Optional[Any] = None) -> Any: |
| | """Get runtime information by key. If the key does not exist, this |
| | method will return default information. |
| | |
| | Args: |
| | key (str): Key of runtime information. |
| | default (Any, optional): The default returned value for the |
| | given key. |
| | |
| | Returns: |
| | Any: A copy of corresponding runtime information if the key exists. |
| | """ |
| | if key not in self.runtime_info: |
| | return default |
| | else: |
| | |
| | |
| | return self._runtime_info[key] |
| |
|
| | def _get_valid_value( |
| | self, |
| | value: Union['torch.Tensor', np.ndarray, np.number, int, float], |
| | ) -> Union[int, float]: |
| | """Convert value to python built-in type. |
| | |
| | Args: |
| | value (torch.Tensor or np.ndarray or np.number or int or float): |
| | value of log. |
| | |
| | Returns: |
| | float or int: python built-in type value. |
| | """ |
| | if isinstance(value, (np.ndarray, np.number)): |
| | assert value.size == 1 |
| | value = value.item() |
| | elif isinstance(value, (int, float)): |
| | value = value |
| | else: |
| | |
| | |
| | assert hasattr(value, 'numel') and value.numel() == 1 |
| | value = value.item() |
| | return value |
| |
|
| | def state_dict(self) -> dict: |
| | """Returns a dictionary containing log scalars, runtime information and |
| | resumed keys, which should be resumed. |
| | |
| | The returned ``state_dict`` can be loaded by :meth:`load_state_dict`. |
| | |
| | Returns: |
| | dict: A dictionary contains ``log_scalars``, ``runtime_info`` and |
| | ``resumed_keys``. |
| | """ |
| | saved_scalars = OrderedDict() |
| | saved_info = OrderedDict() |
| |
|
| | for key, value in self._log_scalars.items(): |
| | if self._resumed_keys.get(key, False): |
| | saved_scalars[key] = copy.deepcopy(value) |
| |
|
| | for key, value in self._runtime_info.items(): |
| | if self._resumed_keys.get(key, False): |
| | try: |
| | saved_info[key] = copy.deepcopy(value) |
| | except: |
| | print_log( |
| | f'{key} in message_hub cannot be copied, ' |
| | f'just return its reference. ', |
| | logger='current', |
| | level=logging.WARNING) |
| | saved_info[key] = value |
| | return dict( |
| | log_scalars=saved_scalars, |
| | runtime_info=saved_info, |
| | resumed_keys=self._resumed_keys) |
| |
|
| | def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None: |
| | """Loads log scalars, runtime information and resumed keys from |
| | ``state_dict`` or ``message_hub``. |
| | |
| | If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it |
| | will only make copies of data which should be resumed from the source |
| | ``message_hub``. |
| | |
| | If ``state_dict`` is a ``message_hub`` instance, it will make copies of |
| | all data from the source message_hub. We suggest to load data from |
| | ``dict`` rather than a ``MessageHub`` instance. |
| | |
| | Args: |
| | state_dict (dict or MessageHub): A dictionary contains key |
| | ``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a |
| | MessageHub instance. |
| | """ |
| | if isinstance(state_dict, dict): |
| | for key in ('log_scalars', 'runtime_info', 'resumed_keys'): |
| | assert key in state_dict, ( |
| | 'The loaded `state_dict` of `MessageHub` must contain ' |
| | f'key: `{key}`') |
| | |
| | |
| | for key, value in state_dict['log_scalars'].items(): |
| | if not isinstance(value, HistoryBuffer): |
| | print_log( |
| | f'{key} in message_hub is not HistoryBuffer, ' |
| | f'just skip resuming it.', |
| | logger='current', |
| | level=logging.WARNING) |
| | continue |
| | self.log_scalars[key] = value |
| |
|
| | for key, value in state_dict['runtime_info'].items(): |
| | try: |
| | self._runtime_info[key] = copy.deepcopy(value) |
| | except: |
| | print_log( |
| | f'{key} in message_hub cannot be copied, ' |
| | f'just return its reference.', |
| | logger='current', |
| | level=logging.WARNING) |
| | self._runtime_info[key] = value |
| |
|
| | for key, value in state_dict['resumed_keys'].items(): |
| | if key not in set(self.log_scalars.keys()) | \ |
| | set(self._runtime_info.keys()): |
| | print_log( |
| | f'resumed key: {key} is not defined in message_hub, ' |
| | f'just skip resuming this key.', |
| | logger='current', |
| | level=logging.WARNING) |
| | continue |
| | elif not value: |
| | print_log( |
| | f'Although resumed key: {key} is False, {key} ' |
| | 'will still be loaded this time. This key will ' |
| | 'not be saved by the next calling of ' |
| | '`MessageHub.state_dict()`', |
| | logger='current', |
| | level=logging.WARNING) |
| | self._resumed_keys[key] = value |
| |
|
| | |
| | |
| | |
| | else: |
| | self._log_scalars = copy.deepcopy(state_dict._log_scalars) |
| | self._runtime_info = copy.deepcopy(state_dict._runtime_info) |
| | self._resumed_keys = copy.deepcopy(state_dict._resumed_keys) |
| |
|
| | def _parse_input(self, name: str, value: Any) -> OrderedDict: |
| | """Parse input value. |
| | |
| | Args: |
| | name (str): name of input value. |
| | value (Any): Input value. |
| | |
| | Returns: |
| | dict: Parsed input value. |
| | """ |
| | if value is None: |
| | return OrderedDict() |
| | elif isinstance(value, dict): |
| | return OrderedDict(value) |
| | else: |
| | raise TypeError(f'{name} should be a dict or `None`, but ' |
| | f'got {type(name)}') |
| |
|