| |
|
|
| from copy import deepcopy |
| from pathlib import Path |
| from typing import Any, Dict, List |
| |
| import numpy as np |
| |
| |
| |
| from omegaconf import DictConfig, OmegaConf |
| import pytorch_lightning as pl |
| from dataset.UAV.dataset import UavMapPair |
| |
| |
| from torch.utils.data import Dataset, ConcatDataset |
| from torch.utils.data import Dataset, DataLoader, random_split |
| import torchvision.transforms as tvf |
|
|
| |
| class UavMapDatasetModule(pl.LightningDataModule): |
|
|
|
|
| def __init__(self, cfg: Dict[str, Any]): |
| super().__init__() |
|
|
| |
| |
| |
| self.cfg=cfg |
| |
| |
| |
| |
| |
|
|
| tfs = [] |
| tfs.append(tvf.ToTensor()) |
| tfs.append(tvf.Resize(self.cfg.image_size)) |
| self.val_tfs = tvf.Compose(tfs) |
|
|
| |
| if cfg.augmentation.image.apply: |
| args = OmegaConf.masked_copy( |
| cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"] |
| ) |
| tfs.append(tvf.ColorJitter(**args)) |
| self.train_tfs = tvf.Compose(tfs) |
|
|
| |
| |
| self.init() |
| def init(self): |
| self.train_dataset = ConcatDataset([ |
| UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs) |
| for city in self.cfg.train_citys |
| ]) |
|
|
| self.val_dataset = ConcatDataset([ |
| UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) |
| for city in self.cfg.val_citys |
| ]) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def train_dataloader(self): |
| train_loader = DataLoader(self.train_dataset, |
| batch_size=self.cfg.train.batch_size, |
| num_workers=self.cfg.train.num_workers, |
| shuffle=True,pin_memory = True) |
| return train_loader |
|
|
| def val_dataloader(self): |
| val_loader = DataLoader(self.val_dataset, |
| batch_size=self.cfg.val.batch_size, |
| num_workers=self.cfg.val.num_workers, |
| shuffle=True,pin_memory = True) |
| |
| |
| |
| |
| |
| |
| return val_loader |
|
|