| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from .utils import _SimpleSegmentationModel |
| |
|
| |
|
| | __all__ = ["DeepLabV3"] |
| |
|
| |
|
| | class DeepLabV3(_SimpleSegmentationModel): |
| | """ |
| | Implements DeepLabV3 model from |
| | `"Rethinking Atrous Convolution for Semantic Image Segmentation" |
| | <https://arxiv.org/abs/1706.05587>`_. |
| | |
| | Arguments: |
| | backbone (nn.Module): the network used to compute the features for the model. |
| | The backbone should return an OrderedDict[Tensor], with the key being |
| | "out" for the last feature map used, and "aux" if an auxiliary classifier |
| | is used. |
| | classifier (nn.Module): module that takes the "out" element returned from |
| | the backbone and returns a dense prediction. |
| | aux_classifier (nn.Module, optional): auxiliary classifier used during training |
| | """ |
| | pass |
| |
|
| | class DeepLabHeadV3Plus(nn.Module): |
| | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): |
| | super(DeepLabHeadV3Plus, self).__init__() |
| | self.project = nn.Sequential( |
| | nn.Conv2d(low_level_channels, 48, 1, bias=False), |
| | nn.BatchNorm2d(48), |
| | nn.ReLU(inplace=True), |
| | ) |
| |
|
| | self.aspp = ASPP(in_channels, aspp_dilate) |
| |
|
| | self.classifier = nn.Sequential( |
| | nn.Conv2d(304, 256, 3, padding=1, bias=False), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(256, num_classes, 1) |
| | ) |
| | self._init_weight() |
| |
|
| | def forward(self, feature): |
| | low_level_feature = self.project( feature['low_level'] ) |
| | output_feature = self.aspp(feature['out']) |
| | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) |
| | return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) |
| | |
| | def _init_weight(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight) |
| | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | class DeepLabHead(nn.Module): |
| | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): |
| | super(DeepLabHead, self).__init__() |
| |
|
| | self.classifier = nn.Sequential( |
| | ASPP(in_channels, aspp_dilate), |
| | nn.Conv2d(256, 256, 3, padding=1, bias=False), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(256, num_classes, 1) |
| | ) |
| | self._init_weight() |
| |
|
| | def forward(self, feature): |
| | return self.classifier( feature['out'] ) |
| |
|
| | def _init_weight(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight) |
| | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | class AtrousSeparableConvolution(nn.Module): |
| | """ Atrous Separable Convolution |
| | """ |
| | def __init__(self, in_channels, out_channels, kernel_size, |
| | stride=1, padding=0, dilation=1, bias=True): |
| | super(AtrousSeparableConvolution, self).__init__() |
| | self.body = nn.Sequential( |
| | |
| | nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), |
| | |
| | nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), |
| | ) |
| | |
| | self._init_weight() |
| |
|
| | def forward(self, x): |
| | return self.body(x) |
| |
|
| | def _init_weight(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight) |
| | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | class ASPPConv(nn.Sequential): |
| | def __init__(self, in_channels, out_channels, dilation): |
| | modules = [ |
| | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(inplace=True) |
| | ] |
| | super(ASPPConv, self).__init__(*modules) |
| |
|
| | class ASPPPooling(nn.Sequential): |
| | def __init__(self, in_channels, out_channels): |
| | super(ASPPPooling, self).__init__( |
| | nn.AdaptiveAvgPool2d(1), |
| | nn.Conv2d(in_channels, out_channels, 1, bias=False), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(inplace=True)) |
| |
|
| | def forward(self, x): |
| | size = x.shape[-2:] |
| | x = super(ASPPPooling, self).forward(x) |
| | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) |
| |
|
| | class ASPP(nn.Module): |
| | def __init__(self, in_channels, atrous_rates): |
| | super(ASPP, self).__init__() |
| | out_channels = 256 |
| | modules = [] |
| | modules.append(nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, 1, bias=False), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(inplace=True))) |
| |
|
| | rate1, rate2, rate3 = tuple(atrous_rates) |
| | modules.append(ASPPConv(in_channels, out_channels, rate1)) |
| | modules.append(ASPPConv(in_channels, out_channels, rate2)) |
| | modules.append(ASPPConv(in_channels, out_channels, rate3)) |
| | modules.append(ASPPPooling(in_channels, out_channels)) |
| |
|
| | self.convs = nn.ModuleList(modules) |
| |
|
| | self.project = nn.Sequential( |
| | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout(0.1),) |
| |
|
| | def forward(self, x): |
| | res = [] |
| | for conv in self.convs: |
| | res.append(conv(x)) |
| | res = torch.cat(res, dim=1) |
| | return self.project(res) |
| |
|
| |
|
| |
|
| | def convert_to_separable_conv(module): |
| | new_module = module |
| | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: |
| | new_module = AtrousSeparableConvolution(module.in_channels, |
| | module.out_channels, |
| | module.kernel_size, |
| | module.stride, |
| | module.padding, |
| | module.dilation, |
| | module.bias) |
| | for name, child in module.named_children(): |
| | new_module.add_module(name, convert_to_separable_conv(child)) |
| | return new_module |