| | from einops import rearrange |
| | import torch.nn as nn |
| | import torch |
| | import math |
| | from transformers import AutoModel, Dinov2WithRegistersModel, Dinov2WithRegistersConfig, DINOv3ViTConfig, DINOv3ViTModel |
| | from x_transformers import Encoder |
| |
|
| |
|
| | class _MST(nn.Module): |
| | def __init__( |
| | self, |
| | out_ch=1, |
| | backbone_type="dinov3", |
| | model_size = "s", |
| | slice_fusion_type = "transformer", |
| | weights=True, |
| | ): |
| | super().__init__() |
| | self.backbone_type = backbone_type |
| | self.slice_fusion_type = slice_fusion_type |
| |
|
| | if backbone_type == "dinov2": |
| | model_size_key = {'s':'small', 'b':'base', 'l':'large'}.get(model_size) |
| | model_name = f"facebook/dinov2-with-registers-{model_size_key}" |
| | if weights: |
| | self.backbone = AutoModel.from_pretrained(model_name) |
| | else: |
| | configs = { |
| | 'small': Dinov2WithRegistersConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6), |
| | 'base': Dinov2WithRegistersConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12), |
| | 'large': Dinov2WithRegistersConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16), |
| | } |
| | config = configs.get(model_size_key) |
| | config.image_size = 518 |
| | config.patch_size = 14 |
| | self.backbone = Dinov2WithRegistersModel(config) |
| | emb_ch = self.backbone.config.hidden_size |
| | elif backbone_type == "dinov3": |
| | model_name = f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m" |
| | if weights: |
| | self.backbone = AutoModel.from_pretrained(model_name) |
| | else: |
| | configs = { |
| | 's': DINOv3ViTConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, intermediate_size=1536, patch_size=16, num_register_tokens=4), |
| | 'b': DINOv3ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, patch_size=16, num_register_tokens=4), |
| | 'l': DINOv3ViTConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, patch_size=16, num_register_tokens=4), |
| | } |
| | config = configs.get(model_size) |
| | self.backbone = DINOv3ViTModel(config) |
| | else: |
| | raise ValueError("Unknown backbone_type") |
| |
|
| | emb_ch = self.backbone.config.hidden_size |
| | self.emb_ch = emb_ch |
| | if slice_fusion_type == "transformer": |
| | self.slice_fusion = Encoder( |
| | dim = emb_ch, |
| | heads = 12 if emb_ch%12 == 0 else 8, |
| | ff_mult = 1, |
| | attn_dropout=0.0, |
| | pre_norm = True, |
| | depth = 1, |
| | attn_flash = True, |
| | ff_no_bias = True, |
| | rotary_pos_emb=True, |
| | ) |
| | self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch)) |
| | elif slice_fusion_type == 'average': |
| | pass |
| | elif slice_fusion_type == "none": |
| | pass |
| | else: |
| | raise ValueError("Unknown slice_fusion_type") |
| |
|
| | self.linear = nn.Linear(emb_ch, out_ch) |
| |
|
| |
|
| |
|
| | def forward(self, x, output_attentions=False): |
| | B, *_ = x.shape |
| |
|
| | |
| | x_pad = torch.isclose(x.mean(dim=(-1,-2)), x[:, :, :, 0, 0]) |
| | x_pad = rearrange(x_pad, 'b c d -> b (c d)') |
| |
|
| | x = rearrange(x, 'b c d h w -> (b c d) h w') |
| | x = x[:, None] |
| | x = x.repeat(1, 3, 1, 1) |
| |
|
| | |
| | backbone_out = self.backbone(x, output_attentions=output_attentions) |
| | x = backbone_out.pooler_output |
| | x = rearrange(x, '(b d) e -> b d e', b=B) |
| | |
| | |
| | if self.slice_fusion_type == 'none': |
| | return x |
| | elif self.slice_fusion_type == 'transformer': |
| | cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=x.device) |
| | pad = torch.concat([x_pad, cls_pad], dim=1) |
| | x = torch.concat([x, self.cls_token.repeat(B, 1, 1)], dim=1) |
| | if output_attentions: |
| | x, slice_hiddens = self.slice_fusion(x, mask=~pad, return_hiddens=True) |
| | else: |
| | x = self.slice_fusion(x, mask=~pad) |
| | elif self.slice_fusion_type == 'linear': |
| | x = rearrange(x, 'b d e -> b e d') |
| | x = self.slice_fusion(x) |
| | x = rearrange(x, 'b e d -> b d e') |
| | elif self.slice_fusion_type == 'average': |
| | x = x.mean(dim=1, keepdim=True) |
| |
|
| | |
| | x = self.linear(x[:, -1]) |
| | if output_attentions: |
| | slice_attn_layers = [ |
| | interm.post_softmax_attn |
| | for interm in getattr(slice_hiddens, 'attn_intermediates', []) |
| | if interm is not None and getattr(interm, 'post_softmax_attn', None) is not None |
| | ] |
| | return x, backbone_out.attentions, slice_attn_layers |
| | return x |
| |
|
| | def forward_attention(self, x) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | B, C, D, _, _ = x.shape |
| | |
| | attn_impl = self.backbone.config._attn_implementation |
| | self.backbone.set_attn_implementation("eager") |
| | flash_modules = [] |
| | for module in self.slice_fusion.modules(): |
| | if hasattr(module, 'flash'): |
| | flash_modules.append((module, module.flash)) |
| | module.flash = False |
| |
|
| | out, backbone_attn, slice_attn_layers = self.forward(x, output_attentions=True) |
| |
|
| | |
| | for module, previous in flash_modules: |
| | module.flash = previous |
| | if hasattr(self.backbone, "set_attn_implementation"): |
| | self.backbone.set_attn_implementation(attn_impl) |
| |
|
| | |
| | slice_attn = torch.stack(slice_attn_layers)[-1] |
| | slice_attn = slice_attn.mean(dim=1) |
| | slice_attn = slice_attn[:, -1, :-1] |
| | slice_attn = slice_attn.view(B, C, D).mean(dim=1) |
| |
|
| | plane_attn_layers = [att for att in backbone_attn if att is not None] |
| | plane_attn = torch.stack(plane_attn_layers)[-1] |
| | plane_attn = plane_attn.mean(dim=1) |
| | num_reg_tokens = getattr(self.backbone.config, 'num_register_tokens', 0) |
| | plane_attn = plane_attn[:, 0, 1 + num_reg_tokens:] |
| | plane_attn = plane_attn.view(B, C * D, -1) |
| |
|
| | |
| | plane_attn = plane_attn * slice_attn.unsqueeze(-1) |
| |
|
| | num_patches = plane_attn.shape[-1] |
| | side = int(math.sqrt(num_patches)) |
| | if side * side != num_patches: |
| | raise RuntimeError("number of patches is not a perfect square") |
| | plane_attn = plane_attn.reshape(B, C * D, side, side) |
| |
|
| | return out, plane_attn, slice_attn |
| |
|
| |
|
| | class MSTRegression(nn.Module): |
| | def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", weights=True, **kwargs): |
| | super().__init__() |
| | self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type, weights=weights) |
| | |
| | def forward(self, x): |
| | return self.mst(x) |
| | |
| | def forward_attention(self, x): |
| | return self.mst.forward_attention(x) |