| import sys |
| import torch.utils.data as data |
| from os import listdir |
| from utils.tools import default_loader, is_image_file, normalize |
| import os |
|
|
| import torchvision.transforms as transforms |
|
|
|
|
| class Dataset(data.Dataset): |
| def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False): |
| super(Dataset, self).__init__() |
| if with_subfolder: |
| self.samples = self._find_samples_in_subfolders(data_path) |
| else: |
| self.samples = [x for x in listdir(data_path) if is_image_file(x)] |
| self.data_path = data_path |
| self.image_shape = image_shape[:-1] |
| self.random_crop = random_crop |
| self.return_name = return_name |
|
|
| def __getitem__(self, index): |
| path = os.path.join(self.data_path, self.samples[index]) |
| img = default_loader(path) |
|
|
| if self.random_crop: |
| imgw, imgh = img.size |
| if imgh < self.image_shape[0] or imgw < self.image_shape[1]: |
| img = transforms.Resize(min(self.image_shape))(img) |
| img = transforms.RandomCrop(self.image_shape)(img) |
| else: |
| img = transforms.Resize(self.image_shape)(img) |
| img = transforms.RandomCrop(self.image_shape)(img) |
|
|
| img = transforms.ToTensor()(img) |
| img = normalize(img) |
|
|
| if self.return_name: |
| return self.samples[index], img |
| else: |
| return img |
|
|
| def _find_samples_in_subfolders(self, dir): |
| """ |
| Finds the class folders in a dataset. |
| Args: |
| dir (string): Root directory path. |
| Returns: |
| tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. |
| Ensures: |
| No class is a subdirectory of another. |
| """ |
| if sys.version_info >= (3, 5): |
| |
| classes = [d.name for d in os.scandir(dir) if d.is_dir()] |
| else: |
| classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] |
| classes.sort() |
| class_to_idx = {classes[i]: i for i in range(len(classes))} |
| samples = [] |
| for target in sorted(class_to_idx.keys()): |
| d = os.path.join(dir, target) |
| if not os.path.isdir(d): |
| continue |
| for root, _, fnames in sorted(os.walk(d)): |
| for fname in sorted(fnames): |
| if is_image_file(fname): |
| path = os.path.join(root, fname) |
| |
| |
| samples.append(path) |
| return samples |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|