| import os |
| import torch |
| from openslide import OpenSlide |
| from utils.preprocessor import MacenkoNormalizer, preprocessor |
| from torch.utils.data import Dataset |
|
|
|
|
| class WSIPatchDataset(Dataset): |
| def __init__( |
| self, |
| coords, |
| wsi_path, |
| pretrained=False, |
| patch_size=256, |
| patch_level=0, |
| macenko=True, |
| return_coord=False, |
| ): |
| self.pretrained = pretrained |
| self.wsi = OpenSlide(wsi_path) |
| self.patch_size = patch_size |
| self.patch_level = patch_level |
| self.return_coord = return_coord |
|
|
| if macenko: |
| normalizer = MacenkoNormalizer( |
| target_path=os.path.join( |
| os.path.dirname(os.path.dirname(os.path.join(__file__))), |
| "macenko_target", |
| "macenko_param.pt", |
| ) |
| ) |
| else: |
| normalizer = None |
|
|
| self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer) |
| self.coords = coords |
| self.length = len(self.coords) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, idx): |
| coord = self.coords[idx] |
| img = self.wsi.read_region( |
| coord, self.patch_level, (self.patch_size, self.patch_size) |
| ).convert("RGB") |
| img = self.roi_transforms(img) |
| if self.return_coord: |
| return img, torch.tensor(coord) |
| else: |
| return img |
|
|