Spaces:
Runtime error
Runtime error
| """This file contains functions to prepare dataloader in the way lightning expects""" | |
| import pytorch_lightning as pl | |
| import torchvision.datasets as datasets | |
| from lightning_fabric.utilities.seed import seed_everything | |
| from modules.dataset import CIFAR10Transforms, apply_cifar_image_transformations | |
| from torch.utils.data import DataLoader, random_split | |
| class CIFARDataModule(pl.LightningDataModule): | |
| """Lightning DataModule for CIFAR10 dataset""" | |
| def __init__(self, data_path, batch_size, seed, val_split=0, num_workers=0): | |
| super().__init__() | |
| self.data_path = data_path | |
| self.batch_size = batch_size | |
| self.seed = seed | |
| self.val_split = val_split | |
| self.num_workers = num_workers | |
| self.dataloader_dict = { | |
| # "shuffle": True, | |
| "batch_size": self.batch_size, | |
| "num_workers": self.num_workers, | |
| "pin_memory": True, | |
| # "worker_init_fn": self._init_fn, | |
| "persistent_workers": self.num_workers > 0, | |
| } | |
| self.prepare_data_per_node = False | |
| # Fixes attribute defined outside __init__ warning | |
| self.training_dataset = None | |
| self.validation_dataset = None | |
| self.testing_dataset = None | |
| # # Make sure data is downloaded | |
| # self.prepare_data() | |
| def _split_train_val(self, dataset): | |
| """Split the dataset into train and validation sets""" | |
| # Throw an error if the validation split is not between 0 and 1 | |
| if not 0 < self.val_split < 1: | |
| raise ValueError("Validation split must be between 0 and 1") | |
| # # Set seed again, might not be necessary | |
| # seed_everything(int(self.seed)) | |
| # Calculate lengths of each dataset | |
| total_length = len(dataset) | |
| train_length = int((1 - self.val_split) * total_length) | |
| val_length = total_length - train_length | |
| # Split the dataset | |
| train_dataset, val_dataset = random_split(dataset, [train_length, val_length]) | |
| return train_dataset, val_dataset | |
| # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data | |
| def prepare_data(self): | |
| # Download the CIFAR10 dataset if it doesn't exist | |
| datasets.CIFAR10(self.data_path, train=True, download=True) | |
| datasets.CIFAR10(self.data_path, train=False, download=True) | |
| # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup | |
| # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.hooks.DataHooks.html#lightning.pytorch.core.hooks.DataHooks.setup | |
| def setup(self, stage=None): | |
| # seed_everything(int(self.seed)) | |
| # Define the data transformations | |
| train_transforms, test_transforms = apply_cifar_image_transformations() | |
| val_transforms = test_transforms | |
| # Create train and validation datasets | |
| if stage == "fit" or stage is None: | |
| if self.val_split != 0: | |
| # Split the training data into training and validation sets | |
| data_train, data_val = self._split_train_val(datasets.CIFAR10(self.data_path, train=True)) | |
| # Apply transformations | |
| self.training_dataset = CIFAR10Transforms(data_train, train_transforms) | |
| self.validation_dataset = CIFAR10Transforms(data_val, val_transforms) | |
| else: | |
| # Only training data here | |
| self.training_dataset = CIFAR10Transforms( | |
| datasets.CIFAR10(self.data_path, train=True), train_transforms | |
| ) | |
| # Validation will be same sa test | |
| self.validation_dataset = CIFAR10Transforms( | |
| datasets.CIFAR10(self.data_path, train=False), val_transforms | |
| ) | |
| # Create test dataset | |
| if stage == "test" or stage is None: | |
| # Assign Test split(s) for use in Dataloaders | |
| self.testing_dataset = CIFAR10Transforms(datasets.CIFAR10(self.data_path, train=False), test_transforms) | |
| # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#train-dataloader | |
| def train_dataloader(self): | |
| return DataLoader(self.training_dataset, **self.dataloader_dict, shuffle=True) | |
| # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#val-dataloader | |
| def val_dataloader(self): | |
| return DataLoader(self.validation_dataset, **self.dataloader_dict, shuffle=False) | |
| # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#test-dataloader | |
| def test_dataloader(self): | |
| return DataLoader(self.testing_dataset, **self.dataloader_dict, shuffle=False) | |
| def _init_fn(self, worker_id): | |
| seed_everything(int(self.seed) + worker_id) |