import torch import torch.nn as nn class CNN(nn.Module): def __init__(self,layer,matrixSize=32): super(CNN,self).__init__() if(layer == 'r31'): # 256x64x64 self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1), nn.ReLU(inplace=True), nn.Conv2d(128,64,3,1,1), nn.ReLU(inplace=True), nn.Conv2d(64,matrixSize,3,1,1)) elif(layer == 'r41'): # 512x32x32 self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1), nn.ReLU(inplace=True), nn.Conv2d(256,128,3,1,1), nn.ReLU(inplace=True), nn.Conv2d(128,matrixSize,3,1,1)) # 32x8x8 self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize) #self.fc = nn.Linear(32*64,256*256) def forward(self,x): out = self.convs(x) # 32x8x8 b,c,h,w = out.size() out = out.view(b,c,-1) # 32x64 out = torch.bmm(out,out.transpose(1,2)).div(h*w) # 32x32 out = out.view(out.size(0),-1) return self.fc(out) class MulLayer(nn.Module): def __init__(self,layer,matrixSize=32): super(MulLayer,self).__init__() self.snet = CNN(layer,matrixSize) self.cnet = CNN(layer,matrixSize) self.matrixSize = matrixSize if(layer == 'r41'): self.compress = nn.Conv2d(512,matrixSize,1,1,0) self.unzip = nn.Conv2d(matrixSize,512,1,1,0) elif(layer == 'r31'): self.compress = nn.Conv2d(256,matrixSize,1,1,0) self.unzip = nn.Conv2d(matrixSize,256,1,1,0) self.transmatrix = None def forward(self,cF,sF,trans=True): cFBK = cF.clone() cb,cc,ch,cw = cF.size() cFF = cF.view(cb,cc,-1) cMean = torch.mean(cFF,dim=2,keepdim=True) cMean = cMean.unsqueeze(3) cMean = cMean.expand_as(cF) cF = cF - cMean sb,sc,sh,sw = sF.size() sFF = sF.view(sb,sc,-1) sMean = torch.mean(sFF,dim=2,keepdim=True) sMean = sMean.unsqueeze(3) sMeanC = sMean.expand_as(cF) sMeanS = sMean.expand_as(sF) sF = sF - sMeanS compress_content = self.compress(cF) b,c,h,w = compress_content.size() compress_content = compress_content.view(b,c,-1) if(trans): cMatrix = self.cnet(cF) sMatrix = self.snet(sF) sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize) cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize) transmatrix = torch.bmm(sMatrix,cMatrix) print(cMatrix) transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w) out = self.unzip(transfeature.view(b,c,h,w)) out = out + sMeanC return out, transmatrix else: out = self.unzip(compress_content.view(b,c,h,w)) out = out + cMean return out