| | |
| | import copy |
| | import functools |
| | import gc |
| | import logging |
| | import pickle |
| | from collections.abc import Mapping |
| | from typing import Any, Callable, List, Optional, Sequence, Tuple, Union |
| |
|
| | import numpy as np |
| | from torch.utils.data import Dataset |
| |
|
| | from mmengine.config import Config |
| | from mmengine.fileio import join_path, list_from_file, load |
| | from mmengine.logging import print_log |
| | from mmengine.registry import TRANSFORMS |
| | from mmengine.utils import is_abs |
| |
|
| |
|
| | class Compose: |
| | """Compose multiple transforms sequentially. |
| | |
| | Args: |
| | transforms (Sequence[dict, callable], optional): Sequence of transform |
| | object or config dict to be composed. |
| | """ |
| |
|
| | def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]): |
| | self.transforms: List[Callable] = [] |
| |
|
| | if transforms is None: |
| | transforms = [] |
| |
|
| | for transform in transforms: |
| | |
| | |
| | if isinstance(transform, dict): |
| | transform = TRANSFORMS.build(transform) |
| | if not callable(transform): |
| | raise TypeError(f'transform should be a callable object, ' |
| | f'but got {type(transform)}') |
| | self.transforms.append(transform) |
| | elif callable(transform): |
| | self.transforms.append(transform) |
| | else: |
| | raise TypeError( |
| | f'transform must be a callable object or dict, ' |
| | f'but got {type(transform)}') |
| |
|
| | def __call__(self, data: dict) -> Optional[dict]: |
| | """Call function to apply transforms sequentially. |
| | |
| | Args: |
| | data (dict): A result dict contains the data to transform. |
| | |
| | Returns: |
| | dict: Transformed data. |
| | """ |
| | for t in self.transforms: |
| | data = t(data) |
| | |
| | |
| | |
| | |
| | if data is None: |
| | return None |
| | return data |
| |
|
| | def __repr__(self): |
| | """Print ``self.transforms`` in sequence. |
| | |
| | Returns: |
| | str: Formatted string. |
| | """ |
| | format_string = self.__class__.__name__ + '(' |
| | for t in self.transforms: |
| | format_string += '\n' |
| | format_string += f' {t}' |
| | format_string += '\n)' |
| | return format_string |
| |
|
| |
|
| | def force_full_init(old_func: Callable) -> Any: |
| | """Those methods decorated by ``force_full_init`` will be forced to call |
| | ``full_init`` if the instance has not been fully initiated. |
| | |
| | Args: |
| | old_func (Callable): Decorated function, make sure the first arg is an |
| | instance with ``full_init`` method. |
| | |
| | Returns: |
| | Any: Depends on old_func. |
| | """ |
| |
|
| | @functools.wraps(old_func) |
| | def wrapper(obj: object, *args, **kwargs): |
| | |
| | if not hasattr(obj, 'full_init'): |
| | raise AttributeError(f'{type(obj)} does not have full_init ' |
| | 'method.') |
| | |
| | |
| | |
| | if not getattr(obj, '_fully_initialized', False): |
| | print_log( |
| | f'Attribute `_fully_initialized` is not defined in ' |
| | f'{type(obj)} or `type(obj)._fully_initialized is ' |
| | 'False, `full_init` will be called and ' |
| | f'{type(obj)}._fully_initialized will be set to True', |
| | logger='current', |
| | level=logging.WARNING) |
| | obj.full_init() |
| | obj._fully_initialized = True |
| |
|
| | return old_func(obj, *args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | class BaseDataset(Dataset): |
| | r"""BaseDataset for open source projects in OpenMMLab. |
| | |
| | The annotation format is shown as follows. |
| | |
| | .. code-block:: none |
| | |
| | { |
| | "metainfo": |
| | { |
| | "dataset_type": "test_dataset", |
| | "task_name": "test_task" |
| | }, |
| | "data_list": |
| | [ |
| | { |
| | "img_path": "test_img.jpg", |
| | "height": 604, |
| | "width": 640, |
| | "instances": |
| | [ |
| | { |
| | "bbox": [0, 0, 10, 20], |
| | "bbox_label": 1, |
| | "mask": [[0,0],[0,10],[10,20],[20,0]], |
| | "extra_anns": [1,2,3] |
| | }, |
| | { |
| | "bbox": [10, 10, 110, 120], |
| | "bbox_label": 2, |
| | "mask": [[10,10],[10,110],[110,120],[120,10]], |
| | "extra_anns": [4,5,6] |
| | } |
| | ] |
| | }, |
| | ] |
| | } |
| | |
| | Args: |
| | ann_file (str, optional): Annotation file path. Defaults to ''. |
| | metainfo (Mapping or Config, optional): Meta information for |
| | dataset, such as class information. Defaults to None. |
| | data_root (str, optional): The root directory for ``data_prefix`` and |
| | ``ann_file``. Defaults to ''. |
| | data_prefix (dict): Prefix for training data. Defaults to |
| | dict(img_path=''). |
| | filter_cfg (dict, optional): Config for filter data. Defaults to None. |
| | indices (int or Sequence[int], optional): Support using first few |
| | data in annotation file to facilitate training/testing on a smaller |
| | serialize_data (bool, optional): Whether to hold memory using |
| | serialized objects, when enabled, data loader workers can use |
| | shared RAM from master process instead of making a copy. Defaults |
| | to True. |
| | pipeline (list, optional): Processing pipeline. Defaults to []. |
| | test_mode (bool, optional): ``test_mode=True`` means in test phase. |
| | Defaults to False. |
| | lazy_init (bool, optional): Whether to load annotation during |
| | instantiation. In some cases, such as visualization, only the meta |
| | information of the dataset is needed, which is not necessary to |
| | load annotation file. ``Basedataset`` can skip load annotations to |
| | save time by set ``lazy_init=True``. Defaults to False. |
| | max_refetch (int, optional): If ``Basedataset.prepare_data`` get a |
| | None img. The maximum extra number of cycles to get a valid |
| | image. Defaults to 1000. |
| | |
| | Note: |
| | BaseDataset collects meta information from ``annotation file`` (the |
| | lowest priority), ``BaseDataset.METAINFO``(medium) and ``metainfo |
| | parameter`` (highest) passed to constructors. The lower priority meta |
| | information will be overwritten by higher one. |
| | |
| | Note: |
| | Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc. |
| | should not inherit from ``BaseDataset`` since ``get_subset`` and |
| | ``get_subset_`` could produce ambiguous meaning sub-dataset which |
| | conflicts with original dataset. |
| | |
| | Examples: |
| | >>> # Assume the annotation file is given above. |
| | >>> class CustomDataset(BaseDataset): |
| | >>> METAINFO: dict = dict(task_name='custom_task', |
| | >>> dataset_type='custom_type') |
| | >>> metainfo=dict(task_name='custom_task_name') |
| | >>> custom_dataset = CustomDataset( |
| | >>> 'path/to/ann_file', |
| | >>> metainfo=metainfo) |
| | >>> # meta information of annotation file will be overwritten by |
| | >>> # `CustomDataset.METAINFO`. The merged meta information will |
| | >>> # further be overwritten by argument `metainfo`. |
| | >>> custom_dataset.metainfo |
| | {'task_name': custom_task_name, dataset_type: custom_type} |
| | """ |
| |
|
| | METAINFO: dict = dict() |
| | _fully_initialized: bool = False |
| |
|
| | def __init__(self, |
| | ann_file: Optional[str] = '', |
| | metainfo: Union[Mapping, Config, None] = None, |
| | data_root: Optional[str] = '', |
| | data_prefix: dict = dict(img_path=''), |
| | filter_cfg: Optional[dict] = None, |
| | indices: Optional[Union[int, Sequence[int]]] = None, |
| | serialize_data: bool = True, |
| | pipeline: List[Union[dict, Callable]] = [], |
| | test_mode: bool = False, |
| | lazy_init: bool = False, |
| | max_refetch: int = 1000): |
| | self.ann_file = ann_file |
| | self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) |
| | self.data_root = data_root |
| | self.data_prefix = copy.copy(data_prefix) |
| | self.filter_cfg = copy.deepcopy(filter_cfg) |
| | self._indices = indices |
| | self.serialize_data = serialize_data |
| | self.test_mode = test_mode |
| | self.max_refetch = max_refetch |
| | self.data_list: List[dict] = [] |
| | self.data_bytes: np.ndarray |
| |
|
| | |
| | self._join_prefix() |
| |
|
| | |
| | self.pipeline = Compose(pipeline) |
| | |
| | if not lazy_init: |
| | self.full_init() |
| |
|
| | @force_full_init |
| | def get_data_info(self, idx: int) -> dict: |
| | """Get annotation by index and automatically call ``full_init`` if the |
| | dataset has not been fully initialized. |
| | |
| | Args: |
| | idx (int): The index of data. |
| | |
| | Returns: |
| | dict: The idx-th annotation of the dataset. |
| | """ |
| | if self.serialize_data: |
| | start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() |
| | end_addr = self.data_address[idx].item() |
| | bytes = memoryview( |
| | self.data_bytes[start_addr:end_addr]) |
| | data_info = pickle.loads(bytes) |
| | else: |
| | data_info = copy.deepcopy(self.data_list[idx]) |
| | |
| | |
| | if idx >= 0: |
| | data_info['sample_idx'] = idx |
| | else: |
| | data_info['sample_idx'] = len(self) + idx |
| |
|
| | return data_info |
| |
|
| | def full_init(self): |
| | """Load annotation file and set ``BaseDataset._fully_initialized`` to |
| | True. |
| | |
| | If ``lazy_init=False``, ``full_init`` will be called during the |
| | instantiation and ``self._fully_initialized`` will be set to True. If |
| | ``obj._fully_initialized=False``, the class method decorated by |
| | ``force_full_init`` will call ``full_init`` automatically. |
| | |
| | Several steps to initialize annotation: |
| | |
| | - load_data_list: Load annotations from annotation file. |
| | - filter data information: Filter annotations according to |
| | filter_cfg. |
| | - slice_data: Slice dataset according to ``self._indices`` |
| | - serialize_data: Serialize ``self.data_list`` if |
| | ``self.serialize_data`` is True. |
| | """ |
| | if self._fully_initialized: |
| | return |
| | |
| | self.data_list = self.load_data_list() |
| | |
| | self.data_list = self.filter_data() |
| | |
| | if self._indices is not None: |
| | self.data_list = self._get_unserialized_subset(self._indices) |
| |
|
| | |
| | if self.serialize_data: |
| | self.data_bytes, self.data_address = self._serialize_data() |
| |
|
| | self._fully_initialized = True |
| |
|
| | @property |
| | def metainfo(self) -> dict: |
| | """Get meta information of dataset. |
| | |
| | Returns: |
| | dict: meta information collected from ``BaseDataset.METAINFO``, |
| | annotation file and metainfo argument during instantiation. |
| | """ |
| | return copy.deepcopy(self._metainfo) |
| |
|
| | def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: |
| | """Parse raw annotation to target format. |
| | |
| | This method should return dict or list of dict. Each dict or list |
| | contains the data information of a training sample. If the protocol of |
| | the sample annotations is changed, this function can be overridden to |
| | update the parsing logic while keeping compatibility. |
| | |
| | Args: |
| | raw_data_info (dict): Raw data information load from ``ann_file`` |
| | |
| | Returns: |
| | list or list[dict]: Parsed annotation. |
| | """ |
| | for prefix_key, prefix in self.data_prefix.items(): |
| | assert prefix_key in raw_data_info, ( |
| | f'raw_data_info: {raw_data_info} dose not contain prefix key' |
| | f'{prefix_key}, please check your data_prefix.') |
| | raw_data_info[prefix_key] = join_path(prefix, |
| | raw_data_info[prefix_key]) |
| | return raw_data_info |
| |
|
| | def filter_data(self) -> List[dict]: |
| | """Filter annotations according to filter_cfg. Defaults return all |
| | ``data_list``. |
| | |
| | If some ``data_list`` could be filtered according to specific logic, |
| | the subclass should override this method. |
| | |
| | Returns: |
| | list[int]: Filtered results. |
| | """ |
| | return self.data_list |
| |
|
| | def get_cat_ids(self, idx: int) -> List[int]: |
| | """Get category ids by index. Dataset wrapped by ClassBalancedDataset |
| | must implement this method. |
| | |
| | The ``ClassBalancedDataset`` requires a subclass which implements this |
| | method. |
| | |
| | Args: |
| | idx (int): The index of data. |
| | |
| | Returns: |
| | list[int]: All categories in the image of specified index. |
| | """ |
| | raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` ' |
| | 'method') |
| |
|
| | def __getitem__(self, idx: int) -> dict: |
| | """Get the idx-th image and data information of dataset after |
| | ``self.pipeline``, and ``full_init`` will be called if the dataset has |
| | not been fully initialized. |
| | |
| | During training phase, if ``self.pipeline`` get ``None``, |
| | ``self._rand_another`` will be called until a valid image is fetched or |
| | the maximum limit of refetech is reached. |
| | |
| | Args: |
| | idx (int): The index of self.data_list. |
| | |
| | Returns: |
| | dict: The idx-th image and data information of dataset after |
| | ``self.pipeline``. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if not self._fully_initialized: |
| | print_log( |
| | 'Please call `full_init()` method manually to accelerate ' |
| | 'the speed.', |
| | logger='current', |
| | level=logging.WARNING) |
| | self.full_init() |
| |
|
| | if self.test_mode: |
| | data = self.prepare_data(idx) |
| | if data is None: |
| | raise Exception('Test time pipline should not get `None` ' |
| | 'data_sample') |
| | return data |
| |
|
| | for _ in range(self.max_refetch + 1): |
| | data = self.prepare_data(idx) |
| | |
| | |
| | if data is None: |
| | idx = self._rand_another() |
| | continue |
| | return data |
| |
|
| | raise Exception(f'Cannot find valid image after {self.max_refetch}! ' |
| | 'Please check your image path and pipeline') |
| |
|
| | def load_data_list(self) -> List[dict]: |
| | """Load annotations from an annotation file named as ``self.ann_file`` |
| | |
| | If the annotation file does not follow `OpenMMLab 2.0 format dataset |
| | <https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html>`_ . |
| | The subclass must override this method for load annotations. The meta |
| | information of annotation file will be overwritten :attr:`METAINFO` |
| | and ``metainfo`` argument of constructor. |
| | |
| | Returns: |
| | list[dict]: A list of annotation. |
| | """ |
| | |
| | |
| | annotations = load(self.ann_file) |
| | if not isinstance(annotations, dict): |
| | raise TypeError(f'The annotations loaded from annotation file ' |
| | f'should be a dict, but got {type(annotations)}!') |
| | if 'data_list' not in annotations or 'metainfo' not in annotations: |
| | raise ValueError('Annotation must have data_list and metainfo ' |
| | 'keys') |
| | metainfo = annotations['metainfo'] |
| | raw_data_list = annotations['data_list'] |
| |
|
| | |
| | |
| | |
| | for k, v in metainfo.items(): |
| | self._metainfo.setdefault(k, v) |
| |
|
| | |
| | data_list = [] |
| | for raw_data_info in raw_data_list: |
| | |
| | data_info = self.parse_data_info(raw_data_info) |
| | if isinstance(data_info, dict): |
| | |
| | |
| | data_list.append(data_info) |
| | elif isinstance(data_info, list): |
| | |
| | |
| | |
| | |
| | for item in data_info: |
| | if not isinstance(item, dict): |
| | raise TypeError('data_info must be list of dict, but ' |
| | f'got {type(item)}') |
| | data_list.extend(data_info) |
| | else: |
| | raise TypeError('data_info should be a dict or list of dict, ' |
| | f'but got {type(data_info)}') |
| |
|
| | return data_list |
| |
|
| | @classmethod |
| | def _load_metainfo(cls, |
| | metainfo: Union[Mapping, Config, None] = None) -> dict: |
| | """Collect meta information from the dictionary of meta. |
| | |
| | Args: |
| | metainfo (Mapping or Config, optional): Meta information dict. |
| | If ``metainfo`` contains existed filename, it will be |
| | parsed by ``list_from_file``. |
| | |
| | Returns: |
| | dict: Parsed meta information. |
| | """ |
| | |
| | cls_metainfo = copy.deepcopy(cls.METAINFO) |
| | if metainfo is None: |
| | return cls_metainfo |
| | if not isinstance(metainfo, (Mapping, Config)): |
| | raise TypeError('metainfo should be a Mapping or Config, ' |
| | f'but got {type(metainfo)}') |
| |
|
| | for k, v in metainfo.items(): |
| | if isinstance(v, str): |
| | |
| | |
| | try: |
| | cls_metainfo[k] = list_from_file(v) |
| | except (TypeError, FileNotFoundError): |
| | print_log( |
| | f'{v} is not a meta file, simply parsed as meta ' |
| | 'information', |
| | logger='current', |
| | level=logging.WARNING) |
| | cls_metainfo[k] = v |
| | else: |
| | cls_metainfo[k] = v |
| | return cls_metainfo |
| |
|
| | def _join_prefix(self): |
| | """Join ``self.data_root`` with ``self.data_prefix`` and |
| | ``self.ann_file``. |
| | |
| | Examples: |
| | >>> # self.data_prefix contains relative paths |
| | >>> self.data_root = 'a/b/c' |
| | >>> self.data_prefix = dict(img='d/e/') |
| | >>> self.ann_file = 'f' |
| | >>> self._join_prefix() |
| | >>> self.data_prefix |
| | dict(img='a/b/c/d/e') |
| | >>> self.ann_file |
| | 'a/b/c/f' |
| | >>> # self.data_prefix contains absolute paths |
| | >>> self.data_root = 'a/b/c' |
| | >>> self.data_prefix = dict(img='/d/e/') |
| | >>> self.ann_file = 'f' |
| | >>> self._join_prefix() |
| | >>> self.data_prefix |
| | dict(img='/d/e') |
| | >>> self.ann_file |
| | 'a/b/c/f' |
| | """ |
| | |
| | |
| | if self.ann_file and not is_abs(self.ann_file) and self.data_root: |
| | self.ann_file = join_path(self.data_root, self.ann_file) |
| | |
| | |
| | for data_key, prefix in self.data_prefix.items(): |
| | if not isinstance(prefix, str): |
| | raise TypeError('prefix should be a string, but got ' |
| | f'{type(prefix)}') |
| | if not is_abs(prefix) and self.data_root: |
| | self.data_prefix[data_key] = join_path(self.data_root, prefix) |
| | else: |
| | self.data_prefix[data_key] = prefix |
| |
|
| | @force_full_init |
| | def get_subset_(self, indices: Union[Sequence[int], int]) -> None: |
| | """The in-place version of ``get_subset`` to convert dataset to a |
| | subset of original dataset. |
| | |
| | This method will convert the original dataset to a subset of dataset. |
| | If type of indices is int, ``get_subset_`` will return a subdataset |
| | which contains the first or last few data information according to |
| | indices is positive or negative. If type of indices is a sequence of |
| | int, the subdataset will extract the data information according to |
| | the index given in indices. |
| | |
| | Examples: |
| | >>> dataset = BaseDataset('path/to/ann_file') |
| | >>> len(dataset) |
| | 100 |
| | >>> dataset.get_subset_(90) |
| | >>> len(dataset) |
| | 90 |
| | >>> # if type of indices is sequence, extract the corresponding |
| | >>> # index data information |
| | >>> dataset.get_subset_([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) |
| | >>> len(dataset) |
| | 10 |
| | >>> dataset.get_subset_(-3) |
| | >>> len(dataset) # Get the latest few data information. |
| | 3 |
| | |
| | Args: |
| | indices (int or Sequence[int]): If type of indices is int, indices |
| | represents the first or last few data of dataset according to |
| | indices is positive or negative. If type of indices is |
| | Sequence, indices represents the target data information |
| | index of dataset. |
| | """ |
| | |
| | |
| | if self.serialize_data: |
| | self.data_bytes, self.data_address = \ |
| | self._get_serialized_subset(indices) |
| | else: |
| | self.data_list = self._get_unserialized_subset(indices) |
| |
|
| | @force_full_init |
| | def get_subset(self, indices: Union[Sequence[int], int]) -> 'BaseDataset': |
| | """Return a subset of dataset. |
| | |
| | This method will return a subset of original dataset. If type of |
| | indices is int, ``get_subset_`` will return a subdataset which |
| | contains the first or last few data information according to |
| | indices is positive or negative. If type of indices is a sequence of |
| | int, the subdataset will extract the information according to the index |
| | given in indices. |
| | |
| | Examples: |
| | >>> dataset = BaseDataset('path/to/ann_file') |
| | >>> len(dataset) |
| | 100 |
| | >>> subdataset = dataset.get_subset(90) |
| | >>> len(sub_dataset) |
| | 90 |
| | >>> # if type of indices is list, extract the corresponding |
| | >>> # index data information |
| | >>> subdataset = dataset.get_subset([0, 1, 2, 3, 4, 5, 6, 7, |
| | >>> 8, 9]) |
| | >>> len(sub_dataset) |
| | 10 |
| | >>> subdataset = dataset.get_subset(-3) |
| | >>> len(subdataset) # Get the latest few data information. |
| | 3 |
| | |
| | Args: |
| | indices (int or Sequence[int]): If type of indices is int, indices |
| | represents the first or last few data of dataset according to |
| | indices is positive or negative. If type of indices is |
| | Sequence, indices represents the target data information |
| | index of dataset. |
| | |
| | Returns: |
| | BaseDataset: A subset of dataset. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | sub_dataset = self._copy_without_annotation() |
| | |
| | if self.serialize_data: |
| | data_bytes, data_address = \ |
| | self._get_serialized_subset(indices) |
| | sub_dataset.data_bytes = data_bytes.copy() |
| | sub_dataset.data_address = data_address.copy() |
| | else: |
| | data_list = self._get_unserialized_subset(indices) |
| | sub_dataset.data_list = copy.deepcopy(data_list) |
| | return sub_dataset |
| |
|
| | def _get_serialized_subset(self, indices: Union[Sequence[int], int]) \ |
| | -> Tuple[np.ndarray, np.ndarray]: |
| | """Get subset of serialized data information list. |
| | |
| | Args: |
| | indices (int or Sequence[int]): If type of indices is int, |
| | indices represents the first or last few data of serialized |
| | data information list. If type of indices is Sequence, indices |
| | represents the target data information index which consist of |
| | subset data information. |
| | |
| | Returns: |
| | Tuple[np.ndarray, np.ndarray]: subset of serialized data |
| | information. |
| | """ |
| | sub_data_bytes: Union[List, np.ndarray] |
| | sub_data_address: Union[List, np.ndarray] |
| | if isinstance(indices, int): |
| | if indices >= 0: |
| | assert indices < len(self.data_address), \ |
| | f'{indices} is out of dataset length({len(self)}' |
| | |
| | end_addr = self.data_address[indices - 1].item() \ |
| | if indices > 0 else 0 |
| | |
| | |
| | sub_data_bytes = self.data_bytes[:end_addr] |
| | |
| | |
| | sub_data_address = self.data_address[:indices] |
| | else: |
| | assert -indices <= len(self.data_address), \ |
| | f'{indices} is out of dataset length({len(self)}' |
| | |
| | ignored_bytes_size = self.data_address[indices - 1] |
| | start_addr = self.data_address[indices - 1].item() |
| | sub_data_bytes = self.data_bytes[start_addr:] |
| | sub_data_address = self.data_address[indices:] |
| | sub_data_address = sub_data_address - ignored_bytes_size |
| | elif isinstance(indices, Sequence): |
| | sub_data_bytes = [] |
| | sub_data_address = [] |
| | for idx in indices: |
| | assert len(self) > idx >= -len(self) |
| | start_addr = 0 if idx == 0 else \ |
| | self.data_address[idx - 1].item() |
| | end_addr = self.data_address[idx].item() |
| | |
| | sub_data_bytes.append(self.data_bytes[start_addr:end_addr]) |
| | |
| | sub_data_address.append(end_addr - start_addr) |
| | |
| | if sub_data_bytes: |
| | sub_data_bytes = np.concatenate(sub_data_bytes) |
| | sub_data_address = np.cumsum(sub_data_address) |
| | else: |
| | sub_data_bytes = np.array([]) |
| | sub_data_address = np.array([]) |
| | else: |
| | raise TypeError('indices should be a int or sequence of int, ' |
| | f'but got {type(indices)}') |
| | return sub_data_bytes, sub_data_address |
| |
|
| | def _get_unserialized_subset(self, indices: Union[Sequence[int], |
| | int]) -> list: |
| | """Get subset of data information list. |
| | |
| | Args: |
| | indices (int or Sequence[int]): If type of indices is int, |
| | indices represents the first or last few data of data |
| | information. If type of indices is Sequence, indices represents |
| | the target data information index which consist of subset data |
| | information. |
| | |
| | Returns: |
| | Tuple[np.ndarray, np.ndarray]: subset of data information. |
| | """ |
| | if isinstance(indices, int): |
| | if indices >= 0: |
| | |
| | sub_data_list = self.data_list[:indices] |
| | else: |
| | |
| | sub_data_list = self.data_list[indices:] |
| | elif isinstance(indices, Sequence): |
| | |
| | sub_data_list = [] |
| | for idx in indices: |
| | sub_data_list.append(self.data_list[idx]) |
| | else: |
| | raise TypeError('indices should be a int or sequence of int, ' |
| | f'but got {type(indices)}') |
| | return sub_data_list |
| |
|
| | def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: |
| | """Serialize ``self.data_list`` to save memory when launching multiple |
| | workers in data loading. This function will be called in ``full_init``. |
| | |
| | Hold memory using serialized objects, and data loader workers can use |
| | shared RAM from master process instead of making a copy. |
| | |
| | Returns: |
| | Tuple[np.ndarray, np.ndarray]: Serialized result and corresponding |
| | address. |
| | """ |
| |
|
| | def _serialize(data): |
| | buffer = pickle.dumps(data, protocol=4) |
| | return np.frombuffer(buffer, dtype=np.uint8) |
| |
|
| | |
| | |
| | |
| | data_list = [_serialize(x) for x in self.data_list] |
| | address_list = np.asarray([len(x) for x in data_list], dtype=np.int64) |
| | data_address: np.ndarray = np.cumsum(address_list) |
| | |
| | data_bytes = np.concatenate(data_list) |
| | |
| | |
| | self.data_list.clear() |
| | gc.collect() |
| | return data_bytes, data_address |
| |
|
| | def _rand_another(self) -> int: |
| | """Get random index. |
| | |
| | Returns: |
| | int: Random index from 0 to ``len(self)-1`` |
| | """ |
| | return np.random.randint(0, len(self)) |
| |
|
| | def prepare_data(self, idx) -> Any: |
| | """Get data processed by ``self.pipeline``. |
| | |
| | Args: |
| | idx (int): The index of ``data_info``. |
| | |
| | Returns: |
| | Any: Depends on ``self.pipeline``. |
| | """ |
| | data_info = self.get_data_info(idx) |
| | return self.pipeline(data_info) |
| |
|
| | @force_full_init |
| | def __len__(self) -> int: |
| | """Get the length of filtered dataset and automatically call |
| | ``full_init`` if the dataset has not been fully init. |
| | |
| | Returns: |
| | int: The length of filtered dataset. |
| | """ |
| | if self.serialize_data: |
| | return len(self.data_address) |
| | else: |
| | return len(self.data_list) |
| |
|
| | def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': |
| | """Deepcopy for all attributes other than ``data_list``, |
| | ``data_address`` and ``data_bytes``. |
| | |
| | Args: |
| | memo: Memory dict which used to reconstruct complex object |
| | correctly. |
| | """ |
| | cls = self.__class__ |
| | other = cls.__new__(cls) |
| | memo[id(self)] = other |
| |
|
| | for key, value in self.__dict__.items(): |
| | if key in ['data_list', 'data_address', 'data_bytes']: |
| | continue |
| | super(BaseDataset, other).__setattr__(key, |
| | copy.deepcopy(value, memo)) |
| |
|
| | return other |
| |
|