"""Pytorch implementation of AESc model architecture""" __author__ = "Jonas Rabensteiner" import torch from torch import nn from torchvision.transforms.functional import crop # create model def resize_layer(conv, deconv): """If needed resize the feature map of the deconvolution to the size of the corresponding feature map from the encoder part. Args: conv (tensor): feature map of the encoder part deconv (tensor): corresponding feature map of the decoder part Returns: tensor: resized decoder feature map """ height = deconv.shape[2] width= deconv.shape[3] if deconv.shape[2] > conv.shape[2]: deconv = crop(deconv, top=0, left=0, height=height-1, width=width) #nn.Cropping2D(cropping=((0, 1), (0, 0)))(deconv) if deconv.shape[3] > conv.shape[3]: deconv = crop(deconv, top=0, left=0, height=height, width=width-1) #nn.Cropping2D(cropping=((0, 0), (0, 1)))(deconv) return deconv class conv_block(nn.Module): """Layer "block" of 2D (de)convolution, batch normalization and activation""" def __init__(self, in_c, out_c, kernel_size=5, stride=True, activation=nn.LeakyReLU()): super().__init__() if not stride: self.conv = nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding="same") else: self.conv = nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=2, padding=2) self.bn = nn.BatchNorm2d(out_c) self.activation = activation def forward(self, inputs): x = self.conv(inputs) x = self.bn(x) x = self.activation(x) return x class encoder_block(nn.Module): """Layer "block" of conv_block and dropout, which corresponds to one "step" of the encoder""" def __init__(self, in_c, out_c, kernel_size=5, activation=nn.LeakyReLU(), dropout_rate=0.0): super().__init__() self.conv = conv_block(in_c, out_c, kernel_size, activation) self.dropout = nn.Dropout(dropout_rate) def forward(self, inputs): x = self.conv(inputs) x = self.dropout(x) return x class decoder_block(nn.Module): """Layer "block" of upsampling, conv_block, dropout and skip connections which corresponds to one "step" of the decoder""" def __init__(self, in_c, out_c, kernel_size=5, activation=nn.LeakyReLU(),dropout_rate=0.0, skip_connections=True): super().__init__() self.up = nn.UpsamplingNearest2d(scale_factor=2) self.conv = conv_block(in_c, out_c, kernel_size, stride=False) self.dropout = nn.Dropout(dropout_rate) self.skip_connections = skip_connections def forward(self, inputs, skip): #print(inputs.shape) x = self.conv(inputs) #print("skip shape", skip.shape) x = resize_layer(skip, x) #print("x resized shape", x.shape) if self.skip_connections: x = skip + x x = self.up(x) x = self.dropout(x) #print(x.shape) return x class AESc(nn.Module): """Autoencoder with skip connections (AESc) according to the original paper by Anne-Sophie Collin.""" def __init__(self, cmap = "rgb", kernel_size=5, activation=nn.LeakyReLU(), dropout_rate=0.0): """Instantiates the model layers Args: cmap (str, optional): color map to use for the model. Either "rgb" or "gray". Defaults to "rgb". kernel_size (int, optional): kernel size. Defaults to 5. activation (nn.Module, optional): activation function. Defaults to nn.LeakyReLU(). dropout_rate (float, optional): dropout rate. Defaults to 0.0. """ super().__init__() #Encoder if cmap == "gray": self.e1 = encoder_block(1, 16, kernel_size, activation, dropout_rate) else: self.e1 = encoder_block(3, 16, kernel_size, activation, dropout_rate) self.e2 = encoder_block(16, 32, kernel_size, activation, dropout_rate) self.e3 = encoder_block(32, 64, kernel_size, activation, dropout_rate) self.e4 = encoder_block(64, 128, kernel_size, activation, dropout_rate) self.e5 = encoder_block(128, 256, kernel_size, activation, dropout_rate) self.e6 = encoder_block(256, 512, kernel_size, activation, dropout_rate) #Decoder self.d1 = nn.UpsamplingNearest2d(scale_factor=2) self.d2 = decoder_block(512, 256, kernel_size, activation, dropout_rate) self.d3 = decoder_block(256, 128, kernel_size, activation, dropout_rate) self.d4 = decoder_block(128, 64, kernel_size, activation, dropout_rate) self.d5 = decoder_block(64, 32, kernel_size, activation, dropout_rate) self.d6 = decoder_block(32, 16, kernel_size, activation, dropout_rate) #Output if cmap == "gray": self.outputs = conv_block(16, 1, stride=False, activation=nn.Identity()) else: self.outputs = conv_block(16, 3, stride=False, activation=nn.Identity()) self.sigmoid = nn.Sigmoid() def forward(self, inputs): #Encoder p1 = self.e1(inputs) p2 = self.e2(p1) p3 = self.e3(p2) p4 = self.e4(p3) p5 = self.e5(p4) p6 = self.e6(p5) #Decoder d1 = self.d1(p6) d2 = self.d2(d1, p5) d3 = self.d3(d2, p4) d4 = self.d4(d3, p3) d5 = self.d5(d4, p2) d6 = self.d6(d5, p1) #Output outputs = self.outputs(d6) outputs = self.sigmoid(outputs) outputs = resize_layer(inputs, outputs) return outputs