|
|
| import os |
| from torch.utils.data import Dataset |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| import cv2 |
| import torch |
| import numpy as np |
| from torch.nn import functional as F |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from utils import train_transforms, get_boxes_from_mask, init_point_sampling |
| import json |
| import random |
|
|
|
|
| class TestingDataset(Dataset): |
| |
| def __init__(self, data_path, image_size=256, mode='test', requires_name=True, point_num=1, return_ori_mask=True, prompt_path=None): |
| """ |
| Initializes a TestingDataset object. |
| Args: |
| data_path (str): The path to the data. |
| image_size (int, optional): The size of the image. Defaults to 256. |
| mode (str, optional): The mode of the dataset. Defaults to 'test'. |
| requires_name (bool, optional): Indicates whether the dataset requires image names. Defaults to True. |
| point_num (int, optional): The number of points to retrieve. Defaults to 1. |
| return_ori_mask (bool, optional): Indicates whether to return the original mask. Defaults to True. |
| prompt_path (str, optional): The path to the prompt file. Defaults to None. |
| """ |
| self.image_size = image_size |
| self.return_ori_mask = return_ori_mask |
| self.prompt_path = prompt_path |
| self.prompt_list = {} if prompt_path is None else json.load(open(prompt_path, "r")) |
| self.requires_name = requires_name |
| self.point_num = point_num |
|
|
| json_file = open(os.path.join(data_path, f'label2image_{mode}.json'), "r") |
| dataset = json.load(json_file) |
| |
| self.image_paths = list(dataset.values()) |
| self.label_paths = list(dataset.keys()) |
| |
| self.pixel_mean = [123.675, 116.28, 103.53] |
| self.pixel_std = [58.395, 57.12, 57.375] |
| |
| def __getitem__(self, index): |
| """ |
| Retrieves and preprocesses an item from the dataset. |
| Args: |
| index (int): The index of the item to retrieve. |
| Returns: |
| dict: A dictionary containing the preprocessed image and associated information. |
| """ |
| image_input = {} |
| try: |
| image = cv2.imread(self.image_paths[index]) |
| image = (image - self.pixel_mean) / self.pixel_std |
| except: |
| print(self.image_paths[index]) |
|
|
| mask_path = self.label_paths[index] |
| ori_np_mask = cv2.imread(mask_path, 0) |
| |
| if ori_np_mask.max() == 255: |
| ori_np_mask = ori_np_mask / 255 |
|
|
| assert np.array_equal(ori_np_mask, ori_np_mask.astype(bool)), f"Mask should only contain binary values 0 and 1. {self.label_paths[index]}" |
|
|
| h, w = ori_np_mask.shape |
| ori_mask = torch.tensor(ori_np_mask).unsqueeze(0) |
|
|
| transforms = train_transforms(self.image_size, h, w) |
| augments = transforms(image=image, mask=ori_np_mask) |
| image, mask = augments['image'], augments['mask'].to(torch.int64) |
|
|
| if self.prompt_path is None: |
| boxes = get_boxes_from_mask(mask) |
| point_coords, point_labels = init_point_sampling(mask, self.point_num) |
| else: |
| prompt_key = mask_path.split('/')[-1] |
| boxes = torch.as_tensor(self.prompt_list[prompt_key]["boxes"], dtype=torch.float) |
| point_coords = torch.as_tensor(self.prompt_list[prompt_key]["point_coords"], dtype=torch.float) |
| point_labels = torch.as_tensor(self.prompt_list[prompt_key]["point_labels"], dtype=torch.int) |
|
|
| image_input["image"] = image |
| image_input["label"] = mask.unsqueeze(0) |
| image_input["point_coords"] = point_coords |
| image_input["point_labels"] = point_labels |
| image_input["boxes"] = boxes |
| image_input["original_size"] = (h, w) |
| image_input["label_path"] = '/'.join(mask_path.split('/')[:-1]) |
|
|
| if self.return_ori_mask: |
| image_input["ori_label"] = ori_mask |
| |
| image_name = self.label_paths[index].split('/')[-1] |
| if self.requires_name: |
| image_input["name"] = image_name |
| return image_input |
| else: |
| return image_input |
|
|
| def __len__(self): |
| return len(self.label_paths) |
|
|
|
|
| if __name__ == "__main__": |
| test_dataset = TestingDataset("data_demo", image_size = 256, mode='test', requires_name = True, point_num=1, return_ori_mask=True, prompt_path = None) |
| print("Dataset:", len(test_dataset)) |
| test_batch_sampler = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4) |
| for i, batched_image in enumerate(tqdm(test_batch_sampler)): |
| for k,v in batched_image.items(): |
| print(k, v) |
|
|
|
|