| | from rscd.models.backbones.lgpnet.unet_parts import *
|
| |
|
| | class BCDNET(nn.Module):
|
| | """ Local-Global Pyramid Network (LGPNet) """
|
| | def __init__(self, n_channels, n_classes):
|
| | super(BCDNET, self).__init__()
|
| | self.n_channels = n_channels
|
| | self.n_classes = n_classes
|
| | self.conv = TribleConv(128, 64)
|
| | self.final = OutConv(64, n_classes)
|
| |
|
| | def forward(self, x=[]):
|
| |
|
| |
|
| | feat1 = x[2]
|
| | feat2 = x[3]
|
| | fusionfeats = torch.cat([feat1, feat2], dim=1)
|
| |
|
| | x = self.conv(fusionfeats)
|
| | logits = self.final(x)
|
| | return logits
|
| |
|
| |
|
| | class TribleConv(nn.Module):
|
| | """(convolution => [BN] => ReLU) 2次"""
|
| |
|
| | def __init__(self, in_channels, out_channels):
|
| | super().__init__()
|
| | self.trible_conv = nn.Sequential(
|
| | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(out_channels),
|
| | nn.ReLU(inplace=True),
|
| | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| | nn.BatchNorm2d(out_channels),
|
| | nn.ReLU(inplace=True)
|
| | )
|
| |
|
| | def forward(self, x):
|
| | return self.trible_conv(x)
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | net = BCDNET(n_channels=3, n_classes=1)
|
| | print(net) |