| import os |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
|
|
| class TinyImageNetDataset(Dataset): |
| def __init__(self, root_dir, transform=None, train=True): |
| self.root_dir = root_dir |
| self.transform = transform |
| self.image_paths = [] |
| |
| if train: |
| |
| train_dir = os.path.join(root_dir, 'train') |
| for cls in os.listdir(train_dir): |
| cls_dir = os.path.join(train_dir, cls, 'images') |
| for img_name in os.listdir(cls_dir): |
| if img_name.endswith('.JPEG'): |
| self.image_paths.append(os.path.join(cls_dir, img_name)) |
| else: |
| |
| val_dir = os.path.join(root_dir, 'val') |
| images_dir = os.path.join(val_dir, 'images') |
| for img_name in os.listdir(images_dir): |
| if img_name.endswith('.JPEG'): |
| self.image_paths.append(os.path.join(images_dir, img_name)) |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| img = Image.open(self.image_paths[idx]).convert('RGB') |
| if self.transform: |
| img = self.transform(img) |
| return img, 0 |
|
|
| def get_dataloaders(config): |
| transform = transforms.Compose([ |
| transforms.Resize(config.image_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
| |
| train_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=True) |
| val_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=False) |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| num_workers=config.num_workers, |
| pin_memory=True |
| ) |
| |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=config.batch_size, |
| shuffle=False, |
| num_workers=config.num_workers |
| ) |
| |
| return train_loader, val_loader |