| |
| import torch.nn as nn |
| from mmcv.cnn import ConvModule, build_norm_layer |
|
|
| from mmseg.registry import MODELS |
|
|
|
|
| class MLAModule(nn.Module): |
|
|
| def __init__(self, |
| in_channels=[1024, 1024, 1024, 1024], |
| out_channels=256, |
| norm_cfg=None, |
| act_cfg=None): |
| super().__init__() |
| self.channel_proj = nn.ModuleList() |
| for i in range(len(in_channels)): |
| self.channel_proj.append( |
| ConvModule( |
| in_channels=in_channels[i], |
| out_channels=out_channels, |
| kernel_size=1, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg)) |
| self.feat_extract = nn.ModuleList() |
| for i in range(len(in_channels)): |
| self.feat_extract.append( |
| ConvModule( |
| in_channels=out_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| padding=1, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg)) |
|
|
| def forward(self, inputs): |
|
|
| |
| feat_list = [] |
| for x, conv in zip(inputs, self.channel_proj): |
| feat_list.append(conv(x)) |
|
|
| |
| |
| feat_list = feat_list[::-1] |
| mid_list = [] |
| for feat in feat_list: |
| if len(mid_list) == 0: |
| mid_list.append(feat) |
| else: |
| mid_list.append(mid_list[-1] + feat) |
|
|
| |
| |
| out_list = [] |
| for mid, conv in zip(mid_list, self.feat_extract): |
| out_list.append(conv(mid)) |
|
|
| return tuple(out_list) |
|
|
|
|
| @MODELS.register_module() |
| class MLANeck(nn.Module): |
| """Multi-level Feature Aggregation. |
| |
| This neck is `The Multi-level Feature Aggregation construction of |
| SETR <https://arxiv.org/abs/2012.15840>`_. |
| |
| |
| Args: |
| in_channels (List[int]): Number of input channels per scale. |
| out_channels (int): Number of output channels (used at each scale). |
| norm_layer (dict): Config dict for input normalization. |
| Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). |
| norm_cfg (dict): Config dict for normalization layer. Default: None. |
| act_cfg (dict): Config dict for activation layer in ConvModule. |
| Default: None. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), |
| norm_cfg=None, |
| act_cfg=None): |
| super().__init__() |
| assert isinstance(in_channels, list) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| |
| |
| self.norm = nn.ModuleList([ |
| build_norm_layer(norm_layer, in_channels[i])[1] |
| for i in range(len(in_channels)) |
| ]) |
|
|
| self.mla = MLAModule( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg) |
|
|
| def forward(self, inputs): |
| assert len(inputs) == len(self.in_channels) |
|
|
| |
| outs = [] |
| for i in range(len(inputs)): |
| x = inputs[i] |
| n, c, h, w = x.shape |
| x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() |
| x = self.norm[i](x) |
| x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() |
| outs.append(x) |
|
|
| outs = self.mla(outs) |
| return tuple(outs) |
|
|