| |
| import hashlib |
| import logging |
| import os.path as osp |
| import pickle |
| from collections import deque |
| from math import inf |
| from pathlib import Path |
| from typing import Callable, Dict, List, Optional, Sequence, Union |
|
|
| from mmengine.dist import is_main_process, master_only |
| from mmengine.fileio import FileClient, get_file_backend |
| from mmengine.logging import print_log |
| from mmengine.registry import HOOKS |
| from mmengine.utils import is_list_of, is_seq_of |
| from .hook import Hook |
|
|
| DATA_BATCH = Optional[Union[dict, tuple, list]] |
|
|
|
|
| @HOOKS.register_module() |
| class CheckpointHook(Hook): |
| """Save checkpoints periodically. |
| |
| Args: |
| interval (int): The saving period. If ``by_epoch=True``, interval |
| indicates epochs, otherwise it indicates iterations. |
| Defaults to -1, which means "never". |
| by_epoch (bool): Saving checkpoints by epoch or by iteration. |
| Defaults to True. |
| save_optimizer (bool): Whether to save optimizer state_dict in the |
| checkpoint. It is usually used for resuming experiments. |
| Defaults to True. |
| save_param_scheduler (bool): Whether to save param_scheduler state_dict |
| in the checkpoint. It is usually used for resuming experiments. |
| Defaults to True. |
| out_dir (str, Path, Optional): The root directory to save checkpoints. |
| If not specified, ``runner.work_dir`` will be used by default. If |
| specified, the ``out_dir`` will be the concatenation of ``out_dir`` |
| and the last level directory of ``runner.work_dir``. For example, |
| if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is |
| ``./work_dir/cur_exp``, then the ckpt will be saved in |
| ``./tmp/cur_exp``. Defaults to None. |
| max_keep_ckpts (int): The maximum checkpoints to keep. |
| In some cases we want only the latest few checkpoints and would |
| like to delete old ones to save the disk space. |
| Defaults to -1, which means unlimited. |
| save_last (bool): Whether to force the last checkpoint to be |
| saved regardless of interval. Defaults to True. |
| save_best (str, List[str], optional): If a metric is specified, it |
| would measure the best checkpoint during evaluation. If a list of |
| metrics is passed, it would measure a group of best checkpoints |
| corresponding to the passed metrics. The information about best |
| checkpoint(s) would be saved in ``runner.message_hub`` to keep |
| best score value and best checkpoint path, which will be also |
| loaded when resuming checkpoint. Options are the evaluation metrics |
| on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox |
| detection and instance segmentation. ``AR@100`` for proposal |
| recall. If ``save_best`` is ``auto``, the first key of the returned |
| ``OrderedDict`` result will be used. Defaults to None. |
| rule (str, List[str], optional): Comparison rule for best score. If |
| set to None, it will infer a reasonable rule. Keys such as 'acc', |
| 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' |
| will be inferred by 'less' rule. If ``save_best`` is a list of |
| metrics and ``rule`` is a str, all metrics in ``save_best`` will |
| share the comparison rule. If ``save_best`` and ``rule`` are both |
| lists, their length must be the same, and metrics in ``save_best`` |
| will use the corresponding comparison rule in ``rule``. Options |
| are 'greater', 'less', None and list which contains 'greater' and |
| 'less'. Defaults to None. |
| greater_keys (List[str], optional): Metric keys that will be |
| inferred by 'greater' comparison rule. If ``None``, |
| _default_greater_keys will be used. Defaults to None. |
| less_keys (List[str], optional): Metric keys that will be |
| inferred by 'less' comparison rule. If ``None``, _default_less_keys |
| will be used. Defaults to None. |
| file_client_args (dict, optional): Arguments to instantiate a |
| FileClient. See :class:`mmengine.fileio.FileClient` for details. |
| Defaults to None. It will be deprecated in future. Please use |
| ``backend_args`` instead. |
| filename_tmpl (str, optional): String template to indicate checkpoint |
| name. If specified, must contain one and only one "{}", which will |
| be replaced with ``epoch + 1`` if ``by_epoch=True`` else |
| ``iteration + 1``. |
| Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" |
| accordingly. |
| backend_args (dict, optional): Arguments to instantiate the |
| prefix of uri corresponding backend. Defaults to None. |
| `New in version 0.2.0.` |
| published_keys (str, List[str], optional): If ``save_last`` is ``True`` |
| or ``save_best`` is not ``None``, it will automatically |
| publish model with keys in the list after training. |
| Defaults to None. |
| `New in version 0.7.1.` |
| save_begin (int): Control the epoch number or iteration number |
| at which checkpoint saving begins. Defaults to 0, which means |
| saving at the beginning. |
| `New in version 0.8.3.` |
| |
| Examples: |
| >>> # Save best based on single metric |
| >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', |
| >>> rule='less') |
| >>> # Save best based on multi metrics with the same comparison rule |
| >>> CheckpointHook(interval=2, by_epoch=True, |
| >>> save_best=['acc', 'mIoU'], rule='greater') |
| >>> # Save best based on multi metrics with different comparison rule |
| >>> CheckpointHook(interval=2, by_epoch=True, |
| >>> save_best=['FID', 'IS'], rule=['less', 'greater']) |
| >>> # Save best based on single metric and publish model after training |
| >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', |
| >>> rule='less', published_keys=['meta', 'state_dict']) |
| """ |
| out_dir: str |
|
|
| priority = 'VERY_LOW' |
|
|
| |
| |
| |
| |
|
|
| rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} |
| init_value_map = {'greater': -inf, 'less': inf} |
| _default_greater_keys = [ |
| 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', |
| 'mAcc', 'aAcc' |
| ] |
| _default_less_keys = ['loss'] |
|
|
| def __init__(self, |
| interval: int = -1, |
| by_epoch: bool = True, |
| save_optimizer: bool = True, |
| save_param_scheduler: bool = True, |
| out_dir: Optional[Union[str, Path]] = None, |
| max_keep_ckpts: int = -1, |
| save_last: bool = True, |
| save_best: Union[str, List[str], None] = None, |
| rule: Union[str, List[str], None] = None, |
| greater_keys: Optional[Sequence[str]] = None, |
| less_keys: Optional[Sequence[str]] = None, |
| file_client_args: Optional[dict] = None, |
| filename_tmpl: Optional[str] = None, |
| backend_args: Optional[dict] = None, |
| published_keys: Union[str, List[str], None] = None, |
| save_begin: int = 0, |
| **kwargs) -> None: |
| self.interval = interval |
| self.by_epoch = by_epoch |
| self.save_optimizer = save_optimizer |
| self.save_param_scheduler = save_param_scheduler |
| self.out_dir = out_dir |
| self.max_keep_ckpts = max_keep_ckpts |
| self.save_last = save_last |
| self.args = kwargs |
|
|
| if file_client_args is not None: |
| print_log( |
| '"file_client_args" will be deprecated in future. ' |
| 'Please use "backend_args" instead', |
| logger='current', |
| level=logging.WARNING) |
| if backend_args is not None: |
| raise ValueError( |
| '"file_client_args" and "backend_args" cannot be set ' |
| 'at the same time.') |
|
|
| self.file_client_args = file_client_args |
| self.backend_args = backend_args |
|
|
| if filename_tmpl is None: |
| if self.by_epoch: |
| self.filename_tmpl = 'epoch_{}.pth' |
| else: |
| self.filename_tmpl = 'iter_{}.pth' |
| else: |
| self.filename_tmpl = filename_tmpl |
|
|
| |
| assert (isinstance(save_best, str) or is_list_of(save_best, str) |
| or (save_best is None)), ( |
| '"save_best" should be a str or list of str or None, ' |
| f'but got {type(save_best)}') |
|
|
| if isinstance(save_best, list): |
| if 'auto' in save_best: |
| assert len(save_best) == 1, ( |
| 'Only support one "auto" in "save_best" list.') |
| assert len(save_best) == len( |
| set(save_best)), ('Find duplicate element in "save_best".') |
| else: |
| |
| if save_best is not None: |
| save_best = [save_best] |
| self.save_best = save_best |
|
|
| |
| assert (isinstance(rule, str) or is_list_of(rule, str) |
| or (rule is None)), ( |
| '"rule" should be a str or list of str or None, ' |
| f'but got {type(rule)}') |
| if isinstance(rule, list): |
| |
| assert len(rule) in [ |
| 1, |
| len(self.save_best) |
| ], ('Number of "rule" must be 1 or the same as number of ' |
| f'"save_best", but got {len(rule)}.') |
| else: |
| |
| rule = [rule] |
|
|
| if greater_keys is None: |
| self.greater_keys = self._default_greater_keys |
| else: |
| if not isinstance(greater_keys, (list, tuple)): |
| greater_keys = (greater_keys, ) |
| assert is_seq_of(greater_keys, str) |
| self.greater_keys = greater_keys |
|
|
| if less_keys is None: |
| self.less_keys = self._default_less_keys |
| else: |
| if not isinstance(less_keys, (list, tuple)): |
| less_keys = (less_keys, ) |
| assert is_seq_of(less_keys, str) |
| self.less_keys = less_keys |
|
|
| if self.save_best is not None: |
| self.is_better_than: Dict[str, Callable] = dict() |
| self._init_rule(rule, self.save_best) |
| if len(self.key_indicators) == 1: |
| self.best_ckpt_path: Optional[str] = None |
| else: |
| self.best_ckpt_path_dict: Dict = dict() |
|
|
| |
| if not (isinstance(published_keys, str) |
| or is_seq_of(published_keys, str) or published_keys is None): |
| raise TypeError( |
| '"published_keys" should be a str or a sequence of str or ' |
| f'None, but got {type(published_keys)}') |
|
|
| if isinstance(published_keys, str): |
| published_keys = [published_keys] |
| elif isinstance(published_keys, (list, tuple)): |
| assert len(published_keys) == len(set(published_keys)), ( |
| 'Find duplicate elements in "published_keys".') |
| self.published_keys = published_keys |
|
|
| self.last_ckpt = None |
| if save_begin < 0: |
| raise ValueError( |
| 'save_begin should not be less than 0, but got {save_begin}') |
| self.save_begin = save_begin |
|
|
| def before_train(self, runner) -> None: |
| """Finish all operations, related to checkpoint. |
| |
| This function will get the appropriate file client, and the directory |
| to save these checkpoints of the model. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| """ |
| if self.out_dir is None: |
| self.out_dir = runner.work_dir |
|
|
| |
| |
| |
| self.file_client = FileClient.infer_client(self.file_client_args, |
| self.out_dir) |
|
|
| if self.file_client_args is None: |
| self.file_backend = get_file_backend( |
| self.out_dir, backend_args=self.backend_args) |
| else: |
| self.file_backend = self.file_client |
|
|
| |
| |
| |
| |
| if self.out_dir != runner.work_dir: |
| basename = osp.basename(runner.work_dir.rstrip(osp.sep)) |
| self.out_dir = self.file_backend.join_path( |
| self.out_dir, basename) |
|
|
| runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') |
|
|
| if self.save_best is not None: |
| if len(self.key_indicators) == 1: |
| if 'best_ckpt' not in runner.message_hub.runtime_info: |
| self.best_ckpt_path = None |
| else: |
| self.best_ckpt_path = runner.message_hub.get_info( |
| 'best_ckpt') |
| else: |
| for key_indicator in self.key_indicators: |
| best_ckpt_name = f'best_ckpt_{key_indicator}' |
| if best_ckpt_name not in runner.message_hub.runtime_info: |
| self.best_ckpt_path_dict[key_indicator] = None |
| else: |
| self.best_ckpt_path_dict[ |
| key_indicator] = runner.message_hub.get_info( |
| best_ckpt_name) |
|
|
| if self.max_keep_ckpts > 0: |
| keep_ckpt_ids = [] |
| if 'keep_ckpt_ids' in runner.message_hub.runtime_info: |
| keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') |
|
|
| while len(keep_ckpt_ids) > self.max_keep_ckpts: |
| step = keep_ckpt_ids.pop(0) |
| if is_main_process(): |
| path = self.file_backend.join_path( |
| self.out_dir, self.filename_tmpl.format(step)) |
| if self.file_backend.isfile(path): |
| self.file_backend.remove(path) |
| elif self.file_backend.isdir(path): |
| |
| self.file_backend.rmtree(path) |
|
|
| self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, |
| self.max_keep_ckpts) |
|
|
| def after_train_epoch(self, runner) -> None: |
| """Save the checkpoint and synchronize buffers after each epoch. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| """ |
| if not self.by_epoch: |
| return |
|
|
| |
| |
| |
| if self.every_n_epochs(runner, self.interval, self.save_begin) or ( |
| self.save_last and self.is_last_train_epoch(runner)): |
| runner.logger.info( |
| f'Saving checkpoint at {runner.epoch + 1} epochs') |
| self._save_checkpoint(runner) |
|
|
| def after_val_epoch(self, runner, metrics): |
| """Save the checkpoint and synchronize buffers after each evaluation |
| epoch. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| metrics (dict): Evaluation results of all metrics |
| """ |
| if len(metrics) == 0: |
| runner.logger.warning( |
| 'Since `metrics` is an empty dict, the behavior to save ' |
| 'the best checkpoint will be skipped in this evaluation.') |
| return |
|
|
| self._save_best_checkpoint(runner, metrics) |
|
|
| def after_train(self, runner) -> None: |
| """Publish the checkpoint after training. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| """ |
| if self.published_keys is None: |
| return |
|
|
| if self.save_last and self.last_ckpt is not None: |
| self._publish_model(runner, self.last_ckpt) |
|
|
| if getattr(self, 'best_ckpt_path', None) is not None: |
| self._publish_model(runner, str(self.best_ckpt_path)) |
| if getattr(self, 'best_ckpt_path_dict', None) is not None: |
| for best_ckpt in self.best_ckpt_path_dict.values(): |
| self._publish_model(runner, best_ckpt) |
|
|
| @master_only |
| def _publish_model(self, runner, ckpt_path: str) -> None: |
| """Remove unnecessary keys from ckpt_path and save the new checkpoint. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| ckpt_path (str): The checkpoint path that ought to be published. |
| """ |
| from mmengine.runner import save_checkpoint |
| from mmengine.runner.checkpoint import _load_checkpoint |
| checkpoint = _load_checkpoint(ckpt_path) |
| assert self.published_keys is not None |
| removed_keys = [] |
| for key in list(checkpoint.keys()): |
| if key not in self.published_keys: |
| removed_keys.append(key) |
| checkpoint.pop(key) |
| if removed_keys: |
| print_log( |
| f'Key {removed_keys} will be removed because they are not ' |
| 'found in published_keys. If you want to keep them, ' |
| f'please set `{removed_keys}` in published_keys', |
| logger='current') |
| checkpoint_data = pickle.dumps(checkpoint) |
| sha = hashlib.sha256(checkpoint_data).hexdigest() |
| final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth' |
| save_checkpoint(checkpoint, final_path) |
| print_log( |
| f'The checkpoint ({ckpt_path}) is published to ' |
| f'{final_path}.', |
| logger='current') |
|
|
| def _save_checkpoint_with_step(self, runner, step, meta): |
| |
| |
| if self.max_keep_ckpts > 0: |
| |
| |
| if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: |
| pass |
| else: |
| if len(self.keep_ckpt_ids) == self.max_keep_ckpts: |
| _step = self.keep_ckpt_ids.popleft() |
| if is_main_process(): |
| ckpt_path = self.file_backend.join_path( |
| self.out_dir, self.filename_tmpl.format(_step)) |
|
|
| if self.file_backend.isfile(ckpt_path): |
| self.file_backend.remove(ckpt_path) |
| elif self.file_backend.isdir(ckpt_path): |
| |
| self.file_backend.rmtree(ckpt_path) |
|
|
| self.keep_ckpt_ids.append(step) |
| runner.message_hub.update_info('keep_ckpt_ids', |
| list(self.keep_ckpt_ids)) |
|
|
| ckpt_filename = self.filename_tmpl.format(step) |
| self.last_ckpt = self.file_backend.join_path(self.out_dir, |
| ckpt_filename) |
| runner.message_hub.update_info('last_ckpt', self.last_ckpt) |
|
|
| runner.save_checkpoint( |
| self.out_dir, |
| ckpt_filename, |
| self.file_client_args, |
| save_optimizer=self.save_optimizer, |
| save_param_scheduler=self.save_param_scheduler, |
| meta=meta, |
| by_epoch=self.by_epoch, |
| backend_args=self.backend_args, |
| **self.args) |
|
|
| |
| |
| if not is_main_process(): |
| return |
|
|
| save_file = osp.join(runner.work_dir, 'last_checkpoint') |
| with open(save_file, 'w') as f: |
| f.write(self.last_ckpt) |
|
|
| def _save_checkpoint(self, runner) -> None: |
| """Save the current checkpoint and delete outdated checkpoint. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| """ |
| if self.by_epoch: |
| step = runner.epoch + 1 |
| meta = dict(epoch=step, iter=runner.iter) |
| else: |
| step = runner.iter + 1 |
| meta = dict(epoch=runner.epoch, iter=step) |
|
|
| self._save_checkpoint_with_step(runner, step, meta=meta) |
|
|
| def _save_best_checkpoint(self, runner, metrics) -> None: |
| """Save the current checkpoint and delete outdated checkpoint. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| metrics (dict): Evaluation results of all metrics. |
| """ |
| if not self.save_best: |
| return |
|
|
| if self.by_epoch: |
| ckpt_filename = self.filename_tmpl.format(runner.epoch) |
| cur_type, cur_time = 'epoch', runner.epoch |
| else: |
| ckpt_filename = self.filename_tmpl.format(runner.iter) |
| cur_type, cur_time = 'iter', runner.iter |
|
|
| meta = dict(epoch=runner.epoch, iter=runner.iter) |
|
|
| |
| if 'auto' in self.key_indicators: |
| self._init_rule(self.rules, [list(metrics.keys())[0]]) |
|
|
| best_ckpt_updated = False |
| |
| |
| for key_indicator, rule in zip(self.key_indicators, self.rules): |
| key_score = metrics[key_indicator] |
|
|
| if len(self.key_indicators) == 1: |
| best_score_key = 'best_score' |
| runtime_best_ckpt_key = 'best_ckpt' |
| best_ckpt_path = self.best_ckpt_path |
| else: |
| best_score_key = f'best_score_{key_indicator}' |
| runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' |
| best_ckpt_path = self.best_ckpt_path_dict[key_indicator] |
|
|
| if best_score_key not in runner.message_hub.runtime_info: |
| best_score = self.init_value_map[rule] |
| else: |
| best_score = runner.message_hub.get_info(best_score_key) |
|
|
| if key_score is None or not self.is_better_than[key_indicator]( |
| key_score, best_score): |
| continue |
|
|
| best_ckpt_updated = True |
|
|
| best_score = key_score |
| runner.message_hub.update_info(best_score_key, best_score) |
|
|
| if best_ckpt_path and is_main_process(): |
| is_removed = False |
| if self.file_backend.isfile(best_ckpt_path): |
| self.file_backend.remove(best_ckpt_path) |
| is_removed = True |
| elif self.file_backend.isdir(best_ckpt_path): |
| |
| self.file_backend.rmtree(best_ckpt_path) |
| is_removed = True |
|
|
| if is_removed: |
| runner.logger.info( |
| f'The previous best checkpoint {best_ckpt_path} ' |
| 'is removed') |
|
|
| best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' |
| |
| best_ckpt_name = best_ckpt_name.replace('/', '_') |
| if len(self.key_indicators) == 1: |
| self.best_ckpt_path = self.file_backend.join_path( |
| self.out_dir, best_ckpt_name) |
| runner.message_hub.update_info(runtime_best_ckpt_key, |
| self.best_ckpt_path) |
| else: |
| self.best_ckpt_path_dict[ |
| key_indicator] = self.file_backend.join_path( |
| self.out_dir, best_ckpt_name) |
| runner.message_hub.update_info( |
| runtime_best_ckpt_key, |
| self.best_ckpt_path_dict[key_indicator]) |
| runner.save_checkpoint( |
| self.out_dir, |
| filename=best_ckpt_name, |
| file_client_args=self.file_client_args, |
| save_optimizer=False, |
| save_param_scheduler=False, |
| meta=meta, |
| by_epoch=False, |
| backend_args=self.backend_args) |
| runner.logger.info( |
| f'The best checkpoint with {best_score:0.4f} {key_indicator} ' |
| f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') |
|
|
| |
| |
| |
| |
| |
| if best_ckpt_updated and self.last_ckpt is not None: |
| self._save_checkpoint_with_step(runner, cur_time, meta) |
|
|
| def _init_rule(self, rules, key_indicators) -> None: |
| """Initialize rule, key_indicator, comparison_func, and best score. If |
| key_indicator is a list of string and rule is a string, all metric in |
| the key_indicator will share the same rule. |
| |
| Here is the rule to determine which rule is used for key indicator when |
| the rule is not specific (note that the key indicator matching is case- |
| insensitive): |
| |
| 1. If the key indicator is in ``self.greater_keys``, the rule |
| will be specified as 'greater'. |
| 2. Or if the key indicator is in ``self.less_keys``, the rule |
| will be specified as 'less'. |
| 3. Or if any one item in ``self.greater_keys`` is a substring of |
| key_indicator, the rule will be specified as 'greater'. |
| 4. Or if any one item in ``self.less_keys`` is a substring of |
| key_indicator, the rule will be specified as 'less'. |
| |
| Args: |
| rule (List[Optional[str]]): Comparison rule for best score. |
| key_indicator (List[str]): Key indicator to determine |
| the comparison rule. |
| """ |
| if len(rules) == 1: |
| rules = rules * len(key_indicators) |
|
|
| self.rules = [] |
| for rule, key_indicator in zip(rules, key_indicators): |
|
|
| if rule not in self.rule_map and rule is not None: |
| raise KeyError('rule must be greater, less or None, ' |
| f'but got {rule}.') |
|
|
| if rule is None and key_indicator != 'auto': |
| |
| |
| key_indicator_lc = key_indicator.lower() |
| greater_keys = {key.lower() for key in self.greater_keys} |
| less_keys = {key.lower() for key in self.less_keys} |
|
|
| if key_indicator_lc in greater_keys: |
| rule = 'greater' |
| elif key_indicator_lc in less_keys: |
| rule = 'less' |
| elif any(key in key_indicator_lc for key in greater_keys): |
| rule = 'greater' |
| elif any(key in key_indicator_lc for key in less_keys): |
| rule = 'less' |
| else: |
| raise ValueError('Cannot infer the rule for key ' |
| f'{key_indicator}, thus a specific rule ' |
| 'must be specified.') |
| if rule is not None: |
| self.is_better_than[key_indicator] = self.rule_map[rule] |
| self.rules.append(rule) |
|
|
| self.key_indicators = key_indicators |
|
|
| def after_train_iter(self, |
| runner, |
| batch_idx: int, |
| data_batch: DATA_BATCH = None, |
| outputs=Optional[dict]) -> None: |
| """Save the checkpoint and synchronize buffers after each iteration. |
| |
| Args: |
| runner (Runner): The runner of the training process. |
| batch_idx (int): The index of the current batch in the train loop. |
| data_batch (dict or tuple or list, optional): Data from dataloader. |
| outputs (dict, optional): Outputs from model. |
| """ |
| if self.by_epoch: |
| return |
|
|
| |
| |
| |
| |
| if self.every_n_train_iters(runner, self.interval, |
| self.save_begin) or \ |
| (self.save_last and |
| self.is_last_train_iter(runner)): |
| runner.logger.info( |
| f'Saving checkpoint at {runner.iter + 1} iterations') |
| self._save_checkpoint(runner) |
|
|