| import os |
| from glob import glob |
|
|
| import cv2 |
| import h5py |
| import numpy as np |
| import torch |
| import torch.utils.data as data |
| from PIL import Image, ImageFilter |
| from torchvision.datasets import ImageNet |
|
|
|
|
| class ImageNet_blur(ImageNet): |
| def __getitem__(self, index): |
| """ |
| Args: |
| index (int): Index |
| |
| Returns: |
| tuple: (sample, target) where target is class_index of the target class. |
| """ |
| path, target = self.samples[index] |
| sample = self.loader(path) |
|
|
| gauss_blur = ImageFilter.GaussianBlur(11) |
| median_blur = ImageFilter.MedianFilter(11) |
|
|
| blurred_img1 = sample.filter(gauss_blur) |
| blurred_img2 = sample.filter(median_blur) |
| blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5) |
|
|
| if self.transform is not None: |
| sample = self.transform(sample) |
| blurred_img = self.transform(blurred_img) |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
|
|
| return (sample, blurred_img), target |
|
|
|
|
| class Imagenet_Segmentation(data.Dataset): |
| CLASSES = 2 |
|
|
| def __init__(self, path, transform=None, target_transform=None): |
| self.path = path |
| self.transform = transform |
| self.target_transform = target_transform |
| |
| self.h5py = None |
| tmp = h5py.File(path, "r") |
| self.data_length = len(tmp["/value/img"]) |
| tmp.close() |
| del tmp |
|
|
| def __getitem__(self, index): |
| if self.h5py is None: |
| self.h5py = h5py.File(self.path, "r") |
|
|
| img = np.array(self.h5py[self.h5py["/value/img"][index, 0]]).transpose( |
| (2, 1, 0) |
| ) |
| target = np.array( |
| self.h5py[self.h5py[self.h5py["/value/gt"][index, 0]][0, 0]] |
| ).transpose((1, 0)) |
|
|
| img = Image.fromarray(img).convert("RGB") |
| target = Image.fromarray(target) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| if self.target_transform is not None: |
| target = np.array(self.target_transform(target)).astype("int32") |
| target = torch.from_numpy(target).long() |
|
|
| return img, target |
|
|
| def __len__(self): |
| |
| return self.data_length |
|
|
|
|
| class Imagenet_Segmentation_Blur(data.Dataset): |
| CLASSES = 2 |
|
|
| def __init__(self, path, transform=None, target_transform=None): |
| self.path = path |
| self.transform = transform |
| self.target_transform = target_transform |
| |
| self.h5py = None |
| tmp = h5py.File(path, "r") |
| self.data_length = len(tmp["/value/img"]) |
| tmp.close() |
| del tmp |
|
|
| def __getitem__(self, index): |
| if self.h5py is None: |
| self.h5py = h5py.File(self.path, "r") |
|
|
| img = np.array(self.h5py[self.h5py["/value/img"][index, 0]]).transpose( |
| (2, 1, 0) |
| ) |
| target = np.array( |
| self.h5py[self.h5py[self.h5py["/value/gt"][index, 0]][0, 0]] |
| ).transpose((1, 0)) |
|
|
| img = Image.fromarray(img).convert("RGB") |
| target = Image.fromarray(target) |
|
|
| gauss_blur = ImageFilter.GaussianBlur(11) |
| median_blur = ImageFilter.MedianFilter(11) |
|
|
| blurred_img1 = img.filter(gauss_blur) |
| blurred_img2 = img.filter(median_blur) |
| blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5) |
|
|
| |
| |
| |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
| blurred_img = self.transform(blurred_img) |
|
|
| if self.target_transform is not None: |
| target = np.array(self.target_transform(target)).astype("int32") |
| target = torch.from_numpy(target).long() |
|
|
| return (img, blurred_img), target |
|
|
| def __len__(self): |
| |
| return self.data_length |
|
|
|
|
| class Imagenet_Segmentation_eval_dir(data.Dataset): |
| CLASSES = 2 |
|
|
| def __init__(self, path, eval_path, transform=None, target_transform=None): |
| self.transform = transform |
| self.target_transform = target_transform |
| self.h5py = h5py.File(path, "r+") |
|
|
| |
| self.results = glob(os.path.join(eval_path, "*.npy")) |
|
|
| def __getitem__(self, index): |
| img = np.array(self.h5py[self.h5py["/value/img"][index, 0]]).transpose( |
| (2, 1, 0) |
| ) |
| target = np.array( |
| self.h5py[self.h5py[self.h5py["/value/gt"][index, 0]][0, 0]] |
| ).transpose((1, 0)) |
| res = np.load(self.results[index]) |
|
|
| img = Image.fromarray(img).convert("RGB") |
| target = Image.fromarray(target) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| if self.target_transform is not None: |
| target = np.array(self.target_transform(target)).astype("int32") |
| target = torch.from_numpy(target).long() |
|
|
| return img, target |
|
|
| def __len__(self): |
| return len(self.h5py["/value/img"]) |
|
|
|
|
| if __name__ == "__main__": |
| import scipy.io as sio |
| import torchvision.transforms as transforms |
| from imageio import imsave |
| from tqdm import tqdm |
|
|
| |
| |
| normalize = transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ) |
| test_img_trans = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| normalize, |
| ] |
| ) |
| test_lbl_trans = transforms.Compose( |
| [ |
| transforms.Resize((224, 224), Image.NEAREST), |
| ] |
| ) |
|
|
| ds = Imagenet_Segmentation( |
| "/home/shirgur/ext/Data/Datasets/imagenet-seg/other/gtsegs_ijcv.mat", |
| transform=test_img_trans, |
| target_transform=test_lbl_trans, |
| ) |
|
|
| for i, (img, tgt) in enumerate(tqdm(ds)): |
| tgt = (tgt.numpy() * 255).astype(np.uint8) |
| imsave("/home/shirgur/ext/Code/C2S/run/imagenet/gt/{}.png".format(i), tgt) |
|
|
| print("here") |
|
|