| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from thop import profile |
| from model.auxiliary import VSSM |
| import torch |
| from model.LaSEA import * |
| import torch |
| import time |
| from thop import profile |
| class ChannelAttention(nn.Module): |
| def __init__(self, in_planes, ratio=16): |
| super(ChannelAttention, self).__init__() |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.max_pool = nn.AdaptiveMaxPool2d(1) |
| self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) |
| self.relu1 = nn.ReLU() |
| self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) |
| max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) |
| out = avg_out + max_out |
| return self.sigmoid(out) |
|
|
|
|
| class SpatialAttention(nn.Module): |
| def __init__(self, kernel_size=7): |
| super(SpatialAttention, self).__init__() |
| assert kernel_size in (3, 7), 'kernel size must be 3 or 7' |
| padding = 3 if kernel_size == 7 else 1 |
| self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| avg_out = torch.mean(x, dim=1, keepdim=True) |
| max_out, _ = torch.max(x, dim=1, keepdim=True) |
| x = torch.cat([avg_out, max_out], dim=1) |
| x = self.conv1(x) |
| return self.sigmoid(x) |
|
|
|
|
| class ResNet(nn.Module): |
| def __init__(self, in_channels, out_channels, stride=1): |
| super(ResNet, self).__init__() |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) |
| self.bn1 = nn.BatchNorm2d(out_channels) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(out_channels) |
| if stride != 1 or out_channels != in_channels: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), |
| nn.BatchNorm2d(out_channels)) |
| else: |
| self.shortcut = None |
|
|
| self.ca = ChannelAttention(out_channels) |
| self.sa = SpatialAttention() |
|
|
| def forward(self, x): |
| residual = x |
| if self.shortcut is not None: |
| residual = self.shortcut(x) |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
| out = self.ca(out) * out |
| out = self.sa(out) * out |
| out += residual |
| out = self.relu(out) |
| return out |
|
|
|
|
| class DCCS(nn.Module): |
| def __init__(self, input_channels, block=ResNet): |
| super().__init__() |
| param_channels = [16, 32, 64, 128, 256] |
| param_blocks = [2, 2, 2, 2] |
| self.pool = nn.MaxPool2d(2, 2) |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) |
| self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) |
| self.up_16 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) |
| self.conv_init = nn.Conv2d(input_channels, param_channels[0], 1, 1) |
| self.encoder_0 = self._make_layer(param_channels[0], param_channels[0], block) |
| self.encoder_1 = self._make_layer(param_channels[0], param_channels[1], block, param_blocks[0]) |
| self.encoder_2 = self._make_layer(param_channels[1], param_channels[2], block, param_blocks[1]) |
| self.encoder_3 = self._make_layer(param_channels[2], param_channels[3], block, param_blocks[2]) |
|
|
| self.middle_layer = self._make_layer(param_channels[3], param_channels[4], block, param_blocks[3]) |
|
|
| self.decoder_3 = self._make_layer(param_channels[3] + param_channels[4], param_channels[3], block, |
| param_blocks[2]) |
| self.decoder_2 = self._make_layer(param_channels[2] + param_channels[3], param_channels[2], block, |
| param_blocks[1]) |
| self.decoder_1 = self._make_layer(param_channels[1] + param_channels[2], param_channels[1], block, |
| param_blocks[0]) |
| self.decoder_0 = self._make_layer(param_channels[0] + param_channels[1], param_channels[0], block) |
|
|
| self.output_0 = nn.Conv2d(param_channels[0], 1, 1) |
| self.output_1 = nn.Conv2d(param_channels[1], 1, 1) |
| self.output_2 = nn.Conv2d(param_channels[2], 1, 1) |
| self.output_3 = nn.Conv2d(param_channels[3], 1, 1) |
| self.final = nn.Conv2d(4, 1, 3, 1, 1) |
| self.VSSM = VSSM() |
| self.post_fuse3 = nn.Conv2d(param_channels[3] * 2, param_channels[3], kernel_size=1) |
| self.post_fuse2 = nn.Conv2d(param_channels[2] * 2, param_channels[2], kernel_size=1) |
| self.post_fuse1 = nn.Conv2d(param_channels[1] * 2, param_channels[1], kernel_size=1) |
| self.post_fuse0 = nn.Conv2d(param_channels[0] * 2, param_channels[0], kernel_size=1) |
| self.GLFA = GLFA(in_channels=256) |
| def _make_layer(self, in_channels, out_channels, block, block_num=1): |
| layer = [] |
| layer.append(block(in_channels, out_channels)) |
| for _ in range(block_num - 1): |
| layer.append(block(out_channels, out_channels)) |
| return nn.Sequential(*layer) |
| def forward(self, x, warm_flag): |
| outputs = self.VSSM(x) |
| x_e0f = outputs[0].permute(0, 3, 1, 2).contiguous() |
| x_e1f = outputs[1].permute(0, 3, 1, 2).contiguous() |
| x_e2f = outputs[2].permute(0, 3, 1, 2).contiguous() |
| x_e3f = outputs[3].permute(0, 3, 1, 2).contiguous() |
| x_e0z = self.encoder_0(self.conv_init(x)) |
| x_e0 = torch.cat([x_e0z, x_e0f], dim=1) |
| x_e0z = self.post_fuse0(x_e0) |
| x_e1z = self.encoder_1(self.pool(x_e0z)) |
| x_e1_fused = torch.cat([x_e1z, x_e1f], dim=1) |
| x_e1z = self.post_fuse1(x_e1_fused) |
| x_e2z = self.encoder_2(self.pool(x_e1z)) |
| x_e2_fused = torch.cat([x_e2z, x_e2f], dim=1) |
| x_e2z = self.post_fuse2(x_e2_fused) |
| x_e3z = self.encoder_3(self.pool(x_e2z)) |
| x_e3_fused = torch.cat([x_e3z, x_e3f], dim=1) |
| x_e3z = self.post_fuse3(x_e3_fused) |
| x_m = self.middle_layer(self.pool(x_e3z)) |
| x_m = self.GLFA(x_m) |
| x_d3 = self.decoder_3(torch.cat([x_e3z, self.up(x_m)], 1)) |
| x_d2 = self.decoder_2(torch.cat([x_e2z, self.up(x_d3)], 1)) |
| x_d1 = self.decoder_1(torch.cat([x_e1z, self.up(x_d2)], 1)) |
| x_d0 = self.decoder_0(torch.cat([x_e0z, self.up(x_d1)], 1)) |
|
|
| if warm_flag: |
| mask0 = self.output_0(x_d0) |
| mask1 = self.output_1(x_d1) |
| mask2 = self.output_2(x_d2) |
| mask3 = self.output_3(x_d3) |
| output = self.final(torch.cat([mask0, self.up(mask1), self.up_4(mask2), self.up_8(mask3)], dim=1)) |
| return [mask0, mask1, mask2, mask3], output |
|
|
| else: |
| output = self.output_0(x_d0) |
| return [], output |
|
|
|
|