| | |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| | from mmengine.model import BaseModule, ModuleList |
| |
|
| |
|
| | class ConvUpsample(BaseModule): |
| | """ConvUpsample performs 2x upsampling after Conv. |
| | |
| | There are several `ConvModule` layers. In the first few layers, upsampling |
| | will be applied after each layer of convolution. The number of upsampling |
| | must be no more than the number of ConvModule layers. |
| | |
| | Args: |
| | in_channels (int): Number of channels in the input feature map. |
| | inner_channels (int): Number of channels produced by the convolution. |
| | num_layers (int): Number of convolution layers. |
| | num_upsample (int | optional): Number of upsampling layer. Must be no |
| | more than num_layers. Upsampling will be applied after the first |
| | ``num_upsample`` layers of convolution. Default: ``num_layers``. |
| | conv_cfg (dict): Config dict for convolution layer. Default: None, |
| | which means using conv2d. |
| | norm_cfg (dict): Config dict for normalization layer. Default: None. |
| | init_cfg (dict): Config dict for initialization. Default: None. |
| | kwargs (key word augments): Other augments used in ConvModule. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | inner_channels, |
| | num_layers=1, |
| | num_upsample=None, |
| | conv_cfg=None, |
| | norm_cfg=None, |
| | init_cfg=None, |
| | **kwargs): |
| | super(ConvUpsample, self).__init__(init_cfg) |
| | if num_upsample is None: |
| | num_upsample = num_layers |
| | assert num_upsample <= num_layers, \ |
| | f'num_upsample({num_upsample})must be no more than ' \ |
| | f'num_layers({num_layers})' |
| | self.num_layers = num_layers |
| | self.num_upsample = num_upsample |
| | self.conv = ModuleList() |
| | for i in range(num_layers): |
| | self.conv.append( |
| | ConvModule( |
| | in_channels, |
| | inner_channels, |
| | 3, |
| | padding=1, |
| | stride=1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | **kwargs)) |
| | in_channels = inner_channels |
| |
|
| | def forward(self, x): |
| | num_upsample = self.num_upsample |
| | for i in range(self.num_layers): |
| | x = self.conv[i](x) |
| | if num_upsample > 0: |
| | num_upsample -= 1 |
| | x = F.interpolate( |
| | x, scale_factor=2, mode='bilinear', align_corners=False) |
| | return x |
| |
|