| import numpy as np |
| import torch |
| import torch.nn as nn |
| from models.aggregator import make_model |
| from models.cls_modules import LinearClassifier |
| from datasets.dataset_WSI import WSIPatchDataset |
| from models.feature_extractor import vit_base |
| from utils.wsi_utils import extract_tissue_patch_coords |
| from torch.utils.data import DataLoader |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin): |
| def __init__( |
| self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True |
| ): |
| super(EXAONEPathV1p5Downstream, self).__init__() |
| self.step_size = step_size |
| self.patch_size = patch_size |
| self.macenko = macenko |
| self.num_sampled_patch = num_sampled_patch |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| self.config = { |
| "step_size": step_size, |
| "patch_size": patch_size, |
| "macenko": macenko, |
| "num_sampled_patch": num_sampled_patch, |
| } |
|
|
| self.feature_extractor = vit_base() |
| self.feature_extractor = self.feature_extractor |
| |
| |
|
|
| self.agg_model = make_model( |
| embed_dim=768, |
| droprate=0.0, |
| num_registers=0, |
| depth=4, |
| ) |
| |
| self.agg_model = LinearClassifier(self.agg_model, pool='mean') |
| |
| |
| |
|
|
| @torch.no_grad() |
| def forward(self, svs_path: str, feature_extractor_batch_size: int = 8): |
| |
| coords = extract_tissue_patch_coords( |
| svs_path, patch_size=self.patch_size, step_size=self.step_size |
| ) |
|
|
| |
| self.feature_extractor.eval() |
| patch_dataset = WSIPatchDataset( |
| coords=coords, |
| wsi_path=svs_path, |
| pretrained=True, |
| macenko=self.macenko, |
| patch_size=self.patch_size, |
| return_coord=True, |
| ) |
| patch_loader = DataLoader( |
| dataset=patch_dataset, |
| batch_size=feature_extractor_batch_size, |
| num_workers=( |
| feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 |
| ), |
| pin_memory=self.device.type == "cuda", |
|
|
| ) |
| features_list = [] |
| coords_list = [] |
| for count, items in enumerate(patch_loader): |
| patches, coords = items |
| print( |
| f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", |
| end="\r", |
| ) |
| patches = patches.to(self.device, non_blocking=True) |
|
|
| feature = self.feature_extractor(patches) |
| feature /= feature.norm(dim=-1, keepdim=True) |
| feature = feature.to("cpu", non_blocking=True) |
| features_list.append(feature) |
|
|
| coords = coords.to(self.device, non_blocking=True) |
| coords_list.append(coords) |
|
|
| print("") |
| print("Feature extraction finished") |
|
|
| features = torch.cat(features_list) |
| coords = torch.cat(coords_list) |
| total_samples = features.shape[0] |
|
|
| num_samples = min(self.num_sampled_patch, total_samples) |
| indices = torch.randperm(total_samples)[:num_samples] |
| sampled_features = features[indices] |
| sampled_coords = coords[indices] |
| |
| |
| self.agg_model.eval() |
|
|
| |
| |
|
|
| logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device)) |
| probs = Y_prob[0].cpu() |
|
|
| return probs |
|
|
| @torch.no_grad() |
| def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8): |
| |
| coords = extract_tissue_patch_coords( |
| svs_path, patch_size=self.patch_size, step_size=self.step_size |
| ) |
|
|
| |
| self.feature_extractor.eval() |
| patch_dataset = WSIPatchDataset( |
| coords=coords, |
| wsi_path=svs_path, |
| pretrained=True, |
| macenko=self.macenko, |
| patch_size=self.patch_size, |
| return_coord=False |
| ) |
| patch_loader = DataLoader( |
| dataset=patch_dataset, |
| batch_size=feature_extractor_batch_size, |
| num_workers=( |
| feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 |
| ), |
| pin_memory=self.device.type == "cuda", |
| ) |
| features_list = [] |
| for count, patches in enumerate(patch_loader): |
| print( |
| f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", |
| end="\r", |
| ) |
| patches = patches.to(self.device, non_blocking=True) |
|
|
| feature = self.feature_extractor(patches) |
| feature /= feature.norm(dim=-1, keepdim=True) |
| feature = feature.to("cpu", non_blocking=True) |
| features_list.append(feature) |
| print("") |
| print("Feature extraction finished") |
|
|
| features = torch.cat(features_list) |
|
|
| return features |