| | from rscd.models.backbones.vmamba import VSSM, LayerNorm2d |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class Backbone_VSSM(VSSM): |
| | def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer='ln2d', **kwargs): |
| | |
| | kwargs.update(norm_layer=norm_layer) |
| | super().__init__(**kwargs) |
| | self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) |
| | _NORMLAYERS = dict( |
| | ln=nn.LayerNorm, |
| | ln2d=LayerNorm2d, |
| | bn=nn.BatchNorm2d, |
| | ) |
| | norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) |
| | |
| | self.out_indices = out_indices |
| | for i in out_indices: |
| | layer = norm_layer(self.dims[i]) |
| | layer_name = f'outnorm{i}' |
| | self.add_module(layer_name, layer) |
| |
|
| | del self.classifier |
| | self.load_pretrained(pretrained) |
| |
|
| | def load_pretrained(self, ckpt=None, key="model"): |
| | if ckpt is None: |
| | return |
| | |
| | try: |
| | _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) |
| | print(f"Successfully load ckpt {ckpt}") |
| | incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) |
| | print(incompatibleKeys) |
| | except Exception as e: |
| | print(f"Failed loading checkpoint form {ckpt}: {e}") |
| |
|
| | def forward(self, x): |
| | def layer_forward(l, x): |
| | x = l.blocks(x) |
| | y = l.downsample(x) |
| | return x, y |
| |
|
| | x = self.patch_embed(x) |
| | outs = [] |
| | for i, layer in enumerate(self.layers): |
| | o, x = layer_forward(layer, x) |
| | if i in self.out_indices: |
| | norm_layer = getattr(self, f'outnorm{i}') |
| | out = norm_layer(o) |
| | if not self.channel_first: |
| | out = out.permute(0, 3, 1, 2).contiguous() |
| | outs.append(out) |
| |
|
| | if len(self.out_indices) == 0: |
| | return x |
| | |
| | return outs |
| | |
| | class CMBackbone(nn.Module): |
| | def __init__(self, pretrained, **kwargs): |
| | super(CMBackbone, self).__init__() |
| | self.encoder = Backbone_VSSM(out_indices=(0, 1, 2, 3), pretrained=pretrained, **kwargs) |
| |
|
| | def forward(self, pre_data, post_data): |
| | |
| | pre_features = self.encoder(pre_data) |
| | post_features = self.encoder(post_data) |
| |
|
| | return [pre_features, post_features, pre_data.size()[-2:]] |