| print("Importing external...") |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from timm.models.efficientvit_mit import ( |
| ConvNormAct, |
| FusedMBConv, |
| MBConv, |
| ResidualBlock, |
| efficientvit_l1, |
| ) |
| from timm.layers import GELUTanh |
|
|
|
|
| def val2list(x: list or tuple or any, repeat_time=1): |
| if isinstance(x, (list, tuple)): |
| return list(x) |
| return [x for _ in range(repeat_time)] |
|
|
|
|
| def resize( |
| x: torch.Tensor, |
| size: any or None = None, |
| scale_factor: list[float] or None = None, |
| mode: str = "bicubic", |
| align_corners: bool or None = False, |
| ) -> torch.Tensor: |
| if mode in {"bilinear", "bicubic"}: |
| return F.interpolate( |
| x, |
| size=size, |
| scale_factor=scale_factor, |
| mode=mode, |
| align_corners=align_corners, |
| ) |
| elif mode in {"nearest", "area"}: |
| return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) |
| else: |
| raise NotImplementedError(f"resize(mode={mode}) not implemented.") |
|
|
|
|
| class UpSampleLayer(nn.Module): |
| def __init__( |
| self, |
| mode="bicubic", |
| size: int or tuple[int, int] or list[int] or None = None, |
| factor=2, |
| align_corners=False, |
| ): |
| super(UpSampleLayer, self).__init__() |
| self.mode = mode |
| self.size = val2list(size, 2) if size is not None else None |
| self.factor = None if self.size is not None else factor |
| self.align_corners = align_corners |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if ( |
| self.size is not None and tuple(x.shape[-2:]) == self.size |
| ) or self.factor == 1: |
| return x |
| return resize(x, self.size, self.factor, self.mode, self.align_corners) |
|
|
|
|
| class DAGBlock(nn.Module): |
| def __init__( |
| self, |
| inputs: dict[str, nn.Module], |
| merge: str, |
| post_input: nn.Module or None, |
| middle: nn.Module, |
| outputs: dict[str, nn.Module], |
| ): |
| super(DAGBlock, self).__init__() |
|
|
| self.input_keys = list(inputs.keys()) |
| self.input_ops = nn.ModuleList(list(inputs.values())) |
| self.merge = merge |
| self.post_input = post_input |
|
|
| self.middle = middle |
|
|
| self.output_keys = list(outputs.keys()) |
| self.output_ops = nn.ModuleList(list(outputs.values())) |
|
|
| def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| feat = [ |
| op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops) |
| ] |
| if self.merge == "add": |
| feat = list_sum(feat) |
| elif self.merge == "cat": |
| feat = torch.concat(feat, dim=1) |
| else: |
| raise NotImplementedError |
| if self.post_input is not None: |
| feat = self.post_input(feat) |
| feat = self.middle(feat) |
| for key, op in zip(self.output_keys, self.output_ops): |
| feature_dict[key] = op(feat) |
| return feature_dict |
|
|
|
|
| def list_sum(x: list) -> any: |
| return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) |
|
|
|
|
| class SegHead(nn.Module): |
| def __init__( |
| self, |
| fid_list: list[str], |
| in_channel_list: list[int], |
| stride_list: list[int], |
| head_stride: int, |
| head_width: int, |
| head_depth: int, |
| expand_ratio: float, |
| middle_op: str, |
| final_expand: float or None, |
| n_classes: int, |
| dropout=0, |
| norm="bn2d", |
| act_func="hswish", |
| ): |
| super(SegHead, self).__init__() |
| |
| if act_func == "gelu": |
| act_func = GELUTanh |
| else: |
| raise ValueError(f"act_func {act_func} not supported") |
| if norm == "bn2d": |
| norm_layer = nn.BatchNorm2d |
| else: |
| raise ValueError(f"norm {norm} not supported") |
|
|
| inputs = {} |
| for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list): |
| factor = stride // head_stride |
| if factor == 1: |
| inputs[fid] = ConvNormAct( |
| in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func |
| ) |
| else: |
| inputs[fid] = nn.Sequential( |
| ConvNormAct( |
| in_channel, |
| head_width, |
| 1, |
| norm_layer=norm_layer, |
| act_layer=act_func, |
| ), |
| UpSampleLayer(factor=factor), |
| ) |
| self.in_keys = inputs.keys() |
| self.in_ops = nn.ModuleList(inputs.values()) |
|
|
| middle = [] |
| for _ in range(head_depth): |
| if middle_op == "mbconv": |
| block = MBConv( |
| head_width, |
| head_width, |
| expand_ratio=expand_ratio, |
| norm_layer=norm_layer, |
| act_layer=(act_func, act_func, None), |
| ) |
| elif middle_op == "fmbconv": |
| block = FusedMBConv( |
| head_width, |
| head_width, |
| expand_ratio=expand_ratio, |
| norm_layer=norm_layer, |
| act_layer=(act_func, None), |
| ) |
| else: |
| raise NotImplementedError |
| middle.append(ResidualBlock(block, nn.Identity())) |
| self.middle = nn.Sequential(*middle) |
|
|
| self.out_layer = nn.Sequential( |
| *[ |
| None |
| if final_expand is None |
| else ConvNormAct( |
| head_width, |
| head_width * final_expand, |
| 1, |
| norm_layer=norm_layer, |
| act_layer=act_func, |
| ), |
| ConvNormAct( |
| head_width * (final_expand or 1), |
| n_classes, |
| 1, |
| bias=True, |
| dropout=dropout, |
| norm_layer=None, |
| act_layer=None, |
| ), |
| ] |
| ) |
|
|
| def forward(self, feature_map_list): |
| t_feat_maps = [ |
| self.in_ops[ind](feature_map_list[ind]) |
| for ind in range(len(feature_map_list)) |
| ] |
| t_feat_map = list_sum(t_feat_maps) |
| t_feat_map = self.middle(t_feat_map) |
| out = self.out_layer(t_feat_map) |
| return out |
|
|
|
|
| class EfficientViT_l1_r224(nn.Module): |
| def __init__( |
| self, |
| out_channels, |
| out_ds_factor=1, |
| decoder_size="small", |
| pretrained=False, |
| use_norm_params=False, |
| ): |
| if decoder_size == "small": |
| head_width = 32 |
| head_depth = 1 |
| middle_op = "mbconv" |
| elif decoder_size == "medium": |
| head_width = 64 |
| head_depth = 3 |
| middle_op = "mbconv" |
| elif decoder_size == "large": |
| head_width = 256 |
| head_depth = 3 |
| middle_op = "fmbconv" |
|
|
| super(EfficientViT_l1_r224, self).__init__() |
| self.bbone = efficientvit_l1( |
| num_classes=0, features_only=True, pretrained=pretrained |
| ) |
| self.head = SegHead( |
| fid_list=["stage4", "stage3", "stage2"], |
| in_channel_list=[512, 256, 128], |
| stride_list=[32, 16, 8], |
| head_stride=out_ds_factor, |
| head_width=head_width, |
| head_depth=head_depth, |
| expand_ratio=4, |
| middle_op=middle_op, |
| final_expand=8, |
| n_classes=out_channels, |
| act_func="gelu", |
| ) |
| |
| if not use_norm_params: |
| for module in self.modules(): |
| if ( |
| isinstance(module, nn.LayerNorm) |
| or isinstance(module, nn.BatchNorm2d) |
| or isinstance(module, nn.BatchNorm1d) |
| ): |
| module.weight.requires_grad_(False) |
| module.bias.requires_grad_(False) |
|
|
| def forward(self, x): |
| feat = self.bbone(x) |
| out = self.head([feat[3], feat[2], feat[1]]) |
| return out |
|
|