| | import os |
| | import logging |
| | import json |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from PIL import Image |
| | import base64 |
| | from io import BytesIO |
| | import torch |
| | import lmdb |
| | from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode |
| | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler |
| | from torch.utils.data.distributed import DistributedSampler |
| | from torch.utils.data.sampler import SequentialSampler |
| | import torchvision.datasets as datasets |
| | from clip import tokenize |
| |
|
| |
|
| | def _convert_to_rgb(image): |
| | return image.convert('RGB') |
| |
|
| |
|
| | def _preprocess_text(text): |
| | |
| | text = text.lower().replace("“", "\"").replace("”", "\"") |
| | return text |
| |
|
| |
|
| | class EvalTxtDataset(Dataset): |
| | def __init__(self, jsonl_filename, max_txt_length=24): |
| | assert os.path.exists(jsonl_filename), "The annotation datafile {} not exists!".format(jsonl_filename) |
| |
|
| | logging.debug(f'Loading jsonl data from {jsonl_filename}.') |
| | self.texts = [] |
| | with open(jsonl_filename, "r", encoding="utf-8") as fin: |
| | for line in fin: |
| | obj = json.loads(line.strip()) |
| | text_id = obj['text_id'] |
| | text = obj['text'] |
| | self.texts.append((text_id, text)) |
| | logging.debug(f'Finished loading jsonl data from {jsonl_filename}.') |
| |
|
| | self.max_txt_length = max_txt_length |
| |
|
| | def __len__(self): |
| | return len(self.texts) |
| |
|
| | def __getitem__(self, idx): |
| | text_id, text = self.texts[idx] |
| | text = tokenize([_preprocess_text(str(text))], context_length=self.max_txt_length)[0] |
| | return text_id, text |
| |
|
| |
|
| | class EvalImgDataset(Dataset): |
| | def __init__(self, lmdb_imgs, resolution=224): |
| | assert os.path.isdir(lmdb_imgs), "The image LMDB directory {} not exists!".format(lmdb_imgs) |
| |
|
| | logging.debug(f'Loading image LMDB from {lmdb_imgs}.') |
| |
|
| | self.env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False) |
| | self.txn_imgs = self.env_imgs.begin(buffers=True) |
| | self.cursor_imgs = self.txn_imgs.cursor() |
| | self.iter_imgs = iter(self.cursor_imgs) |
| | self.number_images = int(self.txn_imgs.get(key=b'num_images').tobytes().decode('utf-8')) |
| | logging.info("The specified LMDB directory contains {} images.".format(self.number_images)) |
| |
|
| | self.transform = self._build_transform(resolution) |
| |
|
| | def _build_transform(self, resolution): |
| | normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
| | return Compose([ |
| | Resize((resolution, resolution), interpolation=InterpolationMode.BICUBIC), |
| | _convert_to_rgb, |
| | ToTensor(), |
| | normalize, |
| | ]) |
| |
|
| | def __len__(self): |
| | return self.number_images |
| |
|
| | def __getitem__(self, idx): |
| | img_id, image_b64 = next(self.iter_imgs) |
| | if img_id == b"num_images": |
| | img_id, image_b64 = next(self.iter_imgs) |
| |
|
| | img_id = img_id.tobytes() |
| | image_b64 = image_b64.tobytes() |
| |
|
| | img_id = int(img_id.decode(encoding="utf8", errors="ignore")) |
| | image_b64 = image_b64.decode(encoding="utf8", errors="ignore") |
| | image = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64))) |
| | image = self.transform(image) |
| |
|
| | return img_id, image |
| |
|
| |
|
| | @dataclass |
| | class DataInfo: |
| | dataloader: DataLoader |
| | sampler: DistributedSampler |
| |
|
| |
|
| | def get_eval_txt_dataset(args, max_txt_length=24): |
| | input_filename = args.text_data |
| | dataset = EvalTxtDataset( |
| | input_filename, |
| | max_txt_length=max_txt_length) |
| | num_samples = len(dataset) |
| | sampler = SequentialSampler(dataset) |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.text_batch_size, |
| | num_workers=0, |
| | pin_memory=True, |
| | sampler=sampler, |
| | drop_last=False, |
| | ) |
| | dataloader.num_samples = num_samples |
| | dataloader.num_batches = len(dataloader) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | def fetch_resolution(vision_model): |
| | |
| | vision_model_config_file = Path(__file__).parent.parent / f"clip/model_configs/{vision_model.replace('/', '-')}.json" |
| | with open(vision_model_config_file, 'r') as fv: |
| | model_info = json.load(fv) |
| | return model_info["image_resolution"] |
| |
|
| |
|
| | def get_eval_img_dataset(args): |
| | lmdb_imgs = args.image_data |
| | dataset = EvalImgDataset( |
| | lmdb_imgs, resolution=fetch_resolution(args.vision_model)) |
| | num_samples = len(dataset) |
| | sampler = SequentialSampler(dataset) |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.img_batch_size, |
| | num_workers=0, |
| | pin_memory=True, |
| | sampler=sampler, |
| | drop_last=False, |
| | ) |
| | dataloader.num_samples = num_samples |
| | dataloader.num_batches = len(dataloader) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | def get_zeroshot_dataset(args, preprocess_fn): |
| | dataset = datasets.ImageFolder(args.datapath, transform=preprocess_fn) |
| |
|
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=args.img_batch_size, |
| | num_workers=args.num_workers, |
| | sampler=None, |
| | ) |
| |
|
| | return DataInfo(dataloader, None) |