| import torch.utils.data |
| from data.base_data_loader import BaseDataLoader |
|
|
|
|
| def CreateDataset(opt): |
| dataset = None |
| from data.aligned_dataset_test import AlignedDataset |
| dataset = AlignedDataset() |
|
|
| print("dataset [%s] was created" % (dataset.name())) |
| dataset.initialize(opt) |
| return dataset |
|
|
| class CustomDatasetDataLoader(BaseDataLoader): |
| def name(self): |
| return 'CustomDatasetDataLoader' |
|
|
| def initialize(self, opt): |
| BaseDataLoader.initialize(self, opt) |
| self.dataset = CreateDataset(opt) |
| self.dataloader = torch.utils.data.DataLoader( |
| self.dataset, |
| batch_size=opt.batchSize, |
| shuffle = False, |
| num_workers=int(opt.nThreads)) |
|
|
| def load_data(self): |
| return self.dataloader |
|
|
| def __len__(self): |
| return min(len(self.dataset), self.opt.max_dataset_size) |
|
|