| from cmath import nan |
| import csv |
| import json |
| import logging |
| import os |
| import sys |
| import pydicom |
|
|
| from abc import abstractmethod |
| from itertools import islice |
| from typing import List, Tuple, Dict, Any |
| from torch.utils.data import DataLoader |
| import PIL |
| from torch.utils.data import Dataset |
| import numpy as np |
| import pandas as pd |
| from torchvision import transforms |
| from PIL import Image |
| from skimage import exposure |
| import torch |
|
|
| |
| from torchvision.transforms import InterpolationMode |
|
|
|
|
| class RSNA2018_Dataset(Dataset): |
| def __init__(self, csv_path): |
| data_info = pd.read_csv(csv_path) |
| self.img_path_list = np.asarray(data_info.iloc[:, 1]) |
| self.class_list = np.asarray(data_info.iloc[:, 3]) |
| self.bbox = np.asarray(data_info.iloc[:, 2]) |
| normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
|
| self.transform = transforms.Compose( |
| [transforms.Resize([224, 224]), transforms.ToTensor(), normalize,] |
| ) |
| self.seg_transfrom = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Resize([224, 224], interpolation=InterpolationMode.NEAREST), |
| ] |
| ) |
|
|
| def __getitem__(self, index): |
| img_path = self.img_path_list[index] |
| class_label = np.array([self.class_list[index]]) |
|
|
| img = self.read_dcm(img_path) |
| image = self.transform(img) |
|
|
| bbox = self.bbox[index] |
| seg_map = np.zeros((1024, 1024)) |
| if class_label == 1: |
| boxes = bbox.split("|") |
| for box in boxes: |
| cc = box.split(";") |
| seg_map[ |
| int(float(cc[1])) : (int(float(cc[1])) + int(float(cc[3]))), |
| int(float(cc[0])) : (int(float(cc[0])) + int(float(cc[2]))), |
| ] = 1 |
| seg_map = self.seg_transfrom(seg_map) |
| return { |
| "image": image, |
| "label": class_label, |
| "image_path": img_path, |
| "seg_map": seg_map, |
| } |
|
|
| def read_dcm(self, dcm_path): |
| dcm_data = pydicom.read_file(dcm_path) |
| img = dcm_data.pixel_array.astype(float) / 255.0 |
| img = exposure.equalize_hist(img) |
|
|
| img = (255 * img).astype(np.uint8) |
| img = PIL.Image.fromarray(img).convert("RGB") |
| return img |
|
|
| def __len__(self): |
| return len(self.img_path_list) |
|
|
|
|
| def create_loader_RSNA( |
| datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
| ): |
| loaders = [] |
| for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( |
| datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
| ): |
| if is_train: |
| shuffle = sampler is None |
| drop_last = True |
| else: |
| shuffle = False |
| drop_last = False |
| loader = DataLoader( |
| dataset, |
| batch_size=bs, |
| num_workers=n_worker, |
| pin_memory=True, |
| sampler=sampler, |
| shuffle=shuffle, |
| collate_fn=collate_fn, |
| drop_last=drop_last, |
| ) |
| loaders.append(loader) |
| return loaders |
|
|