| """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. |
| |
| It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. |
| """ |
| import random |
| import numpy as np |
| import torch.utils.data as data |
| from PIL import Image, ImageOps |
| import torchvision.transforms as transforms |
| from abc import ABC, abstractmethod |
|
|
|
|
| class BaseDataset(data.Dataset, ABC): |
| """This class is an abstract base class (ABC) for datasets. |
| |
| To create a subclass, you need to implement the following four functions: |
| -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). |
| -- <__len__>: return the size of dataset. |
| -- <__getitem__>: get a data point. |
| -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. |
| """ |
|
|
| def __init__(self, opt): |
| """Initialize the class; save the options in the class |
| |
| Parameters: |
| opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
| """ |
| self.opt = opt |
| self.root = opt.dataroot |
|
|
| @staticmethod |
| def modify_commandline_options(parser, is_train): |
| """Add new dataset-specific options, and rewrite default values for existing options. |
| |
| Parameters: |
| parser -- original option parser |
| is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. |
| |
| Returns: |
| the modified parser. |
| """ |
| return parser |
|
|
| @abstractmethod |
| def __len__(self): |
| """Return the total number of images in the dataset.""" |
| return 0 |
|
|
| @abstractmethod |
| def __getitem__(self, index): |
| """Return a data point and its metadata information. |
| |
| Parameters: |
| index - - a random integer for data indexing |
| |
| Returns: |
| a dictionary of data with their names. It ususally contains the data itself and its metadata information. |
| """ |
| pass |
|
|
|
|
| def get_params(opt, size): |
| w, h = size |
| new_h = h |
| new_w = w |
| crop = 0 |
| if opt.preprocess == 'resize_and_crop': |
| new_h = new_w = opt.load_size |
| elif opt.preprocess == 'scale_width_and_crop': |
| new_w = opt.load_size |
| new_h = opt.load_size * h // w |
|
|
| |
| |
|
|
| x = random.randint(crop, np.maximum(0, new_w - opt.crop_size-crop)) |
| y = random.randint(crop, np.maximum(0, new_h - opt.crop_size-crop)) |
|
|
| flip = random.random() > 0.5 |
|
|
| return {'crop_pos': (x, y), 'flip': flip} |
|
|
|
|
| def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): |
| transform_list = [] |
| if grayscale: |
| transform_list.append(transforms.Grayscale(1)) |
| if 'resize' in opt.preprocess: |
| osize = [opt.load_size, opt.load_size] |
| transform_list.append(transforms.Resize(osize, method)) |
| elif 'scale_width' in opt.preprocess: |
| transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) |
|
|
| if 'crop' in opt.preprocess: |
| if params is None: |
| |
| transform_list.append(transforms.CenterCrop(opt.crop_size)) |
| else: |
| transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) |
|
|
| if opt.preprocess == 'none': |
| transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=2**8, method=method))) |
|
|
| if not opt.no_flip: |
| if params is None: |
| transform_list.append(transforms.RandomHorizontalFlip()) |
| elif params['flip']: |
| transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) |
| |
| |
| if convert: |
| transform_list += [transforms.ToTensor()] |
| if grayscale: |
| transform_list += [transforms.Normalize((0.5,), (0.5,))] |
| else: |
| transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| return transforms.Compose(transform_list) |
|
|
|
|
| def __make_power_2(img, base, method=Image.BICUBIC): |
| ow, oh = img.size |
| h = int((oh+base-1) // base * base) |
| w = int((ow+base-1) // base * base) |
| if (h == oh) and (w == ow): |
| return img |
|
|
| __print_size_warning(ow, oh, w, h) |
| return ImageOps.expand(img, (0, 0, w-ow, h-oh), fill=255) |
|
|
|
|
| def __scale_width(img, target_width, method=Image.BICUBIC): |
| ow, oh = img.size |
| if (ow == target_width): |
| return img |
| w = target_width |
| h = int(target_width * oh / ow) |
| return img.resize((w, h), method) |
|
|
|
|
| def __crop(img, pos, size): |
| ow, oh = img.size |
| x1, y1 = pos |
| tw = th = size |
| if (ow > tw or oh > th): |
| return img.crop((x1, y1, x1 + tw, y1 + th)) |
| return img |
|
|
|
|
| def __flip(img, flip): |
| if flip: |
| return img.transpose(Image.FLIP_LEFT_RIGHT) |
| return img |
|
|
|
|
| def __print_size_warning(ow, oh, w, h): |
| """Print warning information about image size(only print once)""" |
| if not hasattr(__print_size_warning, 'has_printed'): |
| print("The image size needs to be a multiple of 4. " |
| "The loaded image size was (%d, %d), so it was adjusted to " |
| "(%d, %d). This adjustment will be done to all images " |
| "whose sizes are not multiples of 4" % (ow, oh, w, h)) |
| __print_size_warning.has_printed = True |
|
|
|
|