| | import os |
| | import random |
| | import math |
| | import numpy as np |
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| |
|
| |
|
| | class CustomCocoDataset(Dataset): |
| | def __init__(self, img_folder, img_size=512, hint_size=448): |
| | self.img_folder = img_folder |
| | self.img_size = img_size |
| | self.hint_size = hint_size |
| | self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))] |
| |
|
| | def __len__(self): |
| | return len(self.ids) |
| |
|
| | def __getitem__(self, index): |
| | img_id = self.ids[index] |
| | img_path = os.path.join(self.img_folder, img_id + '.png') |
| | image = Image.open(img_path).convert('RGB') |
| |
|
| | |
| | cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0) |
| |
|
| | |
| | cropped_image = Image.fromarray(cropped_image) |
| |
|
| | |
| | jpg_image = transforms.functional.to_tensor(cropped_image) |
| | hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC) |
| | hint_image = transforms.functional.to_tensor(hint_image) |
| |
|
| | |
| | prompt = "" |
| |
|
| | return dict(jpg=jpg_image, txt=prompt, hint=hint_image) |
| |
|
| | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): |
| | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) |
| | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) |
| | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) |
| |
|
| | |
| | |
| | |
| | while min(*pil_image.size) >= 2 * smaller_dim_size: |
| | pil_image = pil_image.resize( |
| | tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| | ) |
| |
|
| | scale = smaller_dim_size / min(*pil_image.size) |
| | pil_image = pil_image.resize( |
| | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| | ) |
| |
|
| | arr = np.array(pil_image) |
| | crop_y = random.randrange(arr.shape[0] - image_size + 1) |
| | crop_x = random.randrange(arr.shape[1] - image_size + 1) |
| | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] |
| |
|
| |
|
| | if __name__ == "__main__": |
| | dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train") |
| | print(len(dataset)) |
| | print(dataset[0]) |
| |
|
| | from torch.utils.data import DataLoader |
| | dataloader = DataLoader( |
| | dataset, batch_size=4, num_workers=2, |
| | pin_memory=True, drop_last=True) |
| | |
| | |
| | batch = next(iter(dataloader)) |
| |
|
| | |
| | jpg_images = batch['jpg'] |
| | hint_images = batch['hint'] |
| | prompts = batch['txt'] |
| |
|
| | |
| | print(f"Prompt: {prompts}") |
| |
|
| | |
| | import matplotlib.pyplot as plt |
| |
|
| | for i in range(len(jpg_images)): |
| | plt.figure(figsize=(10, 5)) |
| | |
| | plt.subplot(1, 2, 1) |
| | plt.title(f"JPG Image {i+1} (512x512)") |
| | plt.imshow(jpg_images[i].permute(1, 2, 0)) |
| |
|
| | plt.subplot(1, 2, 2) |
| | plt.title(f"Hint Image {i+1} (448x448)") |
| | plt.imshow(hint_images[i].permute(1, 2, 0)) |
| |
|
| | |
| | plt.savefig(f'output_image_{i+1}.png') |
| |
|
| | |
| | plt.close() |
| |
|
| |
|
| |
|