|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
|
|
| from rscd.models.backbones.Decom_Backbone import ResNet3D |
|
|
|
|
| class AFCD3D_backbone(nn.Module): |
| def __init__(self): |
| super(AFCD3D_backbone, self).__init__() |
| resnet = torchvision.models.resnet18(pretrained=True) |
| self.resnet = ResNet3D(resnet) |
|
|
| def forward(self, imageA, imageB): |
| imageA = imageA.unsqueeze(2) |
| imageB = imageB.unsqueeze(2) |
| x = torch.cat([imageA, imageB], 2) |
| size = x.size()[3:] |
| x = self.resnet.conv1(x) |
| x = self.resnet.bn1(x) |
| x0 = self.resnet.relu(x) |
| x = self.resnet.maxpool(x0) |
| x1 = self.resnet.layer1(x) |
| x2 = self.resnet.layer2(x1) |
| x3 = self.resnet.layer3(x2) |
| x4 = self.resnet.layer4(x3) |
|
|
| return [size, x0, x1, x2, x3, x4] |
|
|
|
|