| | |
| | import bisect |
| | import logging |
| | import time |
| | from typing import Dict, List, Optional, Sequence, Tuple, Union |
| |
|
| | import torch |
| | from torch.utils.data import DataLoader |
| |
|
| | from mmengine.evaluator import Evaluator |
| | from mmengine.logging import HistoryBuffer, print_log |
| | from mmengine.registry import LOOPS |
| | from mmengine.structures import BaseDataElement |
| | from mmengine.utils import is_list_of |
| | from .amp import autocast |
| | from .base_loop import BaseLoop |
| | from .utils import calc_dynamic_intervals |
| |
|
| |
|
| | @LOOPS.register_module() |
| | class EpochBasedTrainLoop(BaseLoop): |
| | """Loop for epoch-based training. |
| | |
| | Args: |
| | runner (Runner): A reference of runner. |
| | dataloader (Dataloader or dict): A dataloader object or a dict to |
| | build a dataloader. |
| | max_epochs (int): Total training epochs. |
| | val_begin (int): The epoch that begins validating. |
| | Defaults to 1. |
| | val_interval (int): Validation interval. Defaults to 1. |
| | dynamic_intervals (List[Tuple[int, int]], optional): The |
| | first element in the tuple is a milestone and the second |
| | element is a interval. The interval is used after the |
| | corresponding milestone. Defaults to None. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | runner, |
| | dataloader: Union[DataLoader, Dict], |
| | max_epochs: int, |
| | val_begin: int = 1, |
| | val_interval: int = 1, |
| | dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: |
| | super().__init__(runner, dataloader) |
| | self._max_epochs = int(max_epochs) |
| | assert self._max_epochs == max_epochs, \ |
| | f'`max_epochs` should be a integer number, but get {max_epochs}.' |
| | self._max_iters = self._max_epochs * len(self.dataloader) |
| | self._epoch = 0 |
| | self._iter = 0 |
| | self.val_begin = val_begin |
| | self.val_interval = val_interval |
| | |
| | |
| | self.stop_training = False |
| | if hasattr(self.dataloader.dataset, 'metainfo'): |
| | self.runner.visualizer.dataset_meta = \ |
| | self.dataloader.dataset.metainfo |
| | else: |
| | print_log( |
| | f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| | 'metainfo. ``dataset_meta`` in visualizer will be ' |
| | 'None.', |
| | logger='current', |
| | level=logging.WARNING) |
| |
|
| | self.dynamic_milestones, self.dynamic_intervals = \ |
| | calc_dynamic_intervals( |
| | self.val_interval, dynamic_intervals) |
| |
|
| | @property |
| | def max_epochs(self): |
| | """int: Total epochs to train model.""" |
| | return self._max_epochs |
| |
|
| | @property |
| | def max_iters(self): |
| | """int: Total iterations to train model.""" |
| | return self._max_iters |
| |
|
| | @property |
| | def epoch(self): |
| | """int: Current epoch.""" |
| | return self._epoch |
| |
|
| | @property |
| | def iter(self): |
| | """int: Current iteration.""" |
| | return self._iter |
| |
|
| | def run(self) -> torch.nn.Module: |
| | """Launch training.""" |
| | self.runner.call_hook('before_train') |
| |
|
| | while self._epoch < self._max_epochs and not self.stop_training: |
| | self.run_epoch() |
| |
|
| | self._decide_current_val_interval() |
| | if (self.runner.val_loop is not None |
| | and self._epoch >= self.val_begin |
| | and (self._epoch % self.val_interval == 0 |
| | or self._epoch == self._max_epochs)): |
| | self.runner.val_loop.run() |
| |
|
| | self.runner.call_hook('after_train') |
| | return self.runner.model |
| |
|
| | def run_epoch(self) -> None: |
| | """Iterate one epoch.""" |
| | self.runner.call_hook('before_train_epoch') |
| | self.runner.model.train() |
| | for idx, data_batch in enumerate(self.dataloader): |
| | self.run_iter(idx, data_batch) |
| |
|
| | self.runner.call_hook('after_train_epoch') |
| | self._epoch += 1 |
| |
|
| | def run_iter(self, idx, data_batch: Sequence[dict]) -> None: |
| | """Iterate one min-batch. |
| | |
| | Args: |
| | data_batch (Sequence[dict]): Batch of data from dataloader. |
| | """ |
| | self.runner.call_hook( |
| | 'before_train_iter', batch_idx=idx, data_batch=data_batch) |
| | |
| | |
| | |
| | outputs = self.runner.model.train_step( |
| | data_batch, optim_wrapper=self.runner.optim_wrapper) |
| |
|
| | self.runner.call_hook( |
| | 'after_train_iter', |
| | batch_idx=idx, |
| | data_batch=data_batch, |
| | outputs=outputs) |
| | self._iter += 1 |
| |
|
| | def _decide_current_val_interval(self) -> None: |
| | """Dynamically modify the ``val_interval``.""" |
| | step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) |
| | self.val_interval = self.dynamic_intervals[step - 1] |
| |
|
| |
|
| | class _InfiniteDataloaderIterator: |
| | """An infinite dataloader iterator wrapper for IterBasedTrainLoop. |
| | |
| | It resets the dataloader to continue iterating when the iterator has |
| | iterated over all the data. However, this approach is not efficient, as the |
| | workers need to be restarted every time the dataloader is reset. It is |
| | recommended to use `mmengine.dataset.InfiniteSampler` to enable the |
| | dataloader to iterate infinitely. |
| | """ |
| |
|
| | def __init__(self, dataloader: DataLoader) -> None: |
| | self._dataloader = dataloader |
| | self._iterator = iter(self._dataloader) |
| | self._epoch = 0 |
| |
|
| | def __iter__(self): |
| | return self |
| |
|
| | def __next__(self) -> Sequence[dict]: |
| | try: |
| | data = next(self._iterator) |
| | except StopIteration: |
| | print_log( |
| | 'Reach the end of the dataloader, it will be ' |
| | 'restarted and continue to iterate. It is ' |
| | 'recommended to use ' |
| | '`mmengine.dataset.InfiniteSampler` to enable the ' |
| | 'dataloader to iterate infinitely.', |
| | logger='current', |
| | level=logging.WARNING) |
| | self._epoch += 1 |
| | if hasattr(self._dataloader, 'sampler') and hasattr( |
| | self._dataloader.sampler, 'set_epoch'): |
| | |
| | |
| | self._dataloader.sampler.set_epoch(self._epoch) |
| |
|
| | elif hasattr(self._dataloader, 'batch_sampler') and hasattr( |
| | self._dataloader.batch_sampler.sampler, 'set_epoch'): |
| | |
| | |
| | |
| | self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) |
| | time.sleep(2) |
| | self._iterator = iter(self._dataloader) |
| | data = next(self._iterator) |
| | return data |
| |
|
| |
|
| | @LOOPS.register_module() |
| | class IterBasedTrainLoop(BaseLoop): |
| | """Loop for iter-based training. |
| | |
| | Args: |
| | runner (Runner): A reference of runner. |
| | dataloader (Dataloader or dict): A dataloader object or a dict to |
| | build a dataloader. |
| | max_iters (int): Total training iterations. |
| | val_begin (int): The iteration that begins validating. |
| | Defaults to 1. |
| | val_interval (int): Validation interval. Defaults to 1000. |
| | dynamic_intervals (List[Tuple[int, int]], optional): The |
| | first element in the tuple is a milestone and the second |
| | element is a interval. The interval is used after the |
| | corresponding milestone. Defaults to None. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | runner, |
| | dataloader: Union[DataLoader, Dict], |
| | max_iters: int, |
| | val_begin: int = 1, |
| | val_interval: int = 1000, |
| | dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: |
| | super().__init__(runner, dataloader) |
| | self._max_iters = int(max_iters) |
| | assert self._max_iters == max_iters, \ |
| | f'`max_iters` should be a integer number, but get {max_iters}' |
| | self._max_epochs = 1 |
| | self._epoch = 0 |
| | self._iter = 0 |
| | self.val_begin = val_begin |
| | self.val_interval = val_interval |
| | |
| | |
| | self.stop_training = False |
| | if hasattr(self.dataloader.dataset, 'metainfo'): |
| | self.runner.visualizer.dataset_meta = \ |
| | self.dataloader.dataset.metainfo |
| | else: |
| | print_log( |
| | f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| | 'metainfo. ``dataset_meta`` in visualizer will be ' |
| | 'None.', |
| | logger='current', |
| | level=logging.WARNING) |
| | |
| | self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) |
| |
|
| | self.dynamic_milestones, self.dynamic_intervals = \ |
| | calc_dynamic_intervals( |
| | self.val_interval, dynamic_intervals) |
| |
|
| | @property |
| | def max_epochs(self): |
| | """int: Total epochs to train model.""" |
| | return self._max_epochs |
| |
|
| | @property |
| | def max_iters(self): |
| | """int: Total iterations to train model.""" |
| | return self._max_iters |
| |
|
| | @property |
| | def epoch(self): |
| | """int: Current epoch.""" |
| | return self._epoch |
| |
|
| | @property |
| | def iter(self): |
| | """int: Current iteration.""" |
| | return self._iter |
| |
|
| | def run(self) -> None: |
| | """Launch training.""" |
| | self.runner.call_hook('before_train') |
| | |
| | |
| | self.runner.call_hook('before_train_epoch') |
| | if self._iter > 0: |
| | print_log( |
| | f'Advance dataloader {self._iter} steps to skip data ' |
| | 'that has already been trained', |
| | logger='current', |
| | level=logging.WARNING) |
| | for _ in range(self._iter): |
| | next(self.dataloader_iterator) |
| | while self._iter < self._max_iters and not self.stop_training: |
| | self.runner.model.train() |
| |
|
| | data_batch = next(self.dataloader_iterator) |
| | self.run_iter(data_batch) |
| |
|
| | self._decide_current_val_interval() |
| | if (self.runner.val_loop is not None |
| | and self._iter >= self.val_begin |
| | and (self._iter % self.val_interval == 0 |
| | or self._iter == self._max_iters)): |
| | self.runner.val_loop.run() |
| |
|
| | self.runner.call_hook('after_train_epoch') |
| | self.runner.call_hook('after_train') |
| | return self.runner.model |
| |
|
| | def run_iter(self, data_batch: Sequence[dict]) -> None: |
| | """Iterate one mini-batch. |
| | |
| | Args: |
| | data_batch (Sequence[dict]): Batch of data from dataloader. |
| | """ |
| | self.runner.call_hook( |
| | 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) |
| | |
| | |
| | |
| | outputs = self.runner.model.train_step( |
| | data_batch, optim_wrapper=self.runner.optim_wrapper) |
| |
|
| | self.runner.call_hook( |
| | 'after_train_iter', |
| | batch_idx=self._iter, |
| | data_batch=data_batch, |
| | outputs=outputs) |
| | self._iter += 1 |
| |
|
| | def _decide_current_val_interval(self) -> None: |
| | """Dynamically modify the ``val_interval``.""" |
| | step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) |
| | self.val_interval = self.dynamic_intervals[step - 1] |
| |
|
| |
|
| | @LOOPS.register_module() |
| | class ValLoop(BaseLoop): |
| | """Loop for validation. |
| | |
| | Args: |
| | runner (Runner): A reference of runner. |
| | dataloader (Dataloader or dict): A dataloader object or a dict to |
| | build a dataloader. |
| | evaluator (Evaluator or dict or list): Used for computing metrics. |
| | fp16 (bool): Whether to enable fp16 validation. Defaults to |
| | False. |
| | """ |
| |
|
| | def __init__(self, |
| | runner, |
| | dataloader: Union[DataLoader, Dict], |
| | evaluator: Union[Evaluator, Dict, List], |
| | fp16: bool = False) -> None: |
| | super().__init__(runner, dataloader) |
| |
|
| | if isinstance(evaluator, (dict, list)): |
| | self.evaluator = runner.build_evaluator(evaluator) |
| | else: |
| | assert isinstance(evaluator, Evaluator), ( |
| | 'evaluator must be one of dict, list or Evaluator instance, ' |
| | f'but got {type(evaluator)}.') |
| | self.evaluator = evaluator |
| | if hasattr(self.dataloader.dataset, 'metainfo'): |
| | self.evaluator.dataset_meta = self.dataloader.dataset.metainfo |
| | self.runner.visualizer.dataset_meta = \ |
| | self.dataloader.dataset.metainfo |
| | else: |
| | print_log( |
| | f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| | 'metainfo. ``dataset_meta`` in evaluator, metric and ' |
| | 'visualizer will be None.', |
| | logger='current', |
| | level=logging.WARNING) |
| | self.fp16 = fp16 |
| | self.val_loss: Dict[str, HistoryBuffer] = dict() |
| |
|
| | def run(self) -> dict: |
| | """Launch validation.""" |
| | self.runner.call_hook('before_val') |
| | self.runner.call_hook('before_val_epoch') |
| | self.runner.model.eval() |
| |
|
| | |
| | self.val_loss.clear() |
| | for idx, data_batch in enumerate(self.dataloader): |
| | self.run_iter(idx, data_batch) |
| |
|
| | |
| | metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
| |
|
| | if self.val_loss: |
| | loss_dict = _parse_losses(self.val_loss, 'val') |
| | metrics.update(loss_dict) |
| |
|
| | self.runner.call_hook('after_val_epoch', metrics=metrics) |
| | self.runner.call_hook('after_val') |
| | return metrics |
| |
|
| | @torch.no_grad() |
| | def run_iter(self, idx, data_batch: Sequence[dict]): |
| | """Iterate one mini-batch. |
| | |
| | Args: |
| | data_batch (Sequence[dict]): Batch of data |
| | from dataloader. |
| | """ |
| | self.runner.call_hook( |
| | 'before_val_iter', batch_idx=idx, data_batch=data_batch) |
| | |
| | with autocast(enabled=self.fp16): |
| | outputs = self.runner.model.val_step(data_batch) |
| |
|
| | outputs, self.val_loss = _update_losses(outputs, self.val_loss) |
| |
|
| | self.evaluator.process(data_samples=outputs, data_batch=data_batch) |
| | self.runner.call_hook( |
| | 'after_val_iter', |
| | batch_idx=idx, |
| | data_batch=data_batch, |
| | outputs=outputs) |
| |
|
| |
|
| | @LOOPS.register_module() |
| | class TestLoop(BaseLoop): |
| | """Loop for test. |
| | |
| | Args: |
| | runner (Runner): A reference of runner. |
| | dataloader (Dataloader or dict): A dataloader object or a dict to |
| | build a dataloader. |
| | evaluator (Evaluator or dict or list): Used for computing metrics. |
| | fp16 (bool): Whether to enable fp16 testing. Defaults to |
| | False. |
| | """ |
| |
|
| | def __init__(self, |
| | runner, |
| | dataloader: Union[DataLoader, Dict], |
| | evaluator: Union[Evaluator, Dict, List], |
| | fp16: bool = False): |
| | super().__init__(runner, dataloader) |
| |
|
| | if isinstance(evaluator, dict) or isinstance(evaluator, list): |
| | self.evaluator = runner.build_evaluator(evaluator) |
| | else: |
| | self.evaluator = evaluator |
| | if hasattr(self.dataloader.dataset, 'metainfo'): |
| | self.evaluator.dataset_meta = self.dataloader.dataset.metainfo |
| | self.runner.visualizer.dataset_meta = \ |
| | self.dataloader.dataset.metainfo |
| | else: |
| | print_log( |
| | f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| | 'metainfo. ``dataset_meta`` in evaluator, metric and ' |
| | 'visualizer will be None.', |
| | logger='current', |
| | level=logging.WARNING) |
| | self.fp16 = fp16 |
| | self.test_loss: Dict[str, HistoryBuffer] = dict() |
| |
|
| | def run(self) -> dict: |
| | """Launch test.""" |
| | self.runner.call_hook('before_test') |
| | self.runner.call_hook('before_test_epoch') |
| | self.runner.model.eval() |
| |
|
| | |
| | self.test_loss.clear() |
| | for idx, data_batch in enumerate(self.dataloader): |
| | self.run_iter(idx, data_batch) |
| |
|
| | |
| | metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
| |
|
| | if self.test_loss: |
| | loss_dict = _parse_losses(self.test_loss, 'test') |
| | metrics.update(loss_dict) |
| |
|
| | self.runner.call_hook('after_test_epoch', metrics=metrics) |
| | self.runner.call_hook('after_test') |
| | return metrics |
| |
|
| | @torch.no_grad() |
| | def run_iter(self, idx, data_batch: Sequence[dict]) -> None: |
| | """Iterate one mini-batch. |
| | |
| | Args: |
| | data_batch (Sequence[dict]): Batch of data from dataloader. |
| | """ |
| | self.runner.call_hook( |
| | 'before_test_iter', batch_idx=idx, data_batch=data_batch) |
| | |
| | with autocast(enabled=self.fp16): |
| | outputs = self.runner.model.test_step(data_batch) |
| |
|
| | outputs, self.test_loss = _update_losses(outputs, self.test_loss) |
| |
|
| | self.evaluator.process(data_samples=outputs, data_batch=data_batch) |
| | self.runner.call_hook( |
| | 'after_test_iter', |
| | batch_idx=idx, |
| | data_batch=data_batch, |
| | outputs=outputs) |
| |
|
| |
|
| | def _parse_losses(losses: Dict[str, HistoryBuffer], |
| | stage: str) -> Dict[str, float]: |
| | """Parses the raw losses of the network. |
| | |
| | Args: |
| | losses (dict): raw losses of the network. |
| | stage (str): The stage of loss, e.g., 'val' or 'test'. |
| | |
| | Returns: |
| | dict[str, float]: The key is the loss name, and the value is the |
| | average loss. |
| | """ |
| | all_loss = 0 |
| | loss_dict: Dict[str, float] = dict() |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | avg_loss = loss_value.mean() |
| | loss_dict[loss_name] = avg_loss |
| | if 'loss' in loss_name: |
| | all_loss += avg_loss |
| |
|
| | loss_dict[f'{stage}_loss'] = all_loss |
| | return loss_dict |
| |
|
| |
|
| | def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]: |
| | """Update and record the losses of the network. |
| | |
| | Args: |
| | outputs (list): The outputs of the network. |
| | losses (dict): The losses of the network. |
| | |
| | Returns: |
| | list: The updated outputs of the network. |
| | dict: The updated losses of the network. |
| | """ |
| | if isinstance(outputs[-1], |
| | BaseDataElement) and outputs[-1].keys() == ['loss']: |
| | loss = outputs[-1].loss |
| | outputs = outputs[:-1] |
| | else: |
| | loss = dict() |
| |
|
| | for loss_name, loss_value in loss.items(): |
| | if loss_name not in losses: |
| | losses[loss_name] = HistoryBuffer() |
| | if isinstance(loss_value, torch.Tensor): |
| | losses[loss_name].update(loss_value.item()) |
| | elif is_list_of(loss_value, torch.Tensor): |
| | for loss_value_i in loss_value: |
| | losses[loss_name].update(loss_value_i.item()) |
| | return outputs, losses |
| |
|