| import os |
| import cv2 |
| import torch |
| import albumentations as A |
|
|
| import config as CFG |
|
|
|
|
| class CLIPDataset(torch.utils.data.Dataset): |
| def __init__(self, image_filenames, captions, tokenizer, transforms): |
| """ |
| image_filenames and cpations must have the same length; so, if there are |
| multiple captions for each image, the image_filenames must have repetitive |
| file names |
| """ |
|
|
| self.image_filenames = image_filenames |
| self.captions = list(captions) |
| self.encoded_captions = tokenizer( |
| list(captions), padding=True, truncation=True, max_length=CFG.max_length |
| ) |
| self.transforms = transforms |
|
|
| def __getitem__(self, idx): |
| item = { |
| key: torch.tensor(values[idx]) |
| for key, values in self.encoded_captions.items() |
| } |
|
|
| image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| image = self.transforms(image=image)['image'] |
| item['image'] = torch.tensor(image).permute(2, 0, 1).float() |
| item['caption'] = self.captions[idx] |
|
|
| return item |
|
|
|
|
| def __len__(self): |
| return len(self.captions) |
|
|
|
|
|
|
| def get_transforms(mode="train"): |
| if mode == "train": |
| return A.Compose( |
| [ |
| A.Resize(CFG.size, CFG.size, always_apply=True), |
| A.Normalize(max_pixel_value=255.0, always_apply=True), |
| ] |
| ) |
| else: |
| return A.Compose( |
| [ |
| A.Resize(CFG.size, CFG.size, always_apply=True), |
| A.Normalize(max_pixel_value=255.0, always_apply=True), |
| ] |
| ) |
|
|
| |