| | |
| | import collections |
| | import os.path as osp |
| | import random |
| | from typing import Dict, List |
| |
|
| | import mmengine |
| | from mmengine.dataset import BaseDataset |
| |
|
| | from mmdet.registry import DATASETS |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class RefCocoDataset(BaseDataset): |
| | """RefCOCO dataset. |
| | |
| | The `Refcoco` and `Refcoco+` dataset is based on |
| | `ReferItGame: Referring to Objects in Photographs of Natural Scenes |
| | <http://tamaraberg.com/papers/referit.pdf>`_. |
| | |
| | The `Refcocog` dataset is based on |
| | `Generation and Comprehension of Unambiguous Object Descriptions |
| | <https://arxiv.org/abs/1511.02283>`_. |
| | |
| | Args: |
| | ann_file (str): Annotation file path. |
| | data_root (str): The root directory for ``data_prefix`` and |
| | ``ann_file``. Defaults to ''. |
| | data_prefix (str): Prefix for training data. |
| | split_file (str): Split file path. |
| | split (str): Split name. Defaults to 'train'. |
| | text_mode (str): Text mode. Defaults to 'random'. |
| | **kwargs: Other keyword arguments in :class:`BaseDataset`. |
| | """ |
| |
|
| | def __init__(self, |
| | data_root: str, |
| | ann_file: str, |
| | split_file: str, |
| | data_prefix: Dict, |
| | split: str = 'train', |
| | text_mode: str = 'random', |
| | **kwargs): |
| | self.split_file = split_file |
| | self.split = split |
| |
|
| | assert text_mode in ['original', 'random', 'concat', 'select_first'] |
| | self.text_mode = text_mode |
| | super().__init__( |
| | data_root=data_root, |
| | data_prefix=data_prefix, |
| | ann_file=ann_file, |
| | **kwargs, |
| | ) |
| |
|
| | def _join_prefix(self): |
| | if not mmengine.is_abs(self.split_file) and self.split_file: |
| | self.split_file = osp.join(self.data_root, self.split_file) |
| |
|
| | return super()._join_prefix() |
| |
|
| | def _init_refs(self): |
| | """Initialize the refs for RefCOCO.""" |
| | anns, imgs = {}, {} |
| | for ann in self.instances['annotations']: |
| | anns[ann['id']] = ann |
| | for img in self.instances['images']: |
| | imgs[img['id']] = img |
| |
|
| | refs, ref_to_ann = {}, {} |
| | for ref in self.splits: |
| | |
| | ref_id = ref['ref_id'] |
| | ann_id = ref['ann_id'] |
| | |
| | refs[ref_id] = ref |
| | ref_to_ann[ref_id] = anns[ann_id] |
| |
|
| | self.refs = refs |
| | self.ref_to_ann = ref_to_ann |
| |
|
| | def load_data_list(self) -> List[dict]: |
| | """Load data list.""" |
| | self.splits = mmengine.load(self.split_file, file_format='pkl') |
| | self.instances = mmengine.load(self.ann_file, file_format='json') |
| | self._init_refs() |
| | img_prefix = self.data_prefix['img_path'] |
| |
|
| | ref_ids = [ |
| | ref['ref_id'] for ref in self.splits if ref['split'] == self.split |
| | ] |
| | full_anno = [] |
| | for ref_id in ref_ids: |
| | ref = self.refs[ref_id] |
| | ann = self.ref_to_ann[ref_id] |
| | ann.update(ref) |
| | full_anno.append(ann) |
| |
|
| | image_id_list = [] |
| | final_anno = {} |
| | for anno in full_anno: |
| | image_id_list.append(anno['image_id']) |
| | final_anno[anno['ann_id']] = anno |
| | annotations = [value for key, value in final_anno.items()] |
| |
|
| | coco_train_id = [] |
| | image_annot = {} |
| | for i in range(len(self.instances['images'])): |
| | coco_train_id.append(self.instances['images'][i]['id']) |
| | image_annot[self.instances['images'][i] |
| | ['id']] = self.instances['images'][i] |
| |
|
| | images = [] |
| | for image_id in list(set(image_id_list)): |
| | images += [image_annot[image_id]] |
| |
|
| | data_list = [] |
| |
|
| | grounding_dict = collections.defaultdict(list) |
| | for anno in annotations: |
| | image_id = int(anno['image_id']) |
| | grounding_dict[image_id].append(anno) |
| |
|
| | join_path = mmengine.fileio.get_file_backend(img_prefix).join_path |
| | for image in images: |
| | img_id = image['id'] |
| | instances = [] |
| | sentences = [] |
| | for grounding_anno in grounding_dict[img_id]: |
| | texts = [x['raw'].lower() for x in grounding_anno['sentences']] |
| | |
| | if self.text_mode == 'random': |
| | idx = random.randint(0, len(texts) - 1) |
| | text = [texts[idx]] |
| | |
| | elif self.text_mode == 'concat': |
| | text = [''.join(texts)] |
| | |
| | elif self.text_mode == 'select_first': |
| | text = [texts[0]] |
| | |
| | elif self.text_mode == 'original': |
| | text = texts |
| | else: |
| | raise ValueError(f'Invalid text mode "{self.text_mode}".') |
| | ins = [{ |
| | 'mask': grounding_anno['segmentation'], |
| | 'ignore_flag': 0 |
| | }] * len(text) |
| | instances.extend(ins) |
| | sentences.extend(text) |
| | data_info = { |
| | 'img_path': join_path(img_prefix, image['file_name']), |
| | 'img_id': img_id, |
| | 'instances': instances, |
| | 'text': sentences |
| | } |
| | data_list.append(data_info) |
| |
|
| | if len(data_list) == 0: |
| | raise ValueError(f'No sample in split "{self.split}".') |
| |
|
| | return data_list |
| |
|