| | import random |
| | import numpy as np |
| | from PIL import Image, ImageOps, ImageFilter |
| | import torch |
| | import torch.utils.data as data |
| |
|
| | __all__ = ['BaseDataset'] |
| |
|
| | class BaseDataset(data.Dataset): |
| | def __init__(self, root, split, mode=None, transform=None, |
| | target_transform=None, base_size=1024, crop_size=512): |
| | self.root = root |
| | self.transform = transform |
| | self.target_transform = target_transform |
| | self.split = split |
| | self.mode = mode if mode is not None else split |
| | self.base_size = base_size |
| | self.crop_size = crop_size |
| | if self.mode == 'train': |
| | print('BaseDataset: base_size {}, crop_size {}'. \ |
| | format(base_size, crop_size)) |
| |
|
| | @property |
| | def num_class(self): |
| | return self.NUM_CLASS |
| |
|
| | def _val_transform(self, img, mask): |
| | outsize = self.crop_size |
| | short_size = outsize |
| | w, h = img.size |
| | if w > h: |
| | oh = short_size |
| | ow = int(1.0 * w * oh / h) |
| | else: |
| | ow = short_size |
| | oh = int(1.0 * h * ow / w) |
| | img = img.resize((ow, oh), Image.BILINEAR) |
| | mask = mask.resize((ow, oh), Image.NEAREST) |
| | |
| | w, h = img.size |
| | x1 = int(round((w - outsize) / 2.)) |
| | y1 = int(round((h - outsize) / 2.)) |
| | img = img.crop((x1, y1, x1+outsize, y1+outsize)) |
| | mask = mask.crop((x1, y1, x1+outsize, y1+outsize)) |
| | |
| | return img, self._mask_transform(mask) |
| |
|
| | def _testval_transform(self, img, mask): |
| | outsize = self.crop_size |
| | short_size = outsize |
| | w, h = img.size |
| | if w > h: |
| | oh = short_size |
| | ow = int(1.0 * w * oh / h) |
| | else: |
| | ow = short_size |
| | oh = int(1.0 * h * ow / w) |
| | img = img.resize((ow, oh), Image.BILINEAR) |
| | return img, self._mask_transform(mask) |
| |
|
| | def _train_transform(self, img, mask): |
| | |
| | if random.random() < 0.5: |
| | img = img.transpose(Image.FLIP_LEFT_RIGHT) |
| | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) |
| | crop_size = self.crop_size |
| | w, h = img.size |
| | long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) |
| | if h > w: |
| | oh = long_size |
| | ow = int(1.0 * w * long_size / h + 0.5) |
| | short_size = ow |
| | else: |
| | ow = long_size |
| | oh = int(1.0 * h * long_size / w + 0.5) |
| | short_size = oh |
| | img = img.resize((ow, oh), Image.BILINEAR) |
| | mask = mask.resize((ow, oh), Image.NEAREST) |
| | |
| | if short_size < crop_size: |
| | padh = crop_size - oh if oh < crop_size else 0 |
| | padw = crop_size - ow if ow < crop_size else 0 |
| | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) |
| | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) |
| | |
| | w, h = img.size |
| | x1 = random.randint(0, w - crop_size) |
| | y1 = random.randint(0, h - crop_size) |
| | img = img.crop((x1, y1, x1+crop_size, y1+crop_size)) |
| | mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size)) |
| | |
| | return img, self._mask_transform(mask) |
| |
|
| | def _mask_transform(self, mask): |
| | return torch.from_numpy(np.array(mask)).long() |
| |
|
| |
|