| import cv2 |
| from PIL import Image |
| import numpy as np |
| import os |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
|
|
|
|
| def preprocess_img(img_dir, channels=3): |
|
|
| if channels == 1: |
| img = cv2.imread(img_dir, 0) |
| elif channels == 3: |
| img = cv2.imread(img_dir) |
|
|
| shape_r = 288 |
| shape_c = 384 |
| img_padded = np.ones((shape_r, shape_c, channels), dtype=np.uint8) |
| if channels == 1: |
| img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8) |
| original_shape = img.shape |
| rows_rate = original_shape[0] / shape_r |
| cols_rate = original_shape[1] / shape_c |
| if rows_rate > cols_rate: |
| new_cols = (original_shape[1] * shape_r) // original_shape[0] |
| img = cv2.resize(img, (new_cols, shape_r)) |
| if new_cols > shape_c: |
| new_cols = shape_c |
| img_padded[:, |
| ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img |
| else: |
| new_rows = (original_shape[0] * shape_c) // original_shape[1] |
| img = cv2.resize(img, (shape_c, new_rows)) |
|
|
| if new_rows > shape_r: |
| new_rows = shape_r |
| img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), |
| :] = img |
|
|
| return img_padded |
|
|
|
|
| def postprocess_img(pred, org_dir): |
| pred = np.array(pred) |
| org = cv2.imread(org_dir, 0) |
| shape_r = org.shape[0] |
| shape_c = org.shape[1] |
| predictions_shape = pred.shape |
|
|
| rows_rate = shape_r / predictions_shape[0] |
| cols_rate = shape_c / predictions_shape[1] |
|
|
| if rows_rate > cols_rate: |
| new_cols = (predictions_shape[1] * shape_r) // predictions_shape[0] |
| pred = cv2.resize(pred, (new_cols, shape_r)) |
| img = pred[:, ((pred.shape[1] - shape_c) // 2):((pred.shape[1] - shape_c) // 2 + shape_c)] |
| else: |
| new_rows = (predictions_shape[0] * shape_c) // predictions_shape[1] |
| pred = cv2.resize(pred, (shape_c, new_rows)) |
| img = pred[((pred.shape[0] - shape_r) // 2):((pred.shape[0] - shape_r) // 2 + shape_r), :] |
|
|
| return img |
|
|
|
|
| class MyDataset(Dataset): |
| """Load dataset.""" |
|
|
| def __init__(self, ids, stimuli_dir, saliency_dir, fixation_dir, transform=None): |
| """ |
| Args: |
| csv_file (string): Path to the csv file with annotations. |
| root_dir (string): Directory with all the images. |
| transform (callable, optional): Optional transform to be applied |
| on a sample. |
| """ |
| self.ids = ids |
| self.stimuli_dir = stimuli_dir |
| self.saliency_dir = saliency_dir |
| self.fixation_dir = fixation_dir |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.ids) |
|
|
| def __getitem__(self, idx): |
| if torch.is_tensor(idx): |
| idx = idx.tolist() |
|
|
| im_path = self.stimuli_dir + self.ids.iloc[idx, 0] |
| image = Image.open(im_path).convert('RGB') |
| img = np.array(image) / 255. |
| img = np.transpose(img, (2, 0, 1)) |
| img = torch.from_numpy(img) |
| |
| |
|
|
| smap_path = self.saliency_dir + self.ids.iloc[idx, 1] |
| saliency = Image.open(smap_path) |
|
|
| smap = np.expand_dims(np.array(saliency) / 255., axis=0) |
| smap = torch.from_numpy(smap) |
|
|
| fmap_path = self.fixation_dir + self.ids.iloc[idx, 2] |
| fixation = Image.open(fmap_path) |
|
|
| fmap = np.expand_dims(np.array(fixation) / 255., axis=0) |
| fmap = torch.from_numpy(fmap) |
|
|
| sample = {'image': img, 'saliency': smap, 'fixation': fmap} |
|
|
| return sample |
|
|
|
|
|
|
|
|
|
|
|
|