| import os |
| from os import path |
|
|
| import torch |
| from torch.utils.data.dataset import Dataset |
| from torchvision import transforms |
| from torchvision.transforms import InterpolationMode |
| from PIL import Image |
| import numpy as np |
|
|
| from dataset.range_transform import im_normalization, im_mean |
| from dataset.tps import random_tps_warp |
| from dataset.reseed import reseed |
|
|
|
|
| class StaticTransformDataset(Dataset): |
| """ |
| Generate pseudo VOS data by applying random transforms on static images. |
| Single-object only. |
| |
| Method 0 - FSS style (class/1.jpg class/1.png) |
| Method 1 - Others style (XXX.jpg XXX.png) |
| """ |
| def __init__(self, parameters, num_frames=3, max_num_obj=1): |
| self.num_frames = num_frames |
| self.max_num_obj = max_num_obj |
|
|
| self.im_list = [] |
| for parameter in parameters: |
| root, method, multiplier = parameter |
| if method == 0: |
| |
| classes = os.listdir(root) |
| for c in classes: |
| imgs = os.listdir(path.join(root, c)) |
| jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] |
|
|
| joint_list = [path.join(root, c, im) for im in jpg_list] |
| self.im_list.extend(joint_list * multiplier) |
|
|
| elif method == 1: |
| self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier) |
|
|
| print(f'{len(self.im_list)} images found.') |
|
|
| |
| self.pair_im_lone_transform = transforms.Compose([ |
| transforms.ColorJitter(0.1, 0.05, 0.05, 0), |
| ]) |
|
|
| self.pair_im_dual_transform = transforms.Compose([ |
| transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean), |
| transforms.Resize(384, InterpolationMode.BICUBIC), |
| transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean), |
| ]) |
|
|
| self.pair_gt_dual_transform = transforms.Compose([ |
| transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0), |
| transforms.Resize(384, InterpolationMode.NEAREST), |
| transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0), |
| ]) |
|
|
|
|
| |
| self.all_im_lone_transform = transforms.Compose([ |
| transforms.ColorJitter(0.1, 0.05, 0.05, 0.05), |
| transforms.RandomGrayscale(0.05), |
| ]) |
|
|
| self.all_im_dual_transform = transforms.Compose([ |
| transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean), |
| transforms.RandomHorizontalFlip(), |
| ]) |
|
|
| self.all_gt_dual_transform = transforms.Compose([ |
| transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0), |
| transforms.RandomHorizontalFlip(), |
| ]) |
|
|
| |
| self.final_im_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| im_normalization, |
| ]) |
|
|
| self.final_gt_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| ]) |
|
|
| def _get_sample(self, idx): |
| im = Image.open(self.im_list[idx]).convert('RGB') |
| gt = Image.open(self.im_list[idx][:-3]+'png').convert('L') |
|
|
| sequence_seed = np.random.randint(2147483647) |
|
|
| images = [] |
| masks = [] |
| for _ in range(self.num_frames): |
| reseed(sequence_seed) |
| this_im = self.all_im_dual_transform(im) |
| this_im = self.all_im_lone_transform(this_im) |
| reseed(sequence_seed) |
| this_gt = self.all_gt_dual_transform(gt) |
|
|
| pairwise_seed = np.random.randint(2147483647) |
| reseed(pairwise_seed) |
| this_im = self.pair_im_dual_transform(this_im) |
| this_im = self.pair_im_lone_transform(this_im) |
| reseed(pairwise_seed) |
| this_gt = self.pair_gt_dual_transform(this_gt) |
|
|
| |
| |
| if np.random.rand() < 0.33: |
| this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02) |
|
|
| this_im = self.final_im_transform(this_im) |
| this_gt = self.final_gt_transform(this_gt) |
|
|
| images.append(this_im) |
| masks.append(this_gt) |
|
|
| images = torch.stack(images, 0) |
| masks = torch.stack(masks, 0) |
|
|
| return images, masks.numpy() |
|
|
| def __getitem__(self, idx): |
| additional_objects = np.random.randint(self.max_num_obj) |
| indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)] |
|
|
| merged_images = None |
| merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int) |
|
|
| for i, list_id in enumerate(indices): |
| images, masks = self._get_sample(list_id) |
| if merged_images is None: |
| merged_images = images |
| else: |
| merged_images = merged_images*(1-masks) + images*masks |
| merged_masks[masks[:,0]>0.5] = (i+1) |
|
|
| masks = merged_masks |
|
|
| labels = np.unique(masks[0]) |
| |
| labels = labels[labels!=0] |
| target_objects = labels.tolist() |
|
|
| |
| cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int) |
| first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int) |
| for i, l in enumerate(target_objects): |
| this_mask = (masks==l) |
| cls_gt[this_mask] = i+1 |
| first_frame_gt[0,i] = (this_mask[0]) |
| cls_gt = np.expand_dims(cls_gt, 1) |
|
|
| info = {} |
| info['name'] = self.im_list[idx] |
| info['num_objects'] = max(1, len(target_objects)) |
|
|
| |
| selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] |
| selector = torch.FloatTensor(selector) |
|
|
| data = { |
| 'rgb': merged_images, |
| 'first_frame_gt': first_frame_gt, |
| 'cls_gt': cls_gt, |
| 'selector': selector, |
| 'info': info |
| } |
|
|
| return data |
|
|
|
|
| def __len__(self): |
| return len(self.im_list) |
|
|