| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, channels): |
| super(ResidualBlock, self).__init__() |
| self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(channels) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(channels) |
|
|
| def forward(self, x): |
| residual = x |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
| out = self.bn2(out) |
| out += residual |
| out = self.relu(out) |
| return out |
|
|
| class Encoder(nn.Module): |
| def __init__(self, input_channels=1, hidden_dims=[64, 128, 256, 512, 1024], latent_dim=32): |
| super(Encoder, self).__init__() |
| self.hidden_dims = hidden_dims |
|
|
| |
| modules = [] |
| for h_dim in hidden_dims: |
| modules.append( |
| nn.Sequential( |
| nn.Conv2d(input_channels, h_dim, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(h_dim), |
| nn.LeakyReLU(), |
| ResidualBlock(h_dim) |
| ) |
| ) |
| input_channels = h_dim |
|
|
| self.encoder = nn.Sequential(*modules) |
| self.fc_mu = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) |
| self.fc_var = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) |
|
|
| def forward(self, x): |
| for layer in self.encoder: |
| x = layer(x) |
| x = torch.flatten(x, start_dim=1) |
| mu = self.fc_mu(x) |
| log_var = self.fc_var(x) |
| return mu, log_var |
|
|
| class Decoder(nn.Module): |
| def __init__(self, latent_dim=32, output_channels=1, hidden_dims=[64, 128, 256, 512, 1024]): |
| super(Decoder, self).__init__() |
| self.hidden_dims = hidden_dims |
| |
| hidden_dims = hidden_dims[::-1] |
| self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]*hidden_dims[2]) |
|
|
| |
| modules = [] |
| for i in range(len(hidden_dims) - 1): |
| modules.append( |
| nn.Sequential( |
| nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=3, stride=2, padding=1, output_padding=1), |
| nn.BatchNorm2d(hidden_dims[i+1]), |
| nn.LeakyReLU(), |
| ResidualBlock(hidden_dims[i+1]) |
| ) |
| ) |
|
|
| self.decoder = nn.Sequential(*modules) |
| self.final_layer = nn.Sequential( |
| nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), |
| nn.BatchNorm2d(hidden_dims[-1]), |
| nn.LeakyReLU(), |
| nn.Conv2d(hidden_dims[-1], output_channels, kernel_size=3, padding=1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, z): |
| z = self.decoder_input(z) |
| z = z.view(-1, 1024, 16, 16) |
| for layer in self.decoder: |
| z = layer(z) |
| result = self.final_layer(z) |
| return result |
|
|
| class VAE(nn.Module): |
| def __init__(self, |
| input_channels=1, |
| latent_dim=32, |
| hidden_dims=None): |
| super(VAE, self).__init__() |
|
|
| if hidden_dims is None: |
| hidden_dims = [64, 128, 256, 512, 1024] |
|
|
| self.encoder = Encoder(input_channels=input_channels, |
| hidden_dims=hidden_dims, |
| latent_dim=latent_dim) |
|
|
| self.decoder = Decoder(latent_dim=latent_dim, |
| output_channels=input_channels, |
| hidden_dims=hidden_dims) |
|
|
| def encode(self, input): |
| mu, log_var = self.encoder(input) |
| return mu, log_var |
|
|
| def reparameterize(self, mu, log_var): |
| std = torch.exp(0.5 * log_var) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
|
|
| def decode(self, z): |
| return self.decoder(z) |
|
|
| def forward(self, input): |
| mu, log_var = self.encode(input) |
| z = self.reparameterize(mu, log_var) |
| return self.decode(z), mu, log_var |
|
|
| |
| def loss_function(recon_x, x, mu, log_var): |
| BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') |
| KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) |
| return BCE + KLD |