| from torch.utils.data import Dataset |
| from PIL import Image |
| from utils import data_utils |
|
|
|
|
| class InferenceDataset(Dataset): |
|
|
| def __init__(self, root, opts, transform=None, preprocess=None): |
| self.paths = sorted(data_utils.make_dataset(root)) |
| self.transform = transform |
| self.preprocess = preprocess |
| self.opts = opts |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, index): |
| from_path = self.paths[index] |
| if self.preprocess is not None: |
| from_im = self.preprocess(from_path) |
| else: |
| from_im = Image.open(from_path).convert('RGB') |
| if self.transform: |
| from_im = self.transform(from_im) |
| return from_im |
|
|