| """ |
| Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. |
| """ |
|
|
| import torch |
| from segment_anything import sam_model_registry |
| torch.backends.cuda.matmul.allow_tf32 = True |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| class Transformer(nn.Module): |
| def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4, |
| checkpoint=None, dtype=torch.float32): |
| super(Transformer, self).__init__() |
| """ |
| print(self.encoder.patch_embed) |
| PatchEmbed( |
| (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16)) |
| ) |
| print(self.encoder.neck) |
| Sequential( |
| (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) |
| (1): LayerNorm2d() |
| (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) |
| (3): LayerNorm2d() |
| ) |
| """ |
| |
| |
| self.encoder = sam_model_registry[backbone](checkpoint).image_encoder |
| w = self.encoder.patch_embed.proj.weight.detach() |
| nchan = w.shape[0] |
| |
| |
| self.ps = ps |
| self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) |
| self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps] |
| |
| |
| ds = (1024 // 16) // (bsize // ps) |
| self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True) |
|
|
| |
| |
| self.nout = nout |
| self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1) |
|
|
| |
| self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps), |
| requires_grad=False) |
| |
| |
| self.rdrop = rdrop |
|
|
| |
| self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False) |
| |
| self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False) |
| |
| |
| for blk in self.encoder.blocks: |
| blk.window_size = 0 |
|
|
| self.dtype = dtype |
|
|
| def forward(self, x, feat=None): |
| |
| x = self.encoder.patch_embed(x) |
| if feat is not None: |
| feat = self.encoder.patch_embed(feat) |
| x = x + x * feat * 0.5 |
| |
| if self.encoder.pos_embed is not None: |
| x = x + self.encoder.pos_embed |
| |
| if self.training and self.rdrop > 0: |
| nlay = len(self.encoder.blocks) |
| rdrop = (torch.rand((len(x), nlay), device=x.device) < |
| torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype) |
| for i, blk in enumerate(self.encoder.blocks): |
| mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| x = x * mask + blk(x) * (1-mask) |
| else: |
| for blk in self.encoder.blocks: |
| x = blk(x) |
|
|
| x = self.encoder.neck(x.permute(0, 3, 1, 2)) |
|
|
| |
| x1 = self.out(x) |
| x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0) |
| |
| |
| |
| return x1, torch.randn((x.shape[0], 256), device=x.device) |
| |
| def load_model(self, PATH, device, strict = False): |
| state_dict = torch.load(PATH, map_location = device, weights_only=True) |
| keys = [k for k in state_dict.keys()] |
| if keys[0][:7] == "module.": |
| from collections import OrderedDict |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k[7:] |
| new_state_dict[name] = v |
| self.load_state_dict(new_state_dict, strict = strict) |
| else: |
| self.load_state_dict(state_dict, strict = strict) |
|
|
| if self.dtype != torch.float32: |
| self = self.to(self.dtype) |
|
|
| |
| @property |
| def device(self): |
| """ |
| Get the device of the model. |
| |
| Returns: |
| torch.device: The device of the model. |
| """ |
| return next(self.parameters()).device |
|
|
| def save_model(self, filename): |
| """ |
| Save the model to a file. |
| |
| Args: |
| filename (str): The path to the file where the model will be saved. |
| """ |
| torch.save(self.state_dict(), filename) |
|
|
|
|