| import os |
| import random |
|
|
| import torch |
| from PIL import Image |
| from torchstain.base.normalizers.he_normalizer import HENormalizer |
| from torchstain.torch.utils import cov, percentile |
| from torchvision import transforms |
| from torchvision.transforms.functional import to_pil_image |
|
|
|
|
| def preprocessor(pretrained=False, normalizer=None): |
| if pretrained: |
| mean = (0.485, 0.456, 0.406) |
| std = (0.229, 0.224, 0.225) |
| else: |
| mean = (0.5, 0.5, 0.5) |
| std = (0.5, 0.5, 0.5) |
|
|
| preprocess = transforms.Compose( |
| [ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.Lambda(lambda x: x) if normalizer == None else normalizer, |
| transforms.ToTensor(), |
| transforms.Normalize(mean=mean, std=std), |
| ] |
| ) |
|
|
| return preprocess |
|
|
|
|
| """ |
| Source code ported from: https://github.com/schaugf/HEnorm_python |
| Original implementation: https://github.com/mitkovetta/staining-normalization |
| """ |
|
|
|
|
| class TorchMacenkoNormalizer(HENormalizer): |
| def __init__(self): |
| super().__init__() |
|
|
| self.HERef = torch.tensor( |
| [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]] |
| ) |
| self.maxCRef = torch.tensor([1.9705, 1.0308]) |
|
|
| |
| self.updated_lstsq = hasattr(torch.linalg, "lstsq") |
|
|
| def __convert_rgb2od(self, I, Io, beta): |
| I = I.permute(1, 2, 0) |
|
|
| |
| OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io) |
|
|
| |
| ODhat = OD[~torch.any(OD < beta, dim=1)] |
|
|
| return OD, ODhat |
|
|
| def __find_HE(self, ODhat, eigvecs, alpha): |
| |
| |
| That = torch.matmul(ODhat, eigvecs) |
| phi = torch.atan2(That[:, 1], That[:, 0]) |
| |
|
|
| minPhi = percentile(phi, alpha) |
| maxPhi = percentile(phi, 100 - alpha) |
|
|
| vMin = torch.matmul( |
| eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))) |
| ).unsqueeze(1) |
| vMax = torch.matmul( |
| eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))) |
| ).unsqueeze(1) |
|
|
| |
| |
| HE = torch.where( |
| vMin[0] > vMax[0], |
| torch.cat((vMin, vMax), dim=1), |
| torch.cat((vMax, vMin), dim=1), |
| ) |
|
|
| return HE |
|
|
| def __find_concentration(self, OD, HE): |
| |
| Y = OD.T |
|
|
| |
| if not self.updated_lstsq: |
| return torch.lstsq(Y, HE)[0][:2] |
|
|
| return torch.linalg.lstsq(HE, Y)[0] |
|
|
| def __compute_matrices(self, I, Io, alpha, beta): |
| OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) |
|
|
| |
| _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) |
| eigvecs = eigvecs[:, [1, 2]] |
|
|
| HE = self.__find_HE(ODhat, eigvecs, alpha) |
|
|
| C = self.__find_concentration(OD, HE) |
| maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) |
|
|
| return HE, C, maxC |
|
|
| def fit(self, I, Io=240, alpha=1, beta=0.15): |
| HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) |
|
|
| self.HERef = HE |
| self.maxCRef = maxC |
|
|
| def normalize( |
| self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int" |
| ): |
| """Normalize staining appearence of H&E stained images |
| |
| Example use: |
| see test.py |
| |
| Input: |
| I: RGB input image: tensor of shape [C, H, W] and type uint8 |
| Io: (optional) transmitted light intensity |
| alpha: percentile |
| beta: transparency threshold |
| stains: if true, return also H & E components |
| |
| Output: |
| Inorm: normalized image |
| H: hematoxylin image |
| E: eosin image |
| |
| Reference: |
| A method for normalizing histology slides for quantitative analysis. M. |
| Macenko et al., ISBI 2009 |
| """ |
|
|
| c, h, w = I.shape |
|
|
| HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) |
|
|
| |
| C *= (self.maxCRef / maxC).unsqueeze(-1) |
|
|
| |
| Inorm = Io * torch.exp(-torch.matmul(self.HERef, C)) |
| Inorm = torch.clip(Inorm, 0, 255) |
|
|
| Inorm = Inorm.reshape(c, h, w).float() / 255.0 |
| Inorm = torch.clip(Inorm, 0.0, 1.0) |
|
|
| H, E = None, None |
|
|
| if stains: |
| H = torch.mul( |
| Io, |
| torch.exp( |
| torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0)) |
| ), |
| ) |
| H[H > 255] = 255 |
| H = H.T.reshape(h, w, c).int() |
|
|
| E = torch.mul( |
| Io, |
| torch.exp( |
| torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0)) |
| ), |
| ) |
| E[E > 255] = 255 |
| E = E.T.reshape(h, w, c).int() |
|
|
| return Inorm, H, E |
|
|
|
|
| class MacenkoNormalizer: |
| def __init__(self, target_path=None, prob=1): |
| self.transform_before_macenko = transforms.Compose( |
| [transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)] |
| ) |
| self.normalizer = TorchMacenkoNormalizer() |
|
|
| ext = os.path.splitext(target_path)[1].lower() |
| if ext in [".jpg", ".jpeg", ".png"]: |
| target = Image.open(target_path) |
| self.normalizer.fit(self.transform_before_macenko(target)) |
| elif ext in [".pt"]: |
| target = torch.load(target_path) |
| self.normalizer.HERef = target["HERef"] |
| self.normalizer.maxCRef = target["maxCRef"] |
|
|
| else: |
| raise ValueError(f"Invalid extension: {ext}") |
| self.prob = prob |
|
|
| def __call__(self, image): |
| t_to_transform = self.transform_before_macenko(image) |
| try: |
| image_macenko, _, _ = self.normalizer.normalize( |
| I=t_to_transform, stains=False, form="chw", dtype="float" |
| ) |
| if torch.any(torch.isnan(image_macenko)): |
| return image |
| else: |
| image_macenko = to_pil_image(image_macenko) |
| return image_macenko |
| except Exception as e: |
| if "kthvalue()" in str(e) or "linalg.eigh" in str(e): |
| pass |
| else: |
| print(str(e)) |
| return image |
|
|