| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | import logging |
| | import os |
| | import json |
| | import pathlib |
| | from os.path import basename |
| |
|
| | from timm.data import create_loader |
| | import torch |
| | import torch.utils.data |
| | import torch.distributed as dist |
| | import torchvision.datasets as datasets |
| | from torchvision.io import read_image |
| | import torch.distributed as dist |
| | from pathlib import Path |
| | from yacs.config import CfgNode as CN |
| |
|
| | from ..LangEncoder import build_tokenizer |
| |
|
| | from .tsv import TSVImageTextDatasetV2 |
| | from .tsv import TSVMeta |
| | from .transforms import build_transforms |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def build_dataset(cfg, is_train): |
| | if cfg['DATASET']['DATASET'] == 'image_text_pairs_v2': |
| | dataset = _build_pairs_dataset_v2(cfg, is_train) |
| | else: |
| | raise ValueError(f'Unknown dataset: {cfg["DATASET"]["DATASET"]}') |
| | return dataset |
| |
|
| |
|
| | def _get_tsv_list(cfg, is_train): |
| | tmp_list = [] |
| | if is_train and 'TRAIN_TSV_LIST' in cfg['DATASET']: |
| | tmp_list = cfg['DATASET']['TRAIN_TSV_LIST'] |
| | elif 'TEST_TSV_LIST' in cfg['DATASET']: |
| | tmp_list = cfg['DATASET']['TEST_TSV_LIST'] |
| |
|
| | tsv_list = [] |
| | for l in tmp_list: |
| | if l.endswith('.list'): |
| | with open(l, 'r') as f: |
| | tsv_list.extend([i.strip() for i in f]) |
| | else: |
| | tsv_list.append(l) |
| |
|
| | logger.info(f'tsv list: {tsv_list}') |
| |
|
| | return tsv_list |
| |
|
| |
|
| | def _get_token_file(cfg): |
| | num_nodes = dist.get_world_size() // torch.cuda.device_count() |
| | if isinstance(cfg['DATASET']['TOKEN_FILE'], list): |
| | if num_nodes == 1: |
| | logger.warning('=> Multi token files are provided, but only one node is used for training') |
| | sas_token_file = cfg['DATASET']['TOKEN_FILE'][0] |
| | else: |
| | rank = dist.get_rank() |
| | node_idx = rank // torch.cuda.device_count() |
| | num_token_files = len(cfg['DATASET']['TOKEN_FILE']) |
| | sas_token_file = cfg['DATASET']['TOKEN_FILE'][node_idx % num_token_files] |
| | else: |
| | sas_token_file = cfg['DATASET']['TOKEN_FILE'] |
| |
|
| | sas_token_file = os.path.join(cfg['DATASET']['ROOT'], sas_token_file) |
| |
|
| | if ( |
| | cfg['DATASET']['LOADER'] == 'blobfuse' |
| | or not os.path.isfile(sas_token_file) |
| | ): |
| | sas_token_file = None |
| |
|
| | return sas_token_file |
| |
|
| |
|
| | def _build_pairs_dataset_v2(cfg, is_train): |
| | transforms = build_transforms(cfg, is_train) |
| | logger.info('transforms: {}'.format(transforms)) |
| |
|
| | dataset_name = cfg['DATASET']['TRAIN_SET'] \ |
| | if is_train else cfg['DATASET']['TEST_SET'] |
| |
|
| | tokenobj = build_tokenizer(cfg['LANG_ENCODER']) |
| |
|
| | if cfg['DATASET']['DATA_FORMAT'] != 'tsv': |
| | raise ValueError('Only support tsv format for pairs dataset v2') |
| |
|
| | tsv_list = _get_tsv_list(cfg, is_train) |
| |
|
| | if len(tsv_list) > 0: |
| | tsv_filenames = sorted( |
| | [ |
| | os.path.join(cfg['DATASET']['ROOT'], dataset_name, f) |
| | for f in tsv_list |
| | ] |
| | ) |
| | else: |
| | dataset_path = os.path.join(cfg['DATASET']['ROOT'], dataset_name) |
| | tsv_files = Path(dataset_path).glob('**/*.tsv') |
| |
|
| | tsv_filenames = sorted( |
| | [ |
| | str(path) |
| | for path in tsv_files |
| | ] |
| | ) |
| |
|
| | image_tsv_files = [ |
| | filename |
| | for filename in tsv_filenames |
| | if ( |
| | 'image-' in basename(filename) |
| | or 'image_' in basename(filename) |
| | or '_image' in basename(filename) |
| | or '-image' in basename(filename) |
| | or 'images-' in basename(filename) |
| | ) |
| | ] |
| | text_tsv_files = [ |
| | filename |
| | for filename in tsv_filenames |
| | if ( |
| | 'text-' in basename(filename) |
| | or 'text_' in basename(filename) |
| | or '_text' in basename(filename) |
| | or '-text' in basename(filename) |
| | or 'texts-' in basename(filename) |
| | ) |
| | ] |
| |
|
| | logger.info( |
| | "=> found %d/%d tsv file(s) to load.", |
| | len(image_tsv_files), len(text_tsv_files) |
| | ) |
| |
|
| | num_captions = 1 \ |
| | if is_train else cfg['DATASET'].get('NUM_CAPTIONS', 1) |
| | text_format = cfg['DATASET'].get('TEXT_FORMAT', 'json') |
| |
|
| | sas_token_file = _get_token_file(cfg) |
| | logger.info("=> SAS token path: %s", sas_token_file) |
| |
|
| | metas = [] |
| | cfg_data = cfg['DATASET'] |
| | if 'CLASSIFICATION_SETS' in cfg_data and 'NUM_CLASSES' in cfg_data: |
| | for source, num_classes in zip(cfg_data['CLASSIFICATION_SETS'], cfg_data['NUM_CLASSES']): |
| | metas.append( |
| | TSVMeta( |
| | source=source, |
| | num_classes=num_classes, |
| | task='classification' |
| | ) |
| | ) |
| | logger.info('=> add meta: {}'.format(metas[-1])) |
| |
|
| | if 'coco-caption' in dataset_name: |
| | logger.info('=> coco caption data is used') |
| | logger.info('=> update num_captions: 5, text_format: json') |
| | logger.warning('=> set sas token to None for coco evaluation') |
| | sas_token_file = None |
| | num_captions = 5 |
| | text_format = 'json' |
| |
|
| | dataset = TSVImageTextDatasetV2( |
| | image_tsv_files, text_tsv_files, |
| | transform=transforms, |
| | tokenize=tokenobj, |
| | context_length=cfg['LANG_ENCODER']['CONTEXT_LENGTH'], |
| | num_captions=num_captions, |
| | text_format=text_format, |
| | is_train=is_train, |
| | sas_token_path=sas_token_file, |
| | metas=metas, |
| | prompt_engineering=cfg['DATASET'].get('PROMPT_ENGINEERING', True), |
| | concat_queries=cfg['DATASET'].get('CONCAT_QUERIES', False) |
| | ) |
| |
|
| | logger.info( |
| | "=> %s set size: %d", 'train' |
| | if is_train else 'val', len(dataset) |
| | ) |
| |
|
| | return dataset |
| |
|
| |
|
| | def build_dataloader(cfg, is_train=True, distributed=False): |
| | dataset = build_dataset(cfg, is_train) |
| |
|
| | if ( |
| | is_train |
| | and 'TIMM_AUG' in cfg['AUG'] |
| | and cfg['AUG']['TIMM_AUG']['USE_LOADER'] |
| | ): |
| | logger.info('=> use timm loader for training') |
| | timm_cfg = CN(init_dict=cfg['AUG']['TIMM_AUG']) |
| | data_loader = create_loader( |
| | dataset, |
| | input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], |
| | batch_size=cfg['TRAIN']['BATCH_SIZE_PER_GPU'], |
| | is_training=True, |
| | use_prefetcher=True, |
| | no_aug=False, |
| | re_prob=timm_cfg.RE_PROB, |
| | re_mode=timm_cfg.RE_MODE, |
| | re_count=timm_cfg.RE_COUNT, |
| | re_split=timm_cfg.RE_SPLIT, |
| | scale=cfg['AUG']['SCALE'], |
| | ratio=cfg['AUG']['RATIO'], |
| | hflip=timm_cfg.HFLIP, |
| | vflip=timm_cfg.VFLIP, |
| | color_jitter=timm_cfg.COLOR_JITTER, |
| | auto_augment=timm_cfg.AUTO_AUGMENT, |
| | num_aug_splits=0, |
| | interpolation=cfg['AUG']['INTERPOLATION'], |
| | mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'], |
| | std=cfg['IMAGE_ENCODER']['IMAGE_STD'], |
| | num_workers=cfg['WORKERS'], |
| | distributed=distributed, |
| | collate_fn=None, |
| | pin_memory=cfg['PIN_MEMORY'], |
| | use_multi_epochs_loader=True |
| | ) |
| | else: |
| | if is_train: |
| | batch_size_per_gpu = cfg['TRAIN']['BATCH_SIZE_PER_GPU'] |
| | shuffle = cfg['TRAIN'].get('SHUFFLE', True) |
| | else: |
| | batch_size_per_gpu = cfg['TEST']['BATCH_SIZE_PER_GPU'] |
| | shuffle = cfg['TEST'].get('SHUFFLE', False) |
| |
|
| | if distributed or cfg.get('ALWAYS_ENABLE_SAMPLER', False): |
| | |
| | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
| | shuffle = False |
| | else: |
| | sampler = None |
| |
|
| | data_loader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=batch_size_per_gpu, |
| | shuffle=shuffle, |
| | num_workers=cfg['WORKERS'], |
| | pin_memory=cfg['PIN_MEMORY'], |
| | sampler=sampler, |
| | drop_last=True if is_train else False, |
| | prefetch_factor=cfg.get('PREFETCH_FACTOR', 2) |
| | ) |
| |
|
| | return data_loader |
| |
|
| |
|
| |
|
| |
|
| |
|