| | |
| | import hashlib |
| | import logging |
| | import os.path as osp |
| | import pickle |
| | from collections import deque |
| | from math import inf |
| | from pathlib import Path |
| | from typing import Callable, Dict, List, Optional, Sequence, Union |
| |
|
| | from mmengine.dist import is_main_process, master_only |
| | from mmengine.fileio import FileClient, get_file_backend |
| | from mmengine.logging import print_log |
| | from mmengine.registry import HOOKS |
| | from mmengine.utils import is_list_of, is_seq_of |
| | from .hook import Hook |
| |
|
| | DATA_BATCH = Optional[Union[dict, tuple, list]] |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class CheckpointHook(Hook): |
| | """Save checkpoints periodically. |
| | |
| | Args: |
| | interval (int): The saving period. If ``by_epoch=True``, interval |
| | indicates epochs, otherwise it indicates iterations. |
| | Defaults to -1, which means "never". |
| | by_epoch (bool): Saving checkpoints by epoch or by iteration. |
| | Defaults to True. |
| | save_optimizer (bool): Whether to save optimizer state_dict in the |
| | checkpoint. It is usually used for resuming experiments. |
| | Defaults to True. |
| | save_param_scheduler (bool): Whether to save param_scheduler state_dict |
| | in the checkpoint. It is usually used for resuming experiments. |
| | Defaults to True. |
| | out_dir (str, Path, Optional): The root directory to save checkpoints. |
| | If not specified, ``runner.work_dir`` will be used by default. If |
| | specified, the ``out_dir`` will be the concatenation of ``out_dir`` |
| | and the last level directory of ``runner.work_dir``. For example, |
| | if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is |
| | ``./work_dir/cur_exp``, then the ckpt will be saved in |
| | ``./tmp/cur_exp``. Defaults to None. |
| | max_keep_ckpts (int): The maximum checkpoints to keep. |
| | In some cases we want only the latest few checkpoints and would |
| | like to delete old ones to save the disk space. |
| | Defaults to -1, which means unlimited. |
| | save_last (bool): Whether to force the last checkpoint to be |
| | saved regardless of interval. Defaults to True. |
| | save_best (str, List[str], optional): If a metric is specified, it |
| | would measure the best checkpoint during evaluation. If a list of |
| | metrics is passed, it would measure a group of best checkpoints |
| | corresponding to the passed metrics. The information about best |
| | checkpoint(s) would be saved in ``runner.message_hub`` to keep |
| | best score value and best checkpoint path, which will be also |
| | loaded when resuming checkpoint. Options are the evaluation metrics |
| | on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox |
| | detection and instance segmentation. ``AR@100`` for proposal |
| | recall. If ``save_best`` is ``auto``, the first key of the returned |
| | ``OrderedDict`` result will be used. Defaults to None. |
| | rule (str, List[str], optional): Comparison rule for best score. If |
| | set to None, it will infer a reasonable rule. Keys such as 'acc', |
| | 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' |
| | will be inferred by 'less' rule. If ``save_best`` is a list of |
| | metrics and ``rule`` is a str, all metrics in ``save_best`` will |
| | share the comparison rule. If ``save_best`` and ``rule`` are both |
| | lists, their length must be the same, and metrics in ``save_best`` |
| | will use the corresponding comparison rule in ``rule``. Options |
| | are 'greater', 'less', None and list which contains 'greater' and |
| | 'less'. Defaults to None. |
| | greater_keys (List[str], optional): Metric keys that will be |
| | inferred by 'greater' comparison rule. If ``None``, |
| | _default_greater_keys will be used. Defaults to None. |
| | less_keys (List[str], optional): Metric keys that will be |
| | inferred by 'less' comparison rule. If ``None``, _default_less_keys |
| | will be used. Defaults to None. |
| | file_client_args (dict, optional): Arguments to instantiate a |
| | FileClient. See :class:`mmengine.fileio.FileClient` for details. |
| | Defaults to None. It will be deprecated in future. Please use |
| | ``backend_args`` instead. |
| | filename_tmpl (str, optional): String template to indicate checkpoint |
| | name. If specified, must contain one and only one "{}", which will |
| | be replaced with ``epoch + 1`` if ``by_epoch=True`` else |
| | ``iteration + 1``. |
| | Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" |
| | accordingly. |
| | backend_args (dict, optional): Arguments to instantiate the |
| | prefix of uri corresponding backend. Defaults to None. |
| | `New in version 0.2.0.` |
| | published_keys (str, List[str], optional): If ``save_last`` is ``True`` |
| | or ``save_best`` is not ``None``, it will automatically |
| | publish model with keys in the list after training. |
| | Defaults to None. |
| | `New in version 0.7.1.` |
| | save_begin (int): Control the epoch number or iteration number |
| | at which checkpoint saving begins. Defaults to 0, which means |
| | saving at the beginning. |
| | `New in version 0.8.3.` |
| | |
| | Examples: |
| | >>> # Save best based on single metric |
| | >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', |
| | >>> rule='less') |
| | >>> # Save best based on multi metrics with the same comparison rule |
| | >>> CheckpointHook(interval=2, by_epoch=True, |
| | >>> save_best=['acc', 'mIoU'], rule='greater') |
| | >>> # Save best based on multi metrics with different comparison rule |
| | >>> CheckpointHook(interval=2, by_epoch=True, |
| | >>> save_best=['FID', 'IS'], rule=['less', 'greater']) |
| | >>> # Save best based on single metric and publish model after training |
| | >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', |
| | >>> rule='less', published_keys=['meta', 'state_dict']) |
| | """ |
| | out_dir: str |
| |
|
| | priority = 'VERY_LOW' |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} |
| | init_value_map = {'greater': -inf, 'less': inf} |
| | _default_greater_keys = [ |
| | 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', |
| | 'mAcc', 'aAcc' |
| | ] |
| | _default_less_keys = ['loss'] |
| |
|
| | def __init__(self, |
| | interval: int = -1, |
| | by_epoch: bool = True, |
| | save_optimizer: bool = True, |
| | save_param_scheduler: bool = True, |
| | out_dir: Optional[Union[str, Path]] = None, |
| | max_keep_ckpts: int = -1, |
| | save_last: bool = True, |
| | save_best: Union[str, List[str], None] = None, |
| | rule: Union[str, List[str], None] = None, |
| | greater_keys: Optional[Sequence[str]] = None, |
| | less_keys: Optional[Sequence[str]] = None, |
| | file_client_args: Optional[dict] = None, |
| | filename_tmpl: Optional[str] = None, |
| | backend_args: Optional[dict] = None, |
| | published_keys: Union[str, List[str], None] = None, |
| | save_begin: int = 0, |
| | **kwargs) -> None: |
| | self.interval = interval |
| | self.by_epoch = by_epoch |
| | self.save_optimizer = save_optimizer |
| | self.save_param_scheduler = save_param_scheduler |
| | self.out_dir = out_dir |
| | self.max_keep_ckpts = max_keep_ckpts |
| | self.save_last = save_last |
| | self.args = kwargs |
| |
|
| | if file_client_args is not None: |
| | print_log( |
| | '"file_client_args" will be deprecated in future. ' |
| | 'Please use "backend_args" instead', |
| | logger='current', |
| | level=logging.WARNING) |
| | if backend_args is not None: |
| | raise ValueError( |
| | '"file_client_args" and "backend_args" cannot be set ' |
| | 'at the same time.') |
| |
|
| | self.file_client_args = file_client_args |
| | self.backend_args = backend_args |
| |
|
| | if filename_tmpl is None: |
| | if self.by_epoch: |
| | self.filename_tmpl = 'epoch_{}.pth' |
| | else: |
| | self.filename_tmpl = 'iter_{}.pth' |
| | else: |
| | self.filename_tmpl = filename_tmpl |
| |
|
| | |
| | assert (isinstance(save_best, str) or is_list_of(save_best, str) |
| | or (save_best is None)), ( |
| | '"save_best" should be a str or list of str or None, ' |
| | f'but got {type(save_best)}') |
| |
|
| | if isinstance(save_best, list): |
| | if 'auto' in save_best: |
| | assert len(save_best) == 1, ( |
| | 'Only support one "auto" in "save_best" list.') |
| | assert len(save_best) == len( |
| | set(save_best)), ('Find duplicate element in "save_best".') |
| | else: |
| | |
| | if save_best is not None: |
| | save_best = [save_best] |
| | self.save_best = save_best |
| |
|
| | |
| | assert (isinstance(rule, str) or is_list_of(rule, str) |
| | or (rule is None)), ( |
| | '"rule" should be a str or list of str or None, ' |
| | f'but got {type(rule)}') |
| | if isinstance(rule, list): |
| | |
| | assert len(rule) in [ |
| | 1, |
| | len(self.save_best) |
| | ], ('Number of "rule" must be 1 or the same as number of ' |
| | f'"save_best", but got {len(rule)}.') |
| | else: |
| | |
| | rule = [rule] |
| |
|
| | if greater_keys is None: |
| | self.greater_keys = self._default_greater_keys |
| | else: |
| | if not isinstance(greater_keys, (list, tuple)): |
| | greater_keys = (greater_keys, ) |
| | assert is_seq_of(greater_keys, str) |
| | self.greater_keys = greater_keys |
| |
|
| | if less_keys is None: |
| | self.less_keys = self._default_less_keys |
| | else: |
| | if not isinstance(less_keys, (list, tuple)): |
| | less_keys = (less_keys, ) |
| | assert is_seq_of(less_keys, str) |
| | self.less_keys = less_keys |
| |
|
| | if self.save_best is not None: |
| | self.is_better_than: Dict[str, Callable] = dict() |
| | self._init_rule(rule, self.save_best) |
| | if len(self.key_indicators) == 1: |
| | self.best_ckpt_path: Optional[str] = None |
| | else: |
| | self.best_ckpt_path_dict: Dict = dict() |
| |
|
| | |
| | if not (isinstance(published_keys, str) |
| | or is_seq_of(published_keys, str) or published_keys is None): |
| | raise TypeError( |
| | '"published_keys" should be a str or a sequence of str or ' |
| | f'None, but got {type(published_keys)}') |
| |
|
| | if isinstance(published_keys, str): |
| | published_keys = [published_keys] |
| | elif isinstance(published_keys, (list, tuple)): |
| | assert len(published_keys) == len(set(published_keys)), ( |
| | 'Find duplicate elements in "published_keys".') |
| | self.published_keys = published_keys |
| |
|
| | self.last_ckpt = None |
| | if save_begin < 0: |
| | raise ValueError( |
| | 'save_begin should not be less than 0, but got {save_begin}') |
| | self.save_begin = save_begin |
| |
|
| | def before_train(self, runner) -> None: |
| | """Finish all operations, related to checkpoint. |
| | |
| | This function will get the appropriate file client, and the directory |
| | to save these checkpoints of the model. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if self.out_dir is None: |
| | self.out_dir = runner.work_dir |
| |
|
| | |
| | |
| | |
| | self.file_client = FileClient.infer_client(self.file_client_args, |
| | self.out_dir) |
| |
|
| | if self.file_client_args is None: |
| | self.file_backend = get_file_backend( |
| | self.out_dir, backend_args=self.backend_args) |
| | else: |
| | self.file_backend = self.file_client |
| |
|
| | |
| | |
| | |
| | |
| | if self.out_dir != runner.work_dir: |
| | basename = osp.basename(runner.work_dir.rstrip(osp.sep)) |
| | self.out_dir = self.file_backend.join_path( |
| | self.out_dir, basename) |
| |
|
| | runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') |
| |
|
| | if self.save_best is not None: |
| | if len(self.key_indicators) == 1: |
| | if 'best_ckpt' not in runner.message_hub.runtime_info: |
| | self.best_ckpt_path = None |
| | else: |
| | self.best_ckpt_path = runner.message_hub.get_info( |
| | 'best_ckpt') |
| | else: |
| | for key_indicator in self.key_indicators: |
| | best_ckpt_name = f'best_ckpt_{key_indicator}' |
| | if best_ckpt_name not in runner.message_hub.runtime_info: |
| | self.best_ckpt_path_dict[key_indicator] = None |
| | else: |
| | self.best_ckpt_path_dict[ |
| | key_indicator] = runner.message_hub.get_info( |
| | best_ckpt_name) |
| |
|
| | if self.max_keep_ckpts > 0: |
| | keep_ckpt_ids = [] |
| | if 'keep_ckpt_ids' in runner.message_hub.runtime_info: |
| | keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') |
| |
|
| | while len(keep_ckpt_ids) > self.max_keep_ckpts: |
| | step = keep_ckpt_ids.pop(0) |
| | if is_main_process(): |
| | path = self.file_backend.join_path( |
| | self.out_dir, self.filename_tmpl.format(step)) |
| | if self.file_backend.isfile(path): |
| | self.file_backend.remove(path) |
| | elif self.file_backend.isdir(path): |
| | |
| | self.file_backend.rmtree(path) |
| |
|
| | self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, |
| | self.max_keep_ckpts) |
| |
|
| | def after_train_epoch(self, runner) -> None: |
| | """Save the checkpoint and synchronize buffers after each epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if not self.by_epoch: |
| | return |
| |
|
| | |
| | |
| | |
| | if self.every_n_epochs(runner, self.interval, self.save_begin) or ( |
| | self.save_last and self.is_last_train_epoch(runner)): |
| | runner.logger.info( |
| | f'Saving checkpoint at {runner.epoch + 1} epochs') |
| | self._save_checkpoint(runner) |
| |
|
| | def after_val_epoch(self, runner, metrics): |
| | """Save the checkpoint and synchronize buffers after each evaluation |
| | epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | metrics (dict): Evaluation results of all metrics |
| | """ |
| | if len(metrics) == 0: |
| | runner.logger.warning( |
| | 'Since `metrics` is an empty dict, the behavior to save ' |
| | 'the best checkpoint will be skipped in this evaluation.') |
| | return |
| |
|
| | self._save_best_checkpoint(runner, metrics) |
| |
|
| | def after_train(self, runner) -> None: |
| | """Publish the checkpoint after training. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if self.published_keys is None: |
| | return |
| |
|
| | if self.save_last and self.last_ckpt is not None: |
| | self._publish_model(runner, self.last_ckpt) |
| |
|
| | if getattr(self, 'best_ckpt_path', None) is not None: |
| | self._publish_model(runner, str(self.best_ckpt_path)) |
| | if getattr(self, 'best_ckpt_path_dict', None) is not None: |
| | for best_ckpt in self.best_ckpt_path_dict.values(): |
| | self._publish_model(runner, best_ckpt) |
| |
|
| | @master_only |
| | def _publish_model(self, runner, ckpt_path: str) -> None: |
| | """Remove unnecessary keys from ckpt_path and save the new checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | ckpt_path (str): The checkpoint path that ought to be published. |
| | """ |
| | from mmengine.runner import save_checkpoint |
| | from mmengine.runner.checkpoint import _load_checkpoint |
| | checkpoint = _load_checkpoint(ckpt_path) |
| | assert self.published_keys is not None |
| | removed_keys = [] |
| | for key in list(checkpoint.keys()): |
| | if key not in self.published_keys: |
| | removed_keys.append(key) |
| | checkpoint.pop(key) |
| | if removed_keys: |
| | print_log( |
| | f'Key {removed_keys} will be removed because they are not ' |
| | 'found in published_keys. If you want to keep them, ' |
| | f'please set `{removed_keys}` in published_keys', |
| | logger='current') |
| | checkpoint_data = pickle.dumps(checkpoint) |
| | sha = hashlib.sha256(checkpoint_data).hexdigest() |
| | final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth' |
| | save_checkpoint(checkpoint, final_path) |
| | print_log( |
| | f'The checkpoint ({ckpt_path}) is published to ' |
| | f'{final_path}.', |
| | logger='current') |
| |
|
| | def _save_checkpoint_with_step(self, runner, step, meta): |
| | |
| | |
| | if self.max_keep_ckpts > 0: |
| | |
| | |
| | if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: |
| | pass |
| | else: |
| | if len(self.keep_ckpt_ids) == self.max_keep_ckpts: |
| | _step = self.keep_ckpt_ids.popleft() |
| | if is_main_process(): |
| | ckpt_path = self.file_backend.join_path( |
| | self.out_dir, self.filename_tmpl.format(_step)) |
| |
|
| | if self.file_backend.isfile(ckpt_path): |
| | self.file_backend.remove(ckpt_path) |
| | elif self.file_backend.isdir(ckpt_path): |
| | |
| | self.file_backend.rmtree(ckpt_path) |
| |
|
| | self.keep_ckpt_ids.append(step) |
| | runner.message_hub.update_info('keep_ckpt_ids', |
| | list(self.keep_ckpt_ids)) |
| |
|
| | ckpt_filename = self.filename_tmpl.format(step) |
| | self.last_ckpt = self.file_backend.join_path(self.out_dir, |
| | ckpt_filename) |
| | runner.message_hub.update_info('last_ckpt', self.last_ckpt) |
| |
|
| | runner.save_checkpoint( |
| | self.out_dir, |
| | ckpt_filename, |
| | self.file_client_args, |
| | save_optimizer=self.save_optimizer, |
| | save_param_scheduler=self.save_param_scheduler, |
| | meta=meta, |
| | by_epoch=self.by_epoch, |
| | backend_args=self.backend_args, |
| | **self.args) |
| |
|
| | |
| | |
| | if not is_main_process(): |
| | return |
| |
|
| | save_file = osp.join(runner.work_dir, 'last_checkpoint') |
| | with open(save_file, 'w') as f: |
| | f.write(self.last_ckpt) |
| |
|
| | def _save_checkpoint(self, runner) -> None: |
| | """Save the current checkpoint and delete outdated checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if self.by_epoch: |
| | step = runner.epoch + 1 |
| | meta = dict(epoch=step, iter=runner.iter) |
| | else: |
| | step = runner.iter + 1 |
| | meta = dict(epoch=runner.epoch, iter=step) |
| |
|
| | self._save_checkpoint_with_step(runner, step, meta=meta) |
| |
|
| | def _save_best_checkpoint(self, runner, metrics) -> None: |
| | """Save the current checkpoint and delete outdated checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | metrics (dict): Evaluation results of all metrics. |
| | """ |
| | if not self.save_best: |
| | return |
| |
|
| | if self.by_epoch: |
| | ckpt_filename = self.filename_tmpl.format(runner.epoch) |
| | cur_type, cur_time = 'epoch', runner.epoch |
| | else: |
| | ckpt_filename = self.filename_tmpl.format(runner.iter) |
| | cur_type, cur_time = 'iter', runner.iter |
| |
|
| | meta = dict(epoch=runner.epoch, iter=runner.iter) |
| |
|
| | |
| | if 'auto' in self.key_indicators: |
| | self._init_rule(self.rules, [list(metrics.keys())[0]]) |
| |
|
| | best_ckpt_updated = False |
| | |
| | |
| | for key_indicator, rule in zip(self.key_indicators, self.rules): |
| | key_score = metrics[key_indicator] |
| |
|
| | if len(self.key_indicators) == 1: |
| | best_score_key = 'best_score' |
| | runtime_best_ckpt_key = 'best_ckpt' |
| | best_ckpt_path = self.best_ckpt_path |
| | else: |
| | best_score_key = f'best_score_{key_indicator}' |
| | runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' |
| | best_ckpt_path = self.best_ckpt_path_dict[key_indicator] |
| |
|
| | if best_score_key not in runner.message_hub.runtime_info: |
| | best_score = self.init_value_map[rule] |
| | else: |
| | best_score = runner.message_hub.get_info(best_score_key) |
| |
|
| | if key_score is None or not self.is_better_than[key_indicator]( |
| | key_score, best_score): |
| | continue |
| |
|
| | best_ckpt_updated = True |
| |
|
| | best_score = key_score |
| | runner.message_hub.update_info(best_score_key, best_score) |
| |
|
| | if best_ckpt_path and is_main_process(): |
| | is_removed = False |
| | if self.file_backend.isfile(best_ckpt_path): |
| | self.file_backend.remove(best_ckpt_path) |
| | is_removed = True |
| | elif self.file_backend.isdir(best_ckpt_path): |
| | |
| | self.file_backend.rmtree(best_ckpt_path) |
| | is_removed = True |
| |
|
| | if is_removed: |
| | runner.logger.info( |
| | f'The previous best checkpoint {best_ckpt_path} ' |
| | 'is removed') |
| |
|
| | best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' |
| | |
| | best_ckpt_name = best_ckpt_name.replace('/', '_') |
| | if len(self.key_indicators) == 1: |
| | self.best_ckpt_path = self.file_backend.join_path( |
| | self.out_dir, best_ckpt_name) |
| | runner.message_hub.update_info(runtime_best_ckpt_key, |
| | self.best_ckpt_path) |
| | else: |
| | self.best_ckpt_path_dict[ |
| | key_indicator] = self.file_backend.join_path( |
| | self.out_dir, best_ckpt_name) |
| | runner.message_hub.update_info( |
| | runtime_best_ckpt_key, |
| | self.best_ckpt_path_dict[key_indicator]) |
| | runner.save_checkpoint( |
| | self.out_dir, |
| | filename=best_ckpt_name, |
| | file_client_args=self.file_client_args, |
| | save_optimizer=False, |
| | save_param_scheduler=False, |
| | meta=meta, |
| | by_epoch=False, |
| | backend_args=self.backend_args) |
| | runner.logger.info( |
| | f'The best checkpoint with {best_score:0.4f} {key_indicator} ' |
| | f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if best_ckpt_updated and self.last_ckpt is not None: |
| | self._save_checkpoint_with_step(runner, cur_time, meta) |
| |
|
| | def _init_rule(self, rules, key_indicators) -> None: |
| | """Initialize rule, key_indicator, comparison_func, and best score. If |
| | key_indicator is a list of string and rule is a string, all metric in |
| | the key_indicator will share the same rule. |
| | |
| | Here is the rule to determine which rule is used for key indicator when |
| | the rule is not specific (note that the key indicator matching is case- |
| | insensitive): |
| | |
| | 1. If the key indicator is in ``self.greater_keys``, the rule |
| | will be specified as 'greater'. |
| | 2. Or if the key indicator is in ``self.less_keys``, the rule |
| | will be specified as 'less'. |
| | 3. Or if any one item in ``self.greater_keys`` is a substring of |
| | key_indicator, the rule will be specified as 'greater'. |
| | 4. Or if any one item in ``self.less_keys`` is a substring of |
| | key_indicator, the rule will be specified as 'less'. |
| | |
| | Args: |
| | rule (List[Optional[str]]): Comparison rule for best score. |
| | key_indicator (List[str]): Key indicator to determine |
| | the comparison rule. |
| | """ |
| | if len(rules) == 1: |
| | rules = rules * len(key_indicators) |
| |
|
| | self.rules = [] |
| | for rule, key_indicator in zip(rules, key_indicators): |
| |
|
| | if rule not in self.rule_map and rule is not None: |
| | raise KeyError('rule must be greater, less or None, ' |
| | f'but got {rule}.') |
| |
|
| | if rule is None and key_indicator != 'auto': |
| | |
| | |
| | key_indicator_lc = key_indicator.lower() |
| | greater_keys = {key.lower() for key in self.greater_keys} |
| | less_keys = {key.lower() for key in self.less_keys} |
| |
|
| | if key_indicator_lc in greater_keys: |
| | rule = 'greater' |
| | elif key_indicator_lc in less_keys: |
| | rule = 'less' |
| | elif any(key in key_indicator_lc for key in greater_keys): |
| | rule = 'greater' |
| | elif any(key in key_indicator_lc for key in less_keys): |
| | rule = 'less' |
| | else: |
| | raise ValueError('Cannot infer the rule for key ' |
| | f'{key_indicator}, thus a specific rule ' |
| | 'must be specified.') |
| | if rule is not None: |
| | self.is_better_than[key_indicator] = self.rule_map[rule] |
| | self.rules.append(rule) |
| |
|
| | self.key_indicators = key_indicators |
| |
|
| | def after_train_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs=Optional[dict]) -> None: |
| | """Save the checkpoint and synchronize buffers after each iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | outputs (dict, optional): Outputs from model. |
| | """ |
| | if self.by_epoch: |
| | return |
| |
|
| | |
| | |
| | |
| | |
| | if self.every_n_train_iters(runner, self.interval, |
| | self.save_begin) or \ |
| | (self.save_last and |
| | self.is_last_train_iter(runner)): |
| | runner.logger.info( |
| | f'Saving checkpoint at {runner.iter + 1} iterations') |
| | self._save_checkpoint(runner) |
| |
|