| import torch
|
| import torch.nn as nn
|
|
|
| from modules.basic_layers import (
|
| SinusoidalPositionalEmbedding,
|
| ResGatedBlock,
|
| MaxViTBlock,
|
| Downsample,
|
| Upsample
|
| )
|
|
|
| class UnetDownBlock(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| temb_channels: int = 128,
|
| heads: int = 1,
|
| window_size: int = 7,
|
| window_attn: bool = True,
|
| grid_attn: bool = True,
|
| expansion_rate: int = 4,
|
| num_conv_blocks: int = 2,
|
| dropout: float = 0.0
|
| ):
|
| super(UnetDownBlock, self).__init__()
|
| self.pool = Downsample(
|
| in_channels = in_channels,
|
| out_channels = in_channels,
|
| use_conv = True
|
| )
|
| in_channels = 3 * in_channels + 2
|
| self.conv = nn.ModuleList([
|
| ResGatedBlock(
|
| in_channels = in_channels if i == 0 else out_channels,
|
| out_channels = out_channels,
|
| emb_channels = temb_channels,
|
| gated_conv = True
|
| ) for i in range(num_conv_blocks)
|
| ])
|
| self.maxvit = MaxViTBlock(
|
| channels = out_channels,
|
|
|
| heads = heads,
|
| window_size = window_size,
|
| window_attn = window_attn,
|
| grid_attn = grid_attn,
|
| expansion_rate = expansion_rate,
|
| dropout = dropout,
|
| emb_channels = temb_channels
|
| )
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| warp0: torch.Tensor,
|
| warp1: torch.Tensor,
|
| temb: torch.Tensor
|
| ):
|
| x = self.pool(x)
|
| x = torch.cat([x, warp0, warp1], dim=1)
|
| for conv in self.conv:
|
| x = conv(x, temb)
|
| x = self.maxvit(x, temb)
|
| return x
|
|
|
| class UnetMiddleBlock(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| mid_channels: int,
|
| out_channels: int,
|
| temb_channels: int = 128,
|
| heads: int = 1,
|
| window_size: int = 7,
|
| window_attn: bool = True,
|
| grid_attn: bool = True,
|
| expansion_rate: int = 4,
|
| dropout: float = 0.0
|
| ):
|
| super(UnetMiddleBlock, self).__init__()
|
|
|
| self.middle_blocks = nn.ModuleList([
|
| ResGatedBlock(
|
| in_channels = in_channels,
|
| out_channels = mid_channels,
|
| emb_channels = temb_channels,
|
| gated_conv = True
|
| ),
|
| MaxViTBlock(
|
| channels = mid_channels,
|
|
|
| heads = heads,
|
| window_size = window_size,
|
| window_attn = window_attn,
|
| grid_attn = grid_attn,
|
| expansion_rate = expansion_rate,
|
| dropout = dropout,
|
| emb_channels = temb_channels
|
| ),
|
| ResGatedBlock(
|
| in_channels = mid_channels,
|
| out_channels = out_channels,
|
| emb_channels = temb_channels,
|
| gated_conv = True
|
| )
|
| ])
|
|
|
| def forward(self, x, temb):
|
| for block in self.middle_blocks:
|
| x = block(x, temb)
|
| return x
|
|
|
| class UnetUpBlock(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| temb_channels: int = 128,
|
| heads: int = 1,
|
| window_size: int = 7,
|
| window_attn: bool = True,
|
| grid_attn: bool = True,
|
| expansion_rate: int = 4,
|
| num_conv_blocks: int = 2,
|
| dropout: float = 0.0
|
| ):
|
| super(UnetUpBlock, self).__init__()
|
| in_channels = 2 * in_channels
|
| self.maxvit = MaxViTBlock(
|
| channels = in_channels,
|
|
|
| heads = heads,
|
| window_size = window_size,
|
| window_attn = window_attn,
|
| grid_attn = grid_attn,
|
| expansion_rate = expansion_rate,
|
| dropout = dropout,
|
| emb_channels = temb_channels
|
| )
|
| self.upsample = Upsample(
|
| in_channels = in_channels,
|
| out_channels = in_channels,
|
| use_conv = True
|
| )
|
| self.conv = nn.ModuleList([
|
| ResGatedBlock(
|
| in_channels if i == 0 else out_channels,
|
| out_channels,
|
| emb_channels = temb_channels,
|
| gated_conv = True
|
| ) for i in range(num_conv_blocks)
|
| ])
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| skip_connection: torch.Tensor,
|
| temb: torch.Tensor
|
| ):
|
| x = torch.cat([x, skip_connection], dim=1)
|
| x = self.maxvit(x, temb)
|
| x = self.upsample(x)
|
| for conv in self.conv:
|
| x = conv(x, temb)
|
| return x
|
|
|
| class Synthesis(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| channels: list[int],
|
| temb_channels: int,
|
| heads: int = 1,
|
| window_size: int = 7,
|
| window_attn: bool = True,
|
| grid_attn: bool = True,
|
| expansion_rate: int = 4,
|
| num_conv_blocks: int = 2,
|
| dropout: float = 0.0
|
| ):
|
| super(Synthesis, self).__init__()
|
|
|
|
|
| self.t_pos_encoding = SinusoidalPositionalEmbedding(temb_channels)
|
|
|
| self.input_blocks = nn.ModuleList([
|
| nn.Conv2d(3*in_channels + 4, channels[0], kernel_size=3, padding=1),
|
| ResGatedBlock(
|
| in_channels = channels[0],
|
| out_channels = channels[0],
|
| emb_channels = temb_channels,
|
| gated_conv = True
|
| )
|
| ])
|
|
|
| self.down_blocks = nn.ModuleList([
|
| UnetDownBlock(
|
|
|
| channels[i],
|
| channels[i + 1],
|
| temb_channels,
|
| heads = heads,
|
| window_size = window_size,
|
| window_attn = window_attn,
|
| grid_attn = grid_attn,
|
| expansion_rate = expansion_rate,
|
| num_conv_blocks = num_conv_blocks,
|
| dropout = dropout,
|
| ) for i in range(len(channels) - 1)
|
| ])
|
|
|
| self.middle_block = UnetMiddleBlock(
|
| in_channels = channels[-1],
|
| mid_channels = channels[-1],
|
| out_channels = channels[-1],
|
| temb_channels = temb_channels,
|
| heads = heads,
|
| window_size = window_size,
|
| window_attn = window_attn,
|
| grid_attn = grid_attn,
|
| expansion_rate = expansion_rate,
|
| dropout = dropout,
|
| )
|
|
|
| self.up_blocks = nn.ModuleList([
|
| UnetUpBlock(
|
| channels[i + 1],
|
| channels[i],
|
| temb_channels,
|
| heads = heads,
|
| window_size = window_size,
|
| window_attn = window_attn,
|
| grid_attn = grid_attn,
|
| expansion_rate = expansion_rate,
|
| num_conv_blocks = num_conv_blocks,
|
| dropout = dropout,
|
| ) for i in reversed(range(len(channels) - 1))
|
| ])
|
|
|
| self.output_blocks = nn.ModuleList([
|
| ResGatedBlock(
|
| in_channels = channels[0],
|
| out_channels = channels[0],
|
| emb_channels = temb_channels,
|
| gated_conv = True
|
| ),
|
| nn.Conv2d(channels[0], in_channels, kernel_size=3, padding=1)
|
| ])
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| warp0: list[torch.Tensor],
|
| warp1: list[torch.Tensor],
|
| temb: torch.Tensor
|
| ):
|
| temb = temb.unsqueeze(-1).type(torch.float)
|
| temb = self.t_pos_encoding(temb)
|
|
|
| x = self.input_blocks[0](torch.cat([x, warp0[0], warp1[0]], dim=1))
|
| x = self.input_blocks[1](x, temb)
|
|
|
| features = []
|
| for i, down_block in enumerate(self.down_blocks):
|
| x = down_block(x, warp0[i + 1], warp1[i + 1], temb)
|
| features.append(x)
|
|
|
| x = self.middle_block(x, temb)
|
|
|
| for i, up_block in enumerate(self.up_blocks):
|
| x = up_block(x, features[-(i + 1)], temb)
|
|
|
| x = self.output_blocks[0](x, temb)
|
| x = self.output_blocks[1](x)
|
|
|
| return x |