| import argparse |
| import json |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| import torchio as tio |
| from torchvision.utils import save_image |
| from matplotlib.pyplot import get_cmap |
|
|
| from models import MSTRegression |
|
|
|
|
|
|
| def minmax_norm(x): |
| """Normalizes input to [0, 1] for each batch and channel""" |
| return (x - x.min()) / (x.max() - x.min()) |
|
|
| def tensor2image(tensor, batch=0): |
| """Transform tensor into shape of multiple 2D RGB/gray images. """ |
| return (tensor if tensor.ndim<5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:,None]) |
|
|
| def tensor_cam2image(tensor, cam, batch=0, alpha=0.5, color_map=get_cmap('jet')): |
| """Transform a tensor and a (grad) cam into multiple 2D RGB images.""" |
| img = tensor2image(tensor, batch) |
| img = torch.cat([img for _ in range(3)], dim=1) if img.shape[1]!=3 else img |
| cam_img = tensor2image(cam, batch) |
| cam_img = cam_img[:,0].cpu().numpy() |
| cam_img = torch.tensor(color_map(cam_img)) |
| cam_img = torch.moveaxis(cam_img, -1, 1)[:, :3] |
| overlay = (1-alpha)*img + alpha*cam_img |
| return overlay |
|
|
|
|
|
|
| def crop_breast_height(image, margin_top=10) -> tio.Crop: |
| """Crop height to 256 and try to cover breast based on intensity localization""" |
| threshold = int(np.quantile(image.data.float(), 0.9)) |
| foreground = image.data>threshold |
| fg_rows = foreground[0].sum(axis=(0, 2)) |
| top = min(max(512-int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256) |
| bottom = 256-top |
| return tio.Crop((0,0, bottom, top, 0, 0)) |
|
|
|
|
| def get_bilateral_transform(img:tio.ScalarImage, ref_img=None, target_spacing = (0.7, 0.7, 3), target_shape = (512, 512, 32)): |
| |
| ref_img = img if ref_img is None else ref_img |
| |
| |
| ref_img = tio.ToCanonical()(ref_img) |
| ref_img = tio.Resample(target_spacing)(ref_img) |
| resample = tio.Resample(ref_img) |
|
|
| |
| ref_img = tio.CropOrPad(target_shape, padding_mode='minimum')(ref_img) |
| crop_height = crop_breast_height(ref_img) |
|
|
| |
| trans = tio.Compose([ |
| resample, |
| tio.CropOrPad(target_shape, padding_mode='minimum'), |
| crop_height, |
| ]) |
|
|
| trans_inv = tio.Compose([ |
| crop_height.inverse(), |
| tio.CropOrPad(img.spatial_shape, padding_mode='minimum'), |
| tio.Resample(img), |
| ]) |
| return trans(img), trans_inv |
|
|
| def get_unilateral_transform(img: tio.ScalarImage, target_shape=(224, 224, 32)): |
| transform = tio.Compose([ |
| tio.Flip((1,0)), |
| tio.CropOrPad(target_shape), |
| tio.ZNormalization(masking_method=lambda x:(x>x.min()) & (x<x.max())), |
| ]) |
| inv_transform = tio.Compose([ |
| tio.CropOrPad(img.spatial_shape), |
| tio.Flip((1,0)), |
| ]) |
| return transform(img), inv_transform |
|
|
|
|
| def run_prediction(img: tio.ScalarImage, model: MSTRegression): |
| img_bil, bil_trans_rev = get_bilateral_transform(img) |
| split_side = { |
| 'right': tio.Crop((256, 0, 0, 0, 0, 0)), |
| 'left': tio.Crop((0, 256, 0, 0, 0, 0)), |
| } |
|
|
| weights, probs = {}, {} |
| for side, crop in split_side.items(): |
| img_side = crop(img_bil) |
| img_side, uni_trans_inv = get_unilateral_transform(img_side) |
| img_side = img_side.data.swapaxes(1,-1) |
| img_side = img_side.unsqueeze(0) |
|
|
| with torch.no_grad(): |
| device = next(model.parameters()).device |
| logits, weight, weight_slice = model.forward_attention(img_side.to(device)) |
|
|
| weight = F.interpolate(weight.unsqueeze(1), size=img_side.shape[2:], mode='trilinear', align_corners=False).cpu() |
| |
| pred_prob = F.softmax(logits, dim=-1).cpu() |
| probs[side] = pred_prob.squeeze(0) |
|
|
| weight = weight.squeeze(0).swapaxes(1,-1) |
| weight = uni_trans_inv(weight) |
| weights[side] = weight |
|
|
| weight = torch.concat([weights['left'], weights['right']], dim=1) |
| weight = tio.ScalarImage(tensor=weight, affine=img_bil.affine) |
| weight = bil_trans_rev(weight) |
| weight.set_data(minmax_norm(weight.data)) |
| return probs, weight |
|
|
| def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression: |
| |
| config_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="model_config.json") |
| with open(config_path, "r", encoding="utf-8") as fp: |
| config = json.load(fp) |
|
|
| hparams = config.get("hparams", {}) |
| model = MSTRegression(weights=False, **hparams) |
|
|
| state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt") |
| state_dict = torch.load(state_dict_path, map_location="cpu") |
| model.load_state_dict(state_dict, strict=True) |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--path_img', default='/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/UKA/data/UKA_2/Sub_1.nii.gz', type=str) |
| args = parser.parse_args() |
|
|
|
|
| |
| path_out_dir = Path().cwd()/'results/test_attention' |
| path_out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| |
| path_img = Path(args.path_img) |
| img = tio.ScalarImage(path_img) |
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = load_model() |
| model.to(device) |
| model.eval() |
|
|
|
|
| |
| probs, weight = run_prediction(img, model) |
|
|
| img.save(path_out_dir/f"input.nii.gz") |
| weight.save(path_out_dir/f"attention.nii.gz") |
| weight = weight.data.swapaxes(1,-1).unsqueeze(0) |
| img = img.data.swapaxes(1,-1).unsqueeze(0) |
| save_image(tensor_cam2image(minmax_norm(img), minmax_norm(weight), alpha=0.5), |
| path_out_dir/f"overlay.png", normalize=False) |
| |
| for side in ['left', 'right']: |
| print(f"{side} breast predicted probabilities: {probs[side]}") |
| |
|
|