| import os |
| import rasterio |
| import torch |
| from torchgeo.datasets import NonGeoDataset |
| from torch.utils.data import DataLoader |
| from torchgeo.datamodules import NonGeoDataModule |
|
|
| from methan_text_dataset import MethaneTextDataset |
|
|
| class MethaneTextDataModule(NonGeoDataModule): |
| """ |
| A DataModule for handling MethaneClassificationDataset |
| """ |
|
|
| def __init__( |
| self, |
| data_root: str, |
| paths: list, |
| captions: list, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| train_transform: callable = None, |
| val_transform: callable = None, |
| test_transform: callable = None, |
| **kwargs |
| ): |
| super().__init__(MethaneTextDataset, batch_size, num_workers, **kwargs) |
|
|
| self.data_root = data_root |
| self.paths = paths |
| self.captions = captions |
| self.train_transform = train_transform |
| self.val_transform = val_transform |
| self.test_transform = test_transform |
|
|
| def setup(self, stage: str = None): |
| if stage in ("fit", "train"): |
| self.train_dataset = MethaneTextDataset( |
| root_dir=self.data_root, |
| paths=self.paths, |
| captions=self.captions, |
| transform=self.train_transform, |
| ) |
| if stage in ("fit", "validate", "val"): |
| self.val_dataset = MethaneTextDataset( |
| root_dir=self.data_root, |
| paths=self.paths, |
| captions=self.captions, |
| transform=self.val_transform, |
| ) |
| if stage in ("test", "predict"): |
| self.test_dataset = MethaneTextDataset( |
| root_dir=self.data_root, |
| paths=self.paths, |
| captions=self.captions, |
| transform=self.test_transform, |
| ) |
|
|
| def train_dataloader(self): |
| return DataLoader( |
| self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True |
| ) |
|
|
| def val_dataloader(self): |
| return DataLoader( |
| self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True |
| ) |
|
|
| def test_dataloader(self): |
| return DataLoader( |
| self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True |
| ) |