| from tqdm import tqdm |
| import numpy as np |
| import os |
| import torch.nn as nn |
| import torch.optim as optim |
| import torch.utils.data as data |
| import torchvision.transforms as transforms |
| import medmnist |
| from medmnist import INFO, Evaluator |
| from PIL import Image |
| from torch.utils.data import Dataset |
| import matplotlib.pyplot as plt |
| from medmnist.utils import montage2d |
| from medimeta import MedIMeta |
|
|
| class DataClass(Dataset): |
| def __init__(self, root, transform=None, size=224): |
| """ |
| Args: |
| root (str): Path to the .npz file (e.g., 'data_root/breastminst.npz'). |
| transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. |
| size (int, optional): Image size. Defaults to 224. |
| """ |
| if not os.path.exists(root): |
| raise FileNotFoundError(f"Dataset file not found at {root}") |
|
|
| self.root = root |
| self.transform = transform |
| self.size = size |
|
|
| |
| npz_file = np.load(self.root, mmap_mode="r") |
| self.imgs = npz_file["images"] |
| self.labels = npz_file["labels"] |
|
|
| |
| self.n_channels = 3 if len(self.imgs.shape) == 4 and self.imgs.shape[-1] == 3 else 1 |
|
|
| def __len__(self): |
| return self.imgs.shape[0] |
|
|
| def __getitem__(self, index): |
| """ |
| Returns: |
| img (PIL.Image): Image loaded and transformed (if applicable). |
| target (int/array): Corresponding label. |
| """ |
| img, target = self.imgs[index], self.labels[index].astype(int) |
| img = Image.fromarray(img) |
|
|
| if self.transform: |
| img = self.transform(img) |
|
|
| return img, target |
|
|
| def montage(self, length=10, replace=False, save_folder=None): |
| """ |
| Create a montage of randomly selected images. |
| |
| Args: |
| length (int): Number of images per row and column (default=10). |
| replace (bool): Whether to allow selecting the same image multiple times. |
| save_folder (str, optional): If provided, saves the montage image. |
| |
| Returns: |
| PIL.Image: The generated montage. |
| """ |
| n_sel = length * length |
| indices = np.arange(n_sel) % len(self) |
|
|
| |
| montage_img = montage2d(imgs=self.imgs, n_channels=self.n_channels, sel=indices) |
|
|
| |
| if save_folder: |
| os.makedirs(save_folder, exist_ok=True) |
| save_path = os.path.join(save_folder, "montage1.jpg") |
| montage_img.save(save_path) |
| print(f"Montage saved at {save_path}") |
|
|
| return montage_img |
|
|
| def build_medmnist_dataset(data_root, transform): |
| dataset = DataClass(root=data_root, transform=transform, size=224) |
| return dataset |
|
|
| def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None): |
| dataset = MedIMeta(data_root, task, disease, transform=transform) |
| return dataset |