incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self,layer,matrixSize=32):
super(CNN,self).__init__()
if(layer == 'r31'):
# 256x64x64
self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(128,64,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(64,matrixSize,3,1,1))
elif(layer == 'r41'):
# 512x32x32
self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(256,128,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(128,matrixSize,3,1,1))
# 32x8x8
self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
#self.fc = nn.Linear(32*64,256*256)
def forward(self,x):
out = self.convs(x)
# 32x8x8
b,c,h,w = out.size()
out = out.view(b,c,-1)
# 32x64
out = torch.bmm(out,out.transpose(1,2)).div(h*w)
# 32x32
out = out.view(out.size(0),-1)
return self.fc(out)
class MulLayer(nn.Module):
def __init__(self,layer,matrixSize=32):
super(MulLayer,self).__init__()
self.snet = CNN(layer,matrixSize)
self.cnet = CNN(layer,matrixSize)
self.matrixSize = matrixSize
if(layer == 'r41'):
self.compress = nn.Conv2d(512,matrixSize,1,1,0)
self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
elif(layer == 'r31'):
self.compress = nn.Conv2d(256,matrixSize,1,1,0)
self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
self.transmatrix = None
def forward(self,cF,sF,trans=True):
cFBK = cF.clone()
cb,cc,ch,cw = cF.size()
cFF = cF.view(cb,cc,-1)
cMean = torch.mean(cFF,dim=2,keepdim=True)
cMean = cMean.unsqueeze(3)
cMean = cMean.expand_as(cF)
cF = cF - cMean
sb,sc,sh,sw = sF.size()
sFF = sF.view(sb,sc,-1)
sMean = torch.mean(sFF,dim=2,keepdim=True)
sMean = sMean.unsqueeze(3)
sMeanC = sMean.expand_as(cF)
sMeanS = sMean.expand_as(sF)
sF = sF - sMeanS
compress_content = self.compress(cF)
b,c,h,w = compress_content.size()
compress_content = compress_content.view(b,c,-1)
if(trans):
cMatrix = self.cnet(cF)
sMatrix = self.snet(sF)
sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
transmatrix = torch.bmm(sMatrix,cMatrix)
print(cMatrix)
transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
out = self.unzip(transfeature.view(b,c,h,w))
out = out + sMeanC
return out, transmatrix
else:
out = self.unzip(compress_content.view(b,c,h,w))
out = out + cMean
return out