| |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule |
| from mmengine.model import BaseModule |
|
|
| from mmdet.registry import MODELS |
|
|
|
|
| class Transition(BaseModule): |
| """Base class for transition. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| """ |
|
|
| def __init__(self, in_channels, out_channels, init_cfg=None): |
| super().__init__(init_cfg) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| def forward(x): |
| pass |
|
|
|
|
| class UpInterpolationConv(Transition): |
| """A transition used for up-sampling. |
| |
| Up-sample the input by interpolation then refines the feature by |
| a convolution layer. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| scale_factor (int): Up-sampling factor. Default: 2. |
| mode (int): Interpolation mode. Default: nearest. |
| align_corners (bool): Whether align corners when interpolation. |
| Default: None. |
| kernel_size (int): Kernel size for the conv. Default: 3. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| scale_factor=2, |
| mode='nearest', |
| align_corners=None, |
| kernel_size=3, |
| init_cfg=None, |
| **kwargs): |
| super().__init__(in_channels, out_channels, init_cfg) |
| self.mode = mode |
| self.scale_factor = scale_factor |
| self.align_corners = align_corners |
| self.conv = ConvModule( |
| in_channels, |
| out_channels, |
| kernel_size, |
| padding=(kernel_size - 1) // 2, |
| **kwargs) |
|
|
| def forward(self, x): |
| x = F.interpolate( |
| x, |
| scale_factor=self.scale_factor, |
| mode=self.mode, |
| align_corners=self.align_corners) |
| x = self.conv(x) |
| return x |
|
|
|
|
| class LastConv(Transition): |
| """A transition used for refining the output of the last stage. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| num_inputs (int): Number of inputs of the FPN features. |
| kernel_size (int): Kernel size for the conv. Default: 3. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| num_inputs, |
| kernel_size=3, |
| init_cfg=None, |
| **kwargs): |
| super().__init__(in_channels, out_channels, init_cfg) |
| self.num_inputs = num_inputs |
| self.conv_out = ConvModule( |
| in_channels, |
| out_channels, |
| kernel_size, |
| padding=(kernel_size - 1) // 2, |
| **kwargs) |
|
|
| def forward(self, inputs): |
| assert len(inputs) == self.num_inputs |
| return self.conv_out(inputs[-1]) |
|
|
|
|
| @MODELS.register_module() |
| class FPG(BaseModule): |
| """FPG. |
| |
| Implementation of `Feature Pyramid Grids (FPG) |
| <https://arxiv.org/abs/2004.03580>`_. |
| This implementation only gives the basic structure stated in the paper. |
| But users can implement different type of transitions to fully explore the |
| the potential power of the structure of FPG. |
| |
| Args: |
| in_channels (int): Number of input channels (feature maps of all levels |
| should have the same channels). |
| out_channels (int): Number of output channels (used at each scale) |
| num_outs (int): Number of output scales. |
| stack_times (int): The number of times the pyramid architecture will |
| be stacked. |
| paths (list[str]): Specify the path order of each stack level. |
| Each element in the list should be either 'bu' (bottom-up) or |
| 'td' (top-down). |
| inter_channels (int): Number of inter channels. |
| same_up_trans (dict): Transition that goes down at the same stage. |
| same_down_trans (dict): Transition that goes up at the same stage. |
| across_lateral_trans (dict): Across-pathway same-stage |
| across_down_trans (dict): Across-pathway bottom-up connection. |
| across_up_trans (dict): Across-pathway top-down connection. |
| across_skip_trans (dict): Across-pathway skip connection. |
| output_trans (dict): Transition that trans the output of the |
| last stage. |
| start_level (int): Index of the start input backbone level used to |
| build the feature pyramid. Default: 0. |
| end_level (int): Index of the end input backbone level (exclusive) to |
| build the feature pyramid. Default: -1, which means the last level. |
| add_extra_convs (bool): It decides whether to add conv |
| layers on top of the original feature maps. Default to False. |
| If True, its actual mode is specified by `extra_convs_on_inputs`. |
| norm_cfg (dict): Config dict for normalization layer. Default: None. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| """ |
|
|
| transition_types = { |
| 'conv': ConvModule, |
| 'interpolation_conv': UpInterpolationConv, |
| 'last_conv': LastConv, |
| } |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| num_outs, |
| stack_times, |
| paths, |
| inter_channels=None, |
| same_down_trans=None, |
| same_up_trans=dict( |
| type='conv', kernel_size=3, stride=2, padding=1), |
| across_lateral_trans=dict(type='conv', kernel_size=1), |
| across_down_trans=dict(type='conv', kernel_size=3), |
| across_up_trans=None, |
| across_skip_trans=dict(type='identity'), |
| output_trans=dict(type='last_conv', kernel_size=3), |
| start_level=0, |
| end_level=-1, |
| add_extra_convs=False, |
| norm_cfg=None, |
| skip_inds=None, |
| init_cfg=[ |
| dict(type='Caffe2Xavier', layer='Conv2d'), |
| dict( |
| type='Constant', |
| layer=[ |
| '_BatchNorm', '_InstanceNorm', 'GroupNorm', |
| 'LayerNorm' |
| ], |
| val=1.0) |
| ]): |
| super(FPG, self).__init__(init_cfg) |
| assert isinstance(in_channels, list) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.num_ins = len(in_channels) |
| self.num_outs = num_outs |
| if inter_channels is None: |
| self.inter_channels = [out_channels for _ in range(num_outs)] |
| elif isinstance(inter_channels, int): |
| self.inter_channels = [inter_channels for _ in range(num_outs)] |
| else: |
| assert isinstance(inter_channels, list) |
| assert len(inter_channels) == num_outs |
| self.inter_channels = inter_channels |
| self.stack_times = stack_times |
| self.paths = paths |
| assert isinstance(paths, list) and len(paths) == stack_times |
| for d in paths: |
| assert d in ('bu', 'td') |
|
|
| self.same_down_trans = same_down_trans |
| self.same_up_trans = same_up_trans |
| self.across_lateral_trans = across_lateral_trans |
| self.across_down_trans = across_down_trans |
| self.across_up_trans = across_up_trans |
| self.output_trans = output_trans |
| self.across_skip_trans = across_skip_trans |
|
|
| self.with_bias = norm_cfg is None |
| |
| if self.across_skip_trans is not None: |
| skip_inds is not None |
| self.skip_inds = skip_inds |
| assert len(self.skip_inds[0]) <= self.stack_times |
|
|
| if end_level == -1 or end_level == self.num_ins - 1: |
| self.backbone_end_level = self.num_ins |
| assert num_outs >= self.num_ins - start_level |
| else: |
| |
| self.backbone_end_level = end_level + 1 |
| assert end_level < self.num_ins |
| assert num_outs == end_level - start_level + 1 |
| self.start_level = start_level |
| self.end_level = end_level |
| self.add_extra_convs = add_extra_convs |
|
|
| |
| self.lateral_convs = nn.ModuleList() |
| for i in range(self.start_level, self.backbone_end_level): |
| l_conv = nn.Conv2d(self.in_channels[i], |
| self.inter_channels[i - self.start_level], 1) |
| self.lateral_convs.append(l_conv) |
|
|
| extra_levels = num_outs - self.backbone_end_level + self.start_level |
| self.extra_downsamples = nn.ModuleList() |
| for i in range(extra_levels): |
| if self.add_extra_convs: |
| fpn_idx = self.backbone_end_level - self.start_level + i |
| extra_conv = nn.Conv2d( |
| self.inter_channels[fpn_idx - 1], |
| self.inter_channels[fpn_idx], |
| 3, |
| stride=2, |
| padding=1) |
| self.extra_downsamples.append(extra_conv) |
| else: |
| self.extra_downsamples.append(nn.MaxPool2d(1, stride=2)) |
|
|
| self.fpn_transitions = nn.ModuleList() |
| for s in range(self.stack_times): |
| stage_trans = nn.ModuleList() |
| for i in range(self.num_outs): |
| |
| trans = nn.ModuleDict() |
| if s in self.skip_inds[i]: |
| stage_trans.append(trans) |
| continue |
| |
| if i == 0 or self.same_up_trans is None: |
| same_up_trans = None |
| else: |
| same_up_trans = self.build_trans( |
| self.same_up_trans, self.inter_channels[i - 1], |
| self.inter_channels[i]) |
| trans['same_up'] = same_up_trans |
| |
| if i == self.num_outs - 1 or self.same_down_trans is None: |
| same_down_trans = None |
| else: |
| same_down_trans = self.build_trans( |
| self.same_down_trans, self.inter_channels[i + 1], |
| self.inter_channels[i]) |
| trans['same_down'] = same_down_trans |
| |
| across_lateral_trans = self.build_trans( |
| self.across_lateral_trans, self.inter_channels[i], |
| self.inter_channels[i]) |
| trans['across_lateral'] = across_lateral_trans |
| |
| if i == self.num_outs - 1 or self.across_down_trans is None: |
| across_down_trans = None |
| else: |
| across_down_trans = self.build_trans( |
| self.across_down_trans, self.inter_channels[i + 1], |
| self.inter_channels[i]) |
| trans['across_down'] = across_down_trans |
| |
| if i == 0 or self.across_up_trans is None: |
| across_up_trans = None |
| else: |
| across_up_trans = self.build_trans( |
| self.across_up_trans, self.inter_channels[i - 1], |
| self.inter_channels[i]) |
| trans['across_up'] = across_up_trans |
| if self.across_skip_trans is None: |
| across_skip_trans = None |
| else: |
| across_skip_trans = self.build_trans( |
| self.across_skip_trans, self.inter_channels[i - 1], |
| self.inter_channels[i]) |
| trans['across_skip'] = across_skip_trans |
| |
| stage_trans.append(trans) |
| self.fpn_transitions.append(stage_trans) |
|
|
| self.output_transition = nn.ModuleList() |
| for i in range(self.num_outs): |
| trans = self.build_trans( |
| self.output_trans, |
| self.inter_channels[i], |
| self.out_channels, |
| num_inputs=self.stack_times + 1) |
| self.output_transition.append(trans) |
|
|
| self.relu = nn.ReLU(inplace=True) |
|
|
| def build_trans(self, cfg, in_channels, out_channels, **extra_args): |
| cfg_ = cfg.copy() |
| trans_type = cfg_.pop('type') |
| trans_cls = self.transition_types[trans_type] |
| return trans_cls(in_channels, out_channels, **cfg_, **extra_args) |
|
|
| def fuse(self, fuse_dict): |
| out = None |
| for item in fuse_dict.values(): |
| if item is not None: |
| if out is None: |
| out = item |
| else: |
| out = out + item |
| return out |
|
|
| def forward(self, inputs): |
| assert len(inputs) == len(self.in_channels) |
|
|
| |
| feats = [ |
| lateral_conv(inputs[i + self.start_level]) |
| for i, lateral_conv in enumerate(self.lateral_convs) |
| ] |
| for downsample in self.extra_downsamples: |
| feats.append(downsample(feats[-1])) |
|
|
| outs = [feats] |
|
|
| for i in range(self.stack_times): |
| current_outs = outs[-1] |
| next_outs = [] |
| direction = self.paths[i] |
| for j in range(self.num_outs): |
| if i in self.skip_inds[j]: |
| next_outs.append(outs[-1][j]) |
| continue |
| |
| if direction == 'td': |
| lvl = self.num_outs - j - 1 |
| else: |
| lvl = j |
| |
| if direction == 'td': |
| same_trans = self.fpn_transitions[i][lvl]['same_down'] |
| else: |
| same_trans = self.fpn_transitions[i][lvl]['same_up'] |
| across_lateral_trans = self.fpn_transitions[i][lvl][ |
| 'across_lateral'] |
| across_down_trans = self.fpn_transitions[i][lvl]['across_down'] |
| across_up_trans = self.fpn_transitions[i][lvl]['across_up'] |
| across_skip_trans = self.fpn_transitions[i][lvl]['across_skip'] |
| |
| to_fuse = dict( |
| same=None, lateral=None, across_up=None, across_down=None) |
| |
| if same_trans is not None: |
| to_fuse['same'] = same_trans(next_outs[-1]) |
| |
| if across_lateral_trans is not None: |
| to_fuse['lateral'] = across_lateral_trans( |
| current_outs[lvl]) |
| |
| if lvl > 0 and across_up_trans is not None: |
| to_fuse['across_up'] = across_up_trans(current_outs[lvl - |
| 1]) |
| |
| if (lvl < self.num_outs - 1 and across_down_trans is not None): |
| to_fuse['across_down'] = across_down_trans( |
| current_outs[lvl + 1]) |
| if across_skip_trans is not None: |
| to_fuse['across_skip'] = across_skip_trans(outs[0][lvl]) |
| x = self.fuse(to_fuse) |
| next_outs.append(x) |
|
|
| if direction == 'td': |
| outs.append(next_outs[::-1]) |
| else: |
| outs.append(next_outs) |
|
|
| |
| final_outs = [] |
| for i in range(self.num_outs): |
| lvl_out_list = [] |
| for s in range(len(outs)): |
| lvl_out_list.append(outs[s][i]) |
| lvl_out = self.output_transition[i](lvl_out_list) |
| final_outs.append(lvl_out) |
|
|
| return final_outs |
|
|