| |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from data_process.resnet import Resnet18 |
| from data_process.utils import encode_segmentation_rgb_batch |
|
|
|
|
| class ConvBNReLU(nn.Module): |
| def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): |
| super(ConvBNReLU, self).__init__() |
| self.conv = nn.Conv2d( |
| in_chan, |
| out_chan, |
| kernel_size=ks, |
| stride=stride, |
| padding=padding, |
| bias=False, |
| ) |
| self.bn = nn.BatchNorm2d(out_chan) |
| self.init_weight() |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = F.relu(self.bn(x)) |
| return x |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
|
|
| class BiSeNetOutput(nn.Module): |
| def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): |
| super(BiSeNetOutput, self).__init__() |
| self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) |
| self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) |
| self.init_weight() |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.conv_out(x) |
| return x |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
| def get_params(self): |
| wd_params, nowd_params = [], [] |
| for name, module in self.named_modules(): |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): |
| wd_params.append(module.weight) |
| if module.bias is not None: |
| nowd_params.append(module.bias) |
| elif isinstance(module, nn.BatchNorm2d): |
| nowd_params += list(module.parameters()) |
| return wd_params, nowd_params |
|
|
|
|
| class AttentionRefinementModule(nn.Module): |
| def __init__(self, in_chan, out_chan, *args, **kwargs): |
| super(AttentionRefinementModule, self).__init__() |
| self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) |
| self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) |
| self.bn_atten = nn.BatchNorm2d(out_chan) |
| self.sigmoid_atten = nn.Sigmoid() |
| self.init_weight() |
|
|
| def forward(self, x): |
| feat = self.conv(x) |
| atten = F.avg_pool2d(feat, feat.size()[2:]) |
| atten = self.conv_atten(atten) |
| atten = self.bn_atten(atten) |
| atten = self.sigmoid_atten(atten) |
| out = torch.mul(feat, atten) |
| return out |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
|
|
| class ContextPath(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super(ContextPath, self).__init__() |
| self.resnet = Resnet18() |
| self.arm16 = AttentionRefinementModule(256, 128) |
| self.arm32 = AttentionRefinementModule(512, 128) |
| self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) |
| self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) |
| self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) |
|
|
| self.init_weight() |
|
|
| def forward(self, x): |
| H0, W0 = x.size()[2:] |
| feat8, feat16, feat32 = self.resnet(x) |
| H8, W8 = feat8.size()[2:] |
| H16, W16 = feat16.size()[2:] |
| H32, W32 = feat32.size()[2:] |
|
|
| avg = F.avg_pool2d(feat32, feat32.size()[2:]) |
| avg = self.conv_avg(avg) |
| avg_up = F.interpolate(avg, (H32, W32), mode="nearest") |
|
|
| feat32_arm = self.arm32(feat32) |
| feat32_sum = feat32_arm + avg_up |
| feat32_up = F.interpolate(feat32_sum, (H16, W16), mode="nearest") |
| feat32_up = self.conv_head32(feat32_up) |
|
|
| feat16_arm = self.arm16(feat16) |
| feat16_sum = feat16_arm + feat32_up |
| feat16_up = F.interpolate(feat16_sum, (H8, W8), mode="nearest") |
| feat16_up = self.conv_head16(feat16_up) |
|
|
| return feat8, feat16_up, feat32_up |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
| def get_params(self): |
| wd_params, nowd_params = [], [] |
| for name, module in self.named_modules(): |
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| wd_params.append(module.weight) |
| if module.bias is not None: |
| nowd_params.append(module.bias) |
| elif isinstance(module, nn.BatchNorm2d): |
| nowd_params += list(module.parameters()) |
| return wd_params, nowd_params |
|
|
|
|
| |
| class SpatialPath(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super(SpatialPath, self).__init__() |
| self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) |
| self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) |
| self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) |
| self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) |
| self.init_weight() |
|
|
| def forward(self, x): |
| feat = self.conv1(x) |
| feat = self.conv2(feat) |
| feat = self.conv3(feat) |
| feat = self.conv_out(feat) |
| return feat |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
| def get_params(self): |
| wd_params, nowd_params = [], [] |
| for name, module in self.named_modules(): |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): |
| wd_params.append(module.weight) |
| if module.bias is not None: |
| nowd_params.append(module.bias) |
| elif isinstance(module, nn.BatchNorm2d): |
| nowd_params += list(module.parameters()) |
| return wd_params, nowd_params |
|
|
|
|
| class FeatureFusionModule(nn.Module): |
| def __init__(self, in_chan, out_chan, *args, **kwargs): |
| super(FeatureFusionModule, self).__init__() |
| self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) |
| self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) |
| self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) |
| self.relu = nn.ReLU(inplace=True) |
| self.sigmoid = nn.Sigmoid() |
| self.init_weight() |
|
|
| def forward(self, fsp, fcp): |
| fcat = torch.cat([fsp, fcp], dim=1) |
| feat = self.convblk(fcat) |
| atten = F.avg_pool2d(feat, feat.size()[2:]) |
| atten = self.conv1(atten) |
| atten = self.relu(atten) |
| atten = self.conv2(atten) |
| atten = self.sigmoid(atten) |
| feat_atten = torch.mul(feat, atten) |
| feat_out = feat_atten + feat |
| return feat_out |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
| def get_params(self): |
| wd_params, nowd_params = [], [] |
| for name, module in self.named_modules(): |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): |
| wd_params.append(module.weight) |
| if module.bias is not None: |
| nowd_params.append(module.bias) |
| elif isinstance(module, nn.BatchNorm2d): |
| nowd_params += list(module.parameters()) |
| return wd_params, nowd_params |
|
|
|
|
| class BiSeNet(nn.Module): |
| def __init__(self, n_classes, *args, **kwargs): |
| super(BiSeNet, self).__init__() |
| self.cp = ContextPath() |
| |
| self.ffm = FeatureFusionModule(256, 256) |
| self.conv_out = BiSeNetOutput(256, 256, n_classes) |
| self.conv_out16 = BiSeNetOutput(128, 64, n_classes) |
| self.conv_out32 = BiSeNetOutput(128, 64, n_classes) |
| self.init_weight() |
|
|
| def get_mask(self, x: torch.Tensor, crop_size: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| x = F.interpolate(x, size=(512, 512)) |
|
|
| parsed_face = self.forward(x)[0] |
|
|
| parsed_face = torch.argmax(parsed_face, dim=1, keepdim=True) |
|
|
| parsed_face = encode_segmentation_rgb_batch(parsed_face) |
|
|
| parsed_face = torch.where( |
| torch.sum(parsed_face, dim=[1, 2, 3], keepdim=True) > 5000, |
| parsed_face, |
| torch.zeros_like(parsed_face), |
| ) |
|
|
| ignore_mask_ids = torch.sum(parsed_face, dim=[1, 2, 3]) == 0 |
|
|
| parsed_face = parsed_face.float().mul_(1 / 255.0) |
|
|
| parsed_face = F.interpolate(parsed_face, size=(crop_size, crop_size), mode="bilinear") |
|
|
| parsed_face = torch.sum(parsed_face, dim=1, keepdim=True) |
|
|
| return parsed_face, ignore_mask_ids |
|
|
| def forward(self, x): |
| H, W = x.size()[2:] |
| feat_res8, feat_cp8, feat_cp16 = self.cp(x) |
| feat_sp = feat_res8 |
| feat_fuse = self.ffm(feat_sp, feat_cp8) |
|
|
| feat_out = self.conv_out(feat_fuse) |
| feat_out16 = self.conv_out16(feat_cp8) |
| feat_out32 = self.conv_out32(feat_cp16) |
|
|
| feat_out = F.interpolate(feat_out, (H, W), mode="bilinear", align_corners=True) |
| feat_out16 = F.interpolate(feat_out16, (H, W), mode="bilinear", align_corners=True) |
| feat_out32 = F.interpolate(feat_out32, (H, W), mode="bilinear", align_corners=True) |
| return feat_out, feat_out16, feat_out32 |
|
|
| def init_weight(self): |
| for ly in self.children(): |
| if isinstance(ly, nn.Conv2d): |
| nn.init.kaiming_normal_(ly.weight, a=1) |
| if ly.bias is not None: |
| nn.init.constant_(ly.bias, 0) |
|
|
| def get_params(self): |
| wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] |
| for name, child in self.named_children(): |
| child_wd_params, child_nowd_params = child.get_params() |
| if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): |
| lr_mul_wd_params += child_wd_params |
| lr_mul_nowd_params += child_nowd_params |
| else: |
| wd_params += child_wd_params |
| nowd_params += child_nowd_params |
| return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params |
|
|
|
|
| if __name__ == "__main__": |
| net = BiSeNet(19) |
| net.cuda() |
| net.eval() |
| in_ten = torch.randn(16, 3, 640, 480).cuda() |
| out, out16, out32 = net(in_ten) |
| print(out.shape) |
|
|
| net.get_params() |
|
|