| | |
| | from abc import abstractmethod |
| | from typing import Dict, List, Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from mmengine.registry import MODELS |
| | from mmengine.structures import BaseDataElement |
| | from .base_model import BaseModel |
| |
|
| | |
| | EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]] |
| | |
| | |
| | |
| | EnhancedBatchDataSamples = List[List[BaseDataElement]] |
| | DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs, |
| | EnhancedBatchDataSamples]], tuple, dict] |
| | MergedDataSamples = List[BaseDataElement] |
| |
|
| |
|
| | @MODELS.register_module() |
| | class BaseTTAModel(BaseModel): |
| | """Base model for inference with test-time augmentation. |
| | |
| | ``BaseTTAModel`` is a wrapper for inference given multi-batch data. |
| | It implements the :meth:`test_step` for multi-batch data inference. |
| | ``multi-batch`` data means data processed by different augmentation |
| | from the same batch. |
| | |
| | During test time augmentation, the data processed by |
| | :obj:`mmcv.transforms.TestTimeAug`, and then collated by |
| | ``pseudo_collate`` will have the following format: |
| | |
| | .. code-block:: |
| | |
| | result = dict( |
| | inputs=[ |
| | [image1_aug1, image2_aug1], |
| | [image1_aug2, image2_aug2] |
| | ], |
| | data_samples=[ |
| | [data_sample1_aug1, data_sample2_aug1], |
| | [data_sample1_aug2, data_sample2_aug2], |
| | ] |
| | ) |
| | |
| | ``image{i}_aug{j}`` means the i-th image of the batch, which is |
| | augmented by the j-th augmentation. |
| | |
| | ``BaseTTAModel`` will collate the data to: |
| | |
| | .. code-block:: |
| | |
| | data1 = dict( |
| | inputs=[image1_aug1, image2_aug1], |
| | data_samples=[data_sample1_aug1, data_sample2_aug1] |
| | ) |
| | |
| | data2 = dict( |
| | inputs=[image1_aug2, image2_aug2], |
| | data_samples=[data_sample1_aug2, data_sample2_aug2] |
| | ) |
| | |
| | ``data1`` and ``data2`` will be passed to model, and the results will be |
| | merged by :meth:`merge_preds`. |
| | |
| | Note: |
| | :meth:`merge_preds` is an abstract method, all subclasses should |
| | implement it. |
| | |
| | Warning: |
| | If ``data_preprocessor`` is not None, it will overwrite the model's |
| | ``data_preprocessor``. |
| | |
| | Args: |
| | module (dict or nn.Module): Tested model. |
| | data_preprocessor (dict or :obj:`BaseDataPreprocessor`, optional): |
| | If model does not define ``data_preprocessor``, it will be the |
| | default value for model. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | module: Union[dict, nn.Module], |
| | data_preprocessor: Union[dict, nn.Module, None] = None, |
| | ): |
| | super().__init__() |
| | if isinstance(module, nn.Module): |
| | self.module = module |
| | elif isinstance(module, dict): |
| | if data_preprocessor is not None: |
| | module['data_preprocessor'] = data_preprocessor |
| | self.module = MODELS.build(module) |
| | else: |
| | raise TypeError('The type of module should be a `nn.Module` ' |
| | f'instance or a dict, but got {module}') |
| | assert hasattr(self.module, 'test_step'), ( |
| | 'Model wrapped by BaseTTAModel must implement `test_step`!') |
| |
|
| | @abstractmethod |
| | def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \ |
| | -> MergedDataSamples: |
| | """Merge predictions of enhanced data to one prediction. |
| | |
| | Args: |
| | data_samples_list (EnhancedBatchDataSamples): List of predictions |
| | of all enhanced data. |
| | |
| | Returns: |
| | List[BaseDataElement]: Merged prediction. |
| | """ |
| |
|
| | def test_step(self, data): |
| | """Get predictions of each enhanced data, a multiple predictions. |
| | |
| | Args: |
| | data (DataBatch): Enhanced data batch sampled from dataloader. |
| | |
| | Returns: |
| | MergedDataSamples: Merged prediction. |
| | """ |
| | data_list: Union[List[dict], List[list]] |
| | if isinstance(data, dict): |
| | num_augs = len(data[next(iter(data))]) |
| | data_list = [{key: value[idx] |
| | for key, value in data.items()} |
| | for idx in range(num_augs)] |
| | elif isinstance(data, (tuple, list)): |
| | num_augs = len(data[0]) |
| | data_list = [[_data[idx] for _data in data] |
| | for idx in range(num_augs)] |
| | else: |
| | raise TypeError('data given by dataLoader should be a dict, ' |
| | f'tuple or a list, but got {type(data)}') |
| |
|
| | predictions = [] |
| | for data in data_list: |
| | predictions.append(self.module.test_step(data)) |
| | return self.merge_preds(list(zip(*predictions))) |
| |
|
| | def forward(self, |
| | inputs: torch.Tensor, |
| | data_samples: Optional[list] = None, |
| | mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: |
| | """``BaseTTAModel.forward`` should not be called.""" |
| | raise NotImplementedError( |
| | '`BaseTTAModel.forward` will not be called during training or' |
| | 'testing. Please call `test_step` instead. If you want to use' |
| | '`BaseTTAModel.forward`, please implement this method') |
| |
|