| | |
| | from typing import List, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from .typing_utils import SampleList |
| |
|
| |
|
| | def add_prefix(inputs, prefix): |
| | """Add prefix for dict. |
| | |
| | Args: |
| | inputs (dict): The input dict with str keys. |
| | prefix (str): The prefix to add. |
| | |
| | Returns: |
| | |
| | dict: The dict with keys updated with ``prefix``. |
| | """ |
| |
|
| | outputs = dict() |
| | for name, value in inputs.items(): |
| | outputs[f'{prefix}.{name}'] = value |
| |
|
| | return outputs |
| |
|
| |
|
| | def stack_batch(inputs: List[torch.Tensor], |
| | data_samples: Optional[SampleList] = None, |
| | size: Optional[tuple] = None, |
| | size_divisor: Optional[int] = None, |
| | pad_val: Union[int, float] = 0, |
| | seg_pad_val: Union[int, float] = 255) -> torch.Tensor: |
| | """Stack multiple inputs to form a batch and pad the images and gt_sem_segs |
| | to the max shape use the right bottom padding mode. |
| | |
| | Args: |
| | inputs (List[Tensor]): The input multiple tensors. each is a |
| | CHW 3D-tensor. |
| | data_samples (list[:obj:`SegDataSample`]): The list of data samples. |
| | It usually includes information such as `gt_sem_seg`. |
| | size (tuple, optional): Fixed padding size. |
| | size_divisor (int, optional): The divisor of padded size. |
| | pad_val (int, float): The padding value. Defaults to 0 |
| | seg_pad_val (int, float): The padding value. Defaults to 255 |
| | |
| | Returns: |
| | Tensor: The 4D-tensor. |
| | List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. |
| | """ |
| | assert isinstance(inputs, list), \ |
| | f'Expected input type to be list, but got {type(inputs)}' |
| | assert len({tensor.ndim for tensor in inputs}) == 1, \ |
| | f'Expected the dimensions of all inputs must be the same, ' \ |
| | f'but got {[tensor.ndim for tensor in inputs]}' |
| | assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ |
| | f'but got {inputs[0].ndim}' |
| | assert len({tensor.shape[0] for tensor in inputs}) == 1, \ |
| | f'Expected the channels of all inputs must be the same, ' \ |
| | f'but got {[tensor.shape[0] for tensor in inputs]}' |
| |
|
| | |
| | assert (size is not None) ^ (size_divisor is not None), \ |
| | 'only one of size and size_divisor should be valid' |
| |
|
| | padded_inputs = [] |
| | padded_samples = [] |
| | inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] |
| | max_size = np.stack(inputs_sizes).max(0) |
| | if size_divisor is not None and size_divisor > 1: |
| | |
| | max_size = (max_size + |
| | (size_divisor - 1)) // size_divisor * size_divisor |
| |
|
| | for i in range(len(inputs)): |
| | tensor = inputs[i] |
| | if size is not None: |
| | width = max(size[-1] - tensor.shape[-1], 0) |
| | height = max(size[-2] - tensor.shape[-2], 0) |
| | |
| | padding_size = (0, width, 0, height) |
| | elif size_divisor is not None: |
| | width = max(max_size[-1] - tensor.shape[-1], 0) |
| | height = max(max_size[-2] - tensor.shape[-2], 0) |
| | padding_size = (0, width, 0, height) |
| | else: |
| | padding_size = [0, 0, 0, 0] |
| |
|
| | |
| | pad_img = F.pad(tensor, padding_size, value=pad_val) |
| | padded_inputs.append(pad_img) |
| | |
| | if data_samples is not None: |
| | data_sample = data_samples[i] |
| | pad_shape = None |
| | if 'gt_sem_seg' in data_sample: |
| | gt_sem_seg = data_sample.gt_sem_seg.data |
| | del data_sample.gt_sem_seg.data |
| | data_sample.gt_sem_seg.data = F.pad( |
| | gt_sem_seg, padding_size, value=seg_pad_val) |
| | pad_shape = data_sample.gt_sem_seg.shape |
| | if 'gt_edge_map' in data_sample: |
| | gt_edge_map = data_sample.gt_edge_map.data |
| | del data_sample.gt_edge_map.data |
| | data_sample.gt_edge_map.data = F.pad( |
| | gt_edge_map, padding_size, value=seg_pad_val) |
| | pad_shape = data_sample.gt_edge_map.shape |
| | if 'gt_depth_map' in data_sample: |
| | gt_depth_map = data_sample.gt_depth_map.data |
| | del data_sample.gt_depth_map.data |
| | data_sample.gt_depth_map.data = F.pad( |
| | gt_depth_map, padding_size, value=seg_pad_val) |
| | pad_shape = data_sample.gt_depth_map.shape |
| | data_sample.set_metainfo({ |
| | 'img_shape': tensor.shape[-2:], |
| | 'pad_shape': pad_shape, |
| | 'padding_size': padding_size |
| | }) |
| | padded_samples.append(data_sample) |
| | else: |
| | padded_samples.append( |
| | dict( |
| | img_padding_size=padding_size, |
| | pad_shape=pad_img.shape[-2:])) |
| |
|
| | return torch.stack(padded_inputs, dim=0), padded_samples |
| |
|