| | |
| | from typing import Optional |
| |
|
| | import torch.nn as nn |
| | from mmcv.cnn import ConvModule |
| | from mmengine.model import BaseModule |
| | from torch import Tensor |
| |
|
| | from mmseg.registry import MODELS |
| | from mmseg.utils import OptConfigType |
| |
|
| |
|
| | class BasicBlock(BaseModule): |
| | """Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_. |
| | |
| | Args: |
| | in_channels (int): Input channels. |
| | channels (int): Output channels. |
| | stride (int): Stride of the first block. Default: 1. |
| | downsample (nn.Module, optional): Downsample operation on identity. |
| | Default: None. |
| | norm_cfg (dict, optional): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict, optional): Config dict for activation layer in |
| | ConvModule. Default: dict(type='ReLU', inplace=True). |
| | act_cfg_out (dict, optional): Config dict for activation layer at the |
| | last of the block. Default: None. |
| | init_cfg (dict, optional): Initialization config dict. Default: None. |
| | """ |
| |
|
| | expansion = 1 |
| |
|
| | def __init__(self, |
| | in_channels: int, |
| | channels: int, |
| | stride: int = 1, |
| | downsample: nn.Module = None, |
| | norm_cfg: OptConfigType = dict(type='BN'), |
| | act_cfg: OptConfigType = dict(type='ReLU', inplace=True), |
| | act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True), |
| | init_cfg: OptConfigType = None): |
| | super().__init__(init_cfg) |
| | self.conv1 = ConvModule( |
| | in_channels, |
| | channels, |
| | kernel_size=3, |
| | stride=stride, |
| | padding=1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| | self.conv2 = ConvModule( |
| | channels, |
| | channels, |
| | kernel_size=3, |
| | padding=1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None) |
| | self.downsample = downsample |
| | if act_cfg_out: |
| | self.act = MODELS.build(act_cfg_out) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | residual = x |
| | out = self.conv1(x) |
| | out = self.conv2(out) |
| |
|
| | if self.downsample: |
| | residual = self.downsample(x) |
| |
|
| | out += residual |
| |
|
| | if hasattr(self, 'act'): |
| | out = self.act(out) |
| |
|
| | return out |
| |
|
| |
|
| | class Bottleneck(BaseModule): |
| | """Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_. |
| | |
| | Args: |
| | in_channels (int): Input channels. |
| | channels (int): Output channels. |
| | stride (int): Stride of the first block. Default: 1. |
| | downsample (nn.Module, optional): Downsample operation on identity. |
| | Default: None. |
| | norm_cfg (dict, optional): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict, optional): Config dict for activation layer in |
| | ConvModule. Default: dict(type='ReLU', inplace=True). |
| | act_cfg_out (dict, optional): Config dict for activation layer at |
| | the last of the block. Default: None. |
| | init_cfg (dict, optional): Initialization config dict. Default: None. |
| | """ |
| |
|
| | expansion = 2 |
| |
|
| | def __init__(self, |
| | in_channels: int, |
| | channels: int, |
| | stride: int = 1, |
| | downsample: Optional[nn.Module] = None, |
| | norm_cfg: OptConfigType = dict(type='BN'), |
| | act_cfg: OptConfigType = dict(type='ReLU', inplace=True), |
| | act_cfg_out: OptConfigType = None, |
| | init_cfg: OptConfigType = None): |
| | super().__init__(init_cfg) |
| | self.conv1 = ConvModule( |
| | in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| | self.conv2 = ConvModule( |
| | channels, |
| | channels, |
| | 3, |
| | stride, |
| | 1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| | self.conv3 = ConvModule( |
| | channels, |
| | channels * self.expansion, |
| | 1, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None) |
| | if act_cfg_out: |
| | self.act = MODELS.build(act_cfg_out) |
| | self.downsample = downsample |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | residual = x |
| |
|
| | out = self.conv1(x) |
| | out = self.conv2(out) |
| | out = self.conv3(out) |
| |
|
| | if self.downsample: |
| | residual = self.downsample(x) |
| |
|
| | out += residual |
| |
|
| | if hasattr(self, 'act'): |
| | out = self.act(out) |
| |
|
| | return out |
| |
|