| |
| import bisect |
| import logging |
| import time |
| from typing import Dict, List, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from mmengine.evaluator import Evaluator |
| from mmengine.logging import HistoryBuffer, print_log |
| from mmengine.registry import LOOPS |
| from mmengine.structures import BaseDataElement |
| from mmengine.utils import is_list_of |
| from .amp import autocast |
| from .base_loop import BaseLoop |
| from .utils import calc_dynamic_intervals |
|
|
|
|
| @LOOPS.register_module() |
| class EpochBasedTrainLoop(BaseLoop): |
| """Loop for epoch-based training. |
| |
| Args: |
| runner (Runner): A reference of runner. |
| dataloader (Dataloader or dict): A dataloader object or a dict to |
| build a dataloader. |
| max_epochs (int): Total training epochs. |
| val_begin (int): The epoch that begins validating. |
| Defaults to 1. |
| val_interval (int): Validation interval. Defaults to 1. |
| dynamic_intervals (List[Tuple[int, int]], optional): The |
| first element in the tuple is a milestone and the second |
| element is a interval. The interval is used after the |
| corresponding milestone. Defaults to None. |
| """ |
|
|
| def __init__( |
| self, |
| runner, |
| dataloader: Union[DataLoader, Dict], |
| max_epochs: int, |
| val_begin: int = 1, |
| val_interval: int = 1, |
| dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: |
| super().__init__(runner, dataloader) |
| self._max_epochs = int(max_epochs) |
| assert self._max_epochs == max_epochs, \ |
| f'`max_epochs` should be a integer number, but get {max_epochs}.' |
| self._max_iters = self._max_epochs * len(self.dataloader) |
| self._epoch = 0 |
| self._iter = 0 |
| self.val_begin = val_begin |
| self.val_interval = val_interval |
| |
| |
| self.stop_training = False |
| if hasattr(self.dataloader.dataset, 'metainfo'): |
| self.runner.visualizer.dataset_meta = \ |
| self.dataloader.dataset.metainfo |
| else: |
| print_log( |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| 'metainfo. ``dataset_meta`` in visualizer will be ' |
| 'None.', |
| logger='current', |
| level=logging.WARNING) |
|
|
| self.dynamic_milestones, self.dynamic_intervals = \ |
| calc_dynamic_intervals( |
| self.val_interval, dynamic_intervals) |
|
|
| @property |
| def max_epochs(self): |
| """int: Total epochs to train model.""" |
| return self._max_epochs |
|
|
| @property |
| def max_iters(self): |
| """int: Total iterations to train model.""" |
| return self._max_iters |
|
|
| @property |
| def epoch(self): |
| """int: Current epoch.""" |
| return self._epoch |
|
|
| @property |
| def iter(self): |
| """int: Current iteration.""" |
| return self._iter |
|
|
| def run(self) -> torch.nn.Module: |
| """Launch training.""" |
| self.runner.call_hook('before_train') |
|
|
| while self._epoch < self._max_epochs and not self.stop_training: |
| self.run_epoch() |
|
|
| self._decide_current_val_interval() |
| if (self.runner.val_loop is not None |
| and self._epoch >= self.val_begin |
| and (self._epoch % self.val_interval == 0 |
| or self._epoch == self._max_epochs)): |
| self.runner.val_loop.run() |
|
|
| self.runner.call_hook('after_train') |
| return self.runner.model |
|
|
| def run_epoch(self) -> None: |
| """Iterate one epoch.""" |
| self.runner.call_hook('before_train_epoch') |
| self.runner.model.train() |
| for idx, data_batch in enumerate(self.dataloader): |
| self.run_iter(idx, data_batch) |
|
|
| self.runner.call_hook('after_train_epoch') |
| self._epoch += 1 |
|
|
| def run_iter(self, idx, data_batch: Sequence[dict]) -> None: |
| """Iterate one min-batch. |
| |
| Args: |
| data_batch (Sequence[dict]): Batch of data from dataloader. |
| """ |
| self.runner.call_hook( |
| 'before_train_iter', batch_idx=idx, data_batch=data_batch) |
| |
| |
| |
| outputs = self.runner.model.train_step( |
| data_batch, optim_wrapper=self.runner.optim_wrapper) |
|
|
| self.runner.call_hook( |
| 'after_train_iter', |
| batch_idx=idx, |
| data_batch=data_batch, |
| outputs=outputs) |
| self._iter += 1 |
|
|
| def _decide_current_val_interval(self) -> None: |
| """Dynamically modify the ``val_interval``.""" |
| step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) |
| self.val_interval = self.dynamic_intervals[step - 1] |
|
|
|
|
| class _InfiniteDataloaderIterator: |
| """An infinite dataloader iterator wrapper for IterBasedTrainLoop. |
| |
| It resets the dataloader to continue iterating when the iterator has |
| iterated over all the data. However, this approach is not efficient, as the |
| workers need to be restarted every time the dataloader is reset. It is |
| recommended to use `mmengine.dataset.InfiniteSampler` to enable the |
| dataloader to iterate infinitely. |
| """ |
|
|
| def __init__(self, dataloader: DataLoader) -> None: |
| self._dataloader = dataloader |
| self._iterator = iter(self._dataloader) |
| self._epoch = 0 |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self) -> Sequence[dict]: |
| try: |
| data = next(self._iterator) |
| except StopIteration: |
| print_log( |
| 'Reach the end of the dataloader, it will be ' |
| 'restarted and continue to iterate. It is ' |
| 'recommended to use ' |
| '`mmengine.dataset.InfiniteSampler` to enable the ' |
| 'dataloader to iterate infinitely.', |
| logger='current', |
| level=logging.WARNING) |
| self._epoch += 1 |
| if hasattr(self._dataloader, 'sampler') and hasattr( |
| self._dataloader.sampler, 'set_epoch'): |
| |
| |
| self._dataloader.sampler.set_epoch(self._epoch) |
|
|
| elif hasattr(self._dataloader, 'batch_sampler') and hasattr( |
| self._dataloader.batch_sampler.sampler, 'set_epoch'): |
| |
| |
| |
| self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) |
| time.sleep(2) |
| self._iterator = iter(self._dataloader) |
| data = next(self._iterator) |
| return data |
|
|
|
|
| @LOOPS.register_module() |
| class IterBasedTrainLoop(BaseLoop): |
| """Loop for iter-based training. |
| |
| Args: |
| runner (Runner): A reference of runner. |
| dataloader (Dataloader or dict): A dataloader object or a dict to |
| build a dataloader. |
| max_iters (int): Total training iterations. |
| val_begin (int): The iteration that begins validating. |
| Defaults to 1. |
| val_interval (int): Validation interval. Defaults to 1000. |
| dynamic_intervals (List[Tuple[int, int]], optional): The |
| first element in the tuple is a milestone and the second |
| element is a interval. The interval is used after the |
| corresponding milestone. Defaults to None. |
| """ |
|
|
| def __init__( |
| self, |
| runner, |
| dataloader: Union[DataLoader, Dict], |
| max_iters: int, |
| val_begin: int = 1, |
| val_interval: int = 1000, |
| dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: |
| super().__init__(runner, dataloader) |
| self._max_iters = int(max_iters) |
| assert self._max_iters == max_iters, \ |
| f'`max_iters` should be a integer number, but get {max_iters}' |
| self._max_epochs = 1 |
| self._epoch = 0 |
| self._iter = 0 |
| self.val_begin = val_begin |
| self.val_interval = val_interval |
| |
| |
| self.stop_training = False |
| if hasattr(self.dataloader.dataset, 'metainfo'): |
| self.runner.visualizer.dataset_meta = \ |
| self.dataloader.dataset.metainfo |
| else: |
| print_log( |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| 'metainfo. ``dataset_meta`` in visualizer will be ' |
| 'None.', |
| logger='current', |
| level=logging.WARNING) |
| |
| self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) |
|
|
| self.dynamic_milestones, self.dynamic_intervals = \ |
| calc_dynamic_intervals( |
| self.val_interval, dynamic_intervals) |
|
|
| @property |
| def max_epochs(self): |
| """int: Total epochs to train model.""" |
| return self._max_epochs |
|
|
| @property |
| def max_iters(self): |
| """int: Total iterations to train model.""" |
| return self._max_iters |
|
|
| @property |
| def epoch(self): |
| """int: Current epoch.""" |
| return self._epoch |
|
|
| @property |
| def iter(self): |
| """int: Current iteration.""" |
| return self._iter |
|
|
| def run(self) -> None: |
| """Launch training.""" |
| self.runner.call_hook('before_train') |
| |
| |
| self.runner.call_hook('before_train_epoch') |
| if self._iter > 0: |
| print_log( |
| f'Advance dataloader {self._iter} steps to skip data ' |
| 'that has already been trained', |
| logger='current', |
| level=logging.WARNING) |
| for _ in range(self._iter): |
| next(self.dataloader_iterator) |
| while self._iter < self._max_iters and not self.stop_training: |
| self.runner.model.train() |
|
|
| data_batch = next(self.dataloader_iterator) |
| self.run_iter(data_batch) |
|
|
| self._decide_current_val_interval() |
| if (self.runner.val_loop is not None |
| and self._iter >= self.val_begin |
| and (self._iter % self.val_interval == 0 |
| or self._iter == self._max_iters)): |
| self.runner.val_loop.run() |
|
|
| self.runner.call_hook('after_train_epoch') |
| self.runner.call_hook('after_train') |
| return self.runner.model |
|
|
| def run_iter(self, data_batch: Sequence[dict]) -> None: |
| """Iterate one mini-batch. |
| |
| Args: |
| data_batch (Sequence[dict]): Batch of data from dataloader. |
| """ |
| self.runner.call_hook( |
| 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) |
| |
| |
| |
| outputs = self.runner.model.train_step( |
| data_batch, optim_wrapper=self.runner.optim_wrapper) |
|
|
| self.runner.call_hook( |
| 'after_train_iter', |
| batch_idx=self._iter, |
| data_batch=data_batch, |
| outputs=outputs) |
| self._iter += 1 |
|
|
| def _decide_current_val_interval(self) -> None: |
| """Dynamically modify the ``val_interval``.""" |
| step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) |
| self.val_interval = self.dynamic_intervals[step - 1] |
|
|
|
|
| @LOOPS.register_module() |
| class ValLoop(BaseLoop): |
| """Loop for validation. |
| |
| Args: |
| runner (Runner): A reference of runner. |
| dataloader (Dataloader or dict): A dataloader object or a dict to |
| build a dataloader. |
| evaluator (Evaluator or dict or list): Used for computing metrics. |
| fp16 (bool): Whether to enable fp16 validation. Defaults to |
| False. |
| """ |
|
|
| def __init__(self, |
| runner, |
| dataloader: Union[DataLoader, Dict], |
| evaluator: Union[Evaluator, Dict, List], |
| fp16: bool = False) -> None: |
| super().__init__(runner, dataloader) |
|
|
| if isinstance(evaluator, (dict, list)): |
| self.evaluator = runner.build_evaluator(evaluator) |
| else: |
| assert isinstance(evaluator, Evaluator), ( |
| 'evaluator must be one of dict, list or Evaluator instance, ' |
| f'but got {type(evaluator)}.') |
| self.evaluator = evaluator |
| if hasattr(self.dataloader.dataset, 'metainfo'): |
| self.evaluator.dataset_meta = self.dataloader.dataset.metainfo |
| self.runner.visualizer.dataset_meta = \ |
| self.dataloader.dataset.metainfo |
| else: |
| print_log( |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| 'metainfo. ``dataset_meta`` in evaluator, metric and ' |
| 'visualizer will be None.', |
| logger='current', |
| level=logging.WARNING) |
| self.fp16 = fp16 |
| self.val_loss: Dict[str, HistoryBuffer] = dict() |
|
|
| def run(self) -> dict: |
| """Launch validation.""" |
| self.runner.call_hook('before_val') |
| self.runner.call_hook('before_val_epoch') |
| self.runner.model.eval() |
|
|
| |
| self.val_loss.clear() |
| for idx, data_batch in enumerate(self.dataloader): |
| self.run_iter(idx, data_batch) |
|
|
| |
| metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
|
|
| if self.val_loss: |
| loss_dict = _parse_losses(self.val_loss, 'val') |
| metrics.update(loss_dict) |
|
|
| self.runner.call_hook('after_val_epoch', metrics=metrics) |
| self.runner.call_hook('after_val') |
| return metrics |
|
|
| @torch.no_grad() |
| def run_iter(self, idx, data_batch: Sequence[dict]): |
| """Iterate one mini-batch. |
| |
| Args: |
| data_batch (Sequence[dict]): Batch of data |
| from dataloader. |
| """ |
| self.runner.call_hook( |
| 'before_val_iter', batch_idx=idx, data_batch=data_batch) |
| |
| with autocast(enabled=self.fp16): |
| outputs = self.runner.model.val_step(data_batch) |
|
|
| outputs, self.val_loss = _update_losses(outputs, self.val_loss) |
|
|
| self.evaluator.process(data_samples=outputs, data_batch=data_batch) |
| self.runner.call_hook( |
| 'after_val_iter', |
| batch_idx=idx, |
| data_batch=data_batch, |
| outputs=outputs) |
|
|
|
|
| @LOOPS.register_module() |
| class TestLoop(BaseLoop): |
| """Loop for test. |
| |
| Args: |
| runner (Runner): A reference of runner. |
| dataloader (Dataloader or dict): A dataloader object or a dict to |
| build a dataloader. |
| evaluator (Evaluator or dict or list): Used for computing metrics. |
| fp16 (bool): Whether to enable fp16 testing. Defaults to |
| False. |
| """ |
|
|
| def __init__(self, |
| runner, |
| dataloader: Union[DataLoader, Dict], |
| evaluator: Union[Evaluator, Dict, List], |
| fp16: bool = False): |
| super().__init__(runner, dataloader) |
|
|
| if isinstance(evaluator, dict) or isinstance(evaluator, list): |
| self.evaluator = runner.build_evaluator(evaluator) |
| else: |
| self.evaluator = evaluator |
| if hasattr(self.dataloader.dataset, 'metainfo'): |
| self.evaluator.dataset_meta = self.dataloader.dataset.metainfo |
| self.runner.visualizer.dataset_meta = \ |
| self.dataloader.dataset.metainfo |
| else: |
| print_log( |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' |
| 'metainfo. ``dataset_meta`` in evaluator, metric and ' |
| 'visualizer will be None.', |
| logger='current', |
| level=logging.WARNING) |
| self.fp16 = fp16 |
| self.test_loss: Dict[str, HistoryBuffer] = dict() |
|
|
| def run(self) -> dict: |
| """Launch test.""" |
| self.runner.call_hook('before_test') |
| self.runner.call_hook('before_test_epoch') |
| self.runner.model.eval() |
|
|
| |
| self.test_loss.clear() |
| for idx, data_batch in enumerate(self.dataloader): |
| self.run_iter(idx, data_batch) |
|
|
| |
| metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) |
|
|
| if self.test_loss: |
| loss_dict = _parse_losses(self.test_loss, 'test') |
| metrics.update(loss_dict) |
|
|
| self.runner.call_hook('after_test_epoch', metrics=metrics) |
| self.runner.call_hook('after_test') |
| return metrics |
|
|
| @torch.no_grad() |
| def run_iter(self, idx, data_batch: Sequence[dict]) -> None: |
| """Iterate one mini-batch. |
| |
| Args: |
| data_batch (Sequence[dict]): Batch of data from dataloader. |
| """ |
| self.runner.call_hook( |
| 'before_test_iter', batch_idx=idx, data_batch=data_batch) |
| |
| with autocast(enabled=self.fp16): |
| outputs = self.runner.model.test_step(data_batch) |
|
|
| outputs, self.test_loss = _update_losses(outputs, self.test_loss) |
|
|
| self.evaluator.process(data_samples=outputs, data_batch=data_batch) |
| self.runner.call_hook( |
| 'after_test_iter', |
| batch_idx=idx, |
| data_batch=data_batch, |
| outputs=outputs) |
|
|
|
|
| def _parse_losses(losses: Dict[str, HistoryBuffer], |
| stage: str) -> Dict[str, float]: |
| """Parses the raw losses of the network. |
| |
| Args: |
| losses (dict): raw losses of the network. |
| stage (str): The stage of loss, e.g., 'val' or 'test'. |
| |
| Returns: |
| dict[str, float]: The key is the loss name, and the value is the |
| average loss. |
| """ |
| all_loss = 0 |
| loss_dict: Dict[str, float] = dict() |
|
|
| for loss_name, loss_value in losses.items(): |
| avg_loss = loss_value.mean() |
| loss_dict[loss_name] = avg_loss |
| if 'loss' in loss_name: |
| all_loss += avg_loss |
|
|
| loss_dict[f'{stage}_loss'] = all_loss |
| return loss_dict |
|
|
|
|
| def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]: |
| """Update and record the losses of the network. |
| |
| Args: |
| outputs (list): The outputs of the network. |
| losses (dict): The losses of the network. |
| |
| Returns: |
| list: The updated outputs of the network. |
| dict: The updated losses of the network. |
| """ |
| if isinstance(outputs[-1], |
| BaseDataElement) and outputs[-1].keys() == ['loss']: |
| loss = outputs[-1].loss |
| outputs = outputs[:-1] |
| else: |
| loss = dict() |
|
|
| for loss_name, loss_value in loss.items(): |
| if loss_name not in losses: |
| losses[loss_name] = HistoryBuffer() |
| if isinstance(loss_value, torch.Tensor): |
| losses[loss_name].update(loss_value.item()) |
| elif is_list_of(loss_value, torch.Tensor): |
| for loss_value_i in loss_value: |
| losses[loss_name].update(loss_value_i.item()) |
| return outputs, losses |
|
|