| import os |
| import random |
| import math |
| import numpy as np |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
|
|
|
|
| class CustomCocoDataset(Dataset): |
| def __init__(self, img_folder, img_size=512, hint_size=448): |
| self.img_folder = img_folder |
| self.img_size = img_size |
| self.hint_size = hint_size |
| self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))] |
|
|
| def __len__(self): |
| return len(self.ids) |
|
|
| def __getitem__(self, index): |
| img_id = self.ids[index] |
| img_path = os.path.join(self.img_folder, img_id + '.png') |
| image = Image.open(img_path).convert('RGB') |
|
|
| |
| cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0) |
|
|
| |
| cropped_image = Image.fromarray(cropped_image) |
|
|
| |
| jpg_image = transforms.functional.to_tensor(cropped_image) |
| hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC) |
| hint_image = transforms.functional.to_tensor(hint_image) |
|
|
| |
| prompt = "" |
|
|
| return dict(jpg=jpg_image, txt=prompt, hint=hint_image) |
|
|
| def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): |
| min_smaller_dim_size = math.ceil(image_size / max_crop_frac) |
| max_smaller_dim_size = math.ceil(image_size / min_crop_frac) |
| smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) |
|
|
| |
| |
| |
| while min(*pil_image.size) >= 2 * smaller_dim_size: |
| pil_image = pil_image.resize( |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| ) |
|
|
| scale = smaller_dim_size / min(*pil_image.size) |
| pil_image = pil_image.resize( |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| ) |
|
|
| arr = np.array(pil_image) |
| crop_y = random.randrange(arr.shape[0] - image_size + 1) |
| crop_x = random.randrange(arr.shape[1] - image_size + 1) |
| return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] |
|
|
|
|
| if __name__ == "__main__": |
| dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train") |
| print(len(dataset)) |
| print(dataset[0]) |
|
|
| from torch.utils.data import DataLoader |
| dataloader = DataLoader( |
| dataset, batch_size=4, num_workers=2, |
| pin_memory=True, drop_last=True) |
| |
| |
| batch = next(iter(dataloader)) |
|
|
| |
| jpg_images = batch['jpg'] |
| hint_images = batch['hint'] |
| prompts = batch['txt'] |
|
|
| |
| print(f"Prompt: {prompts}") |
|
|
| |
| import matplotlib.pyplot as plt |
|
|
| for i in range(len(jpg_images)): |
| plt.figure(figsize=(10, 5)) |
| |
| plt.subplot(1, 2, 1) |
| plt.title(f"JPG Image {i+1} (512x512)") |
| plt.imshow(jpg_images[i].permute(1, 2, 0)) |
|
|
| plt.subplot(1, 2, 2) |
| plt.title(f"Hint Image {i+1} (448x448)") |
| plt.imshow(hint_images[i].permute(1, 2, 0)) |
|
|
| |
| plt.savefig(f'output_image_{i+1}.png') |
|
|
| |
| plt.close() |
|
|
|
|
|
|