| | |
| | import torch |
| | import torch.nn as nn |
| |
|
| | EMBEDDING_SIZE = 64 |
| |
|
| | class EmbedDoodle(nn.Module): |
| | def __init__(self, embedding_size: int): |
| | |
| | |
| | super().__init__() |
| |
|
| | latent_size = 256 |
| | embed_depth = 5 |
| |
|
| | |
| | |
| | def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool): |
| | cell = nn.Sequential() |
| | cell.append(nn.Linear(in_size, hidden_size)) |
| | cell.append(nn.SELU()) |
| | cell.append(nn.Linear(hidden_size, hidden_size)) |
| | if add_dropout: |
| | cell.append(nn.Dropout()) |
| | cell.append(nn.SELU()) |
| | cell.append(nn.Linear(hidden_size, out_size)) |
| | return cell |
| |
|
| | self.preprocess = nn.Sequential( |
| | nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64), |
| | nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), |
| | nn.SELU(), |
| | nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), |
| | nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), |
| | nn.Dropout(), |
| | nn.SELU(), |
| | |
| |
|
| | nn.Flatten(), |
| | nn.Linear(36864, latent_size), |
| | nn.SELU(), |
| | ) |
| | |
| | self.embedding_path = nn.ModuleList() |
| | for i in range(0, embed_depth): |
| | self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True)) |
| | |
| | self.embedding_head = nn.Linear(latent_size, embedding_size) |
| |
|
| | def forward(self, x): |
| | x = x.view(-1, 1, 32, 32) |
| | |
| | x = self.preprocess(x) |
| | |
| | |
| | for c in self.embedding_path: |
| | x = x + c(x) |
| |
|
| | x = self.embedding_head(x) |
| | embedding = nn.functional.normalize(x, dim=-1) |
| | return embedding |
| |
|
| |
|