| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from torch.autograd import Variable |
| |
|
| |
|
| | def default_conv(in_channels, out_channels, kernel_size, bias=True): |
| | return nn.Conv2d( |
| | in_channels, out_channels, kernel_size, |
| | padding=(kernel_size//2), bias=bias) |
| |
|
| |
|
| | class ResUnit(nn.Module): |
| | def __init__(self, dim): |
| | super(ResUnit, self).__init__() |
| |
|
| | self.act = nn.ReLU(True) |
| | self.conv1 = default_conv(dim, dim, 3) |
| | self.conv2 = default_conv(dim, dim*2, 1) |
| | self.conv3 = default_conv(dim*2, dim, 1) |
| |
|
| | def forward(self, x): |
| | shortcut = x |
| |
|
| | x = self.conv1(x) |
| | x = self.conv2(x) |
| | x = self.act(x) |
| | x = self.conv3(x) |
| |
|
| | return x + shortcut |
| |
|
| |
|
| | class FusionBlock(nn.Module): |
| | def __init__(self, n_color, embed_dim): |
| | super(FusionBlock, self).__init__() |
| |
|
| | self.act = nn.ReLU(True) |
| |
|
| | self.conv_1 = default_conv(n_color, embed_dim, 3) |
| | self.conv_2 = default_conv(embed_dim, embed_dim, 3) |
| |
|
| | self.conv_1_2 = default_conv(embed_dim, embed_dim, 3) |
| | self.conv_2_2 = default_conv(embed_dim, embed_dim, 3) |
| |
|
| | self.ru_1 = ResUnit(embed_dim) |
| | self.ru_2 = ResUnit(embed_dim) |
| |
|
| | self.ru_1_1 = ResUnit(embed_dim) |
| | self.ru_2_1 = ResUnit(embed_dim) |
| |
|
| | self.ru = ResUnit(embed_dim) |
| | self.ru_ = ResUnit(embed_dim) |
| |
|
| | self.conv_tail_1 = default_conv(embed_dim*2, embed_dim, 3) |
| | self.conv_tail_2 = default_conv(embed_dim, embed_dim, 3) |
| |
|
| | def forward(self, img_snow, mask): |
| |
|
| | img_snow = self.ru_1(self.conv_1(img_snow)) |
| | mask = self.ru_2(self.conv_2(mask)) |
| |
|
| | img_1 = self.ru(self.conv_1_2((img_snow-mask))) |
| |
|
| | img_snow = self.ru_1_1(img_snow) |
| | mask = self.ru_2_1(mask) |
| |
|
| | img_2 = self.ru_(self.conv_2_2((img_snow-mask))) |
| |
|
| |
|
| | return self.conv_tail_2(self.act(self.conv_tail_1(torch.cat((img_1, img_2), dim=1)))) |
| |
|
| |
|
| | class MARB(nn.Module): |
| | def __init__(self, dim): |
| | super(MARB, self).__init__() |
| |
|
| | self.act = nn.ReLU(True) |
| |
|
| | self.conv_dl2 = default_conv(dim, dim, 1) |
| | self.conv_dl3 = default_conv(dim, dim, 3) |
| | self.conv_dl5 = default_conv(dim, dim, 5) |
| |
|
| | self.conv1_1 = default_conv(dim, dim, 1) |
| | self.conv1_2 = default_conv(dim, dim, 1) |
| | self.conv1_3 = default_conv(dim, dim, 1) |
| |
|
| | self.conv2_1 = default_conv(dim*2, dim, 1) |
| | self.conv2_2 = default_conv(dim*2, dim, 1) |
| |
|
| | self.conv_tail = default_conv(dim*2, dim, 1) |
| |
|
| | def forward(self, x): |
| | x1 = self.conv1_1(self.conv_dl2(x)) |
| | x2 = self.conv1_2(self.conv_dl3(x)) |
| | x3 = self.conv1_3(self.conv_dl5(x)) |
| |
|
| | x_cat_1 = self.conv2_1(torch.cat((x1, x2), dim=1)) |
| | x_cat_2 = self.conv2_2(torch.cat((x2, x3), dim=1)) |
| |
|
| | return self.conv_tail(self.act(torch.cat((x_cat_1, x_cat_2), dim=1))) + x |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class MaskBlock(nn.Module): |
| | def __init__(self, embed_dim): |
| | super(MaskBlock, self).__init__() |
| | self.act = nn.ReLU(True) |
| | self.conv_head = default_conv(embed_dim, embed_dim, 3) |
| |
|
| | self.conv_self = default_conv(embed_dim, embed_dim, 1) |
| |
|
| | self.conv1 = default_conv(embed_dim, embed_dim, 3) |
| | self.conv1_1 = default_conv(embed_dim, embed_dim, 1) |
| | self.conv1_2 = default_conv(embed_dim, embed_dim, 1) |
| | self.conv_tail = default_conv(embed_dim, embed_dim, 3) |
| |
|
| | def forward(self, x): |
| | x = self.conv_head(x) |
| | x = self.conv_self(x) |
| | x = x.mul(x) |
| | x = self.act(self.conv1(x)) |
| | x = self.conv1_1(x).mul(self.conv1_2(x)) |
| |
|
| | return self.conv_tail(x) |
| |
|
| | def dwt_init(x): |
| | x01 = x[:, :, 0::2, :] / 2 |
| | x02 = x[:, :, 1::2, :] / 2 |
| | x1 = x01[:, :, :, 0::2] |
| | x2 = x02[:, :, :, 0::2] |
| | x3 = x01[:, :, :, 1::2] |
| | x4 = x02[:, :, :, 1::2] |
| | x_LL = x1 + x2 + x3 + x4 |
| | x_HL = -x1 - x2 + x3 + x4 |
| | x_LH = -x1 + x2 - x3 + x4 |
| | x_HH = x1 - x2 - x3 + x4 |
| |
|
| | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) |
| |
|
| |
|
| | def iwt_init(x): |
| | r = 2 |
| | in_batch, in_channel, in_height, in_width = x.size() |
| | |
| | out_batch, out_channel, out_height, out_width = in_batch, int( |
| | in_channel / (r ** 2)), r * in_height, r * in_width |
| | x1 = x[:, 0:out_channel, :, :] / 2 |
| | x2 = x[:, out_channel:out_channel * 2, :, :] / 2 |
| | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 |
| | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 |
| |
|
| | h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() |
| |
|
| | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 |
| | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 |
| | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 |
| | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 |
| |
|
| | return h |
| |
|
| | class DWT(nn.Module): |
| | def __init__(self): |
| | super(DWT, self).__init__() |
| | self.requires_grad = False |
| |
|
| | def forward(self, x): |
| | return dwt_init(x) |
| |
|
| |
|
| | class IWT(nn.Module): |
| | def __init__(self): |
| | super(IWT, self).__init__() |
| | self.requires_grad = False |
| |
|
| | def forward(self, x): |
| | return iwt_init(x) |