| | from torch import nn |
| | import torch |
| | import numpy as np |
| | from transformers import AutoModel |
| | import torch.nn.functional as F |
| | import esm |
| | import copy |
| | import pdb |
| |
|
| | class GaussianFourierProjection(nn.Module): |
| | """ |
| | Gaussian random features for encoding time steps. |
| | """ |
| |
|
| | def __init__(self, embed_dim, scale=30.): |
| | super().__init__() |
| | |
| | |
| | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) |
| |
|
| | def forward(self, x): |
| | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi |
| | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
| |
|
| | class Dense(nn.Module): |
| | """ |
| | A fully connected layer that reshapes outputs to feature maps. |
| | """ |
| |
|
| | def __init__(self, input_dim, output_dim): |
| | super().__init__() |
| | self.dense = nn.Linear(input_dim, output_dim) |
| |
|
| | def forward(self, x): |
| | return self.dense(x)[...] |
| |
|
| | class Swish(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x): |
| | return torch.sigmoid(x) * x |
| |
|
| | class CNNESMModel(nn.Module): |
| | """A time-dependent score-based model built upon U-Net architecture.""" |
| |
|
| | def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): |
| | """ |
| | Args: |
| | embed_dim (int): Dimensionality of the token and time embeddings. |
| | """ |
| | super().__init__() |
| | self.alphabet_size = alphabet_size |
| | |
| | |
| | self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | self.esm.eval() |
| | for param in self.esm.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.time_embed = nn.Sequential( |
| | GaussianFourierProjection(embed_dim=embed_dim), |
| | nn.Linear(embed_dim, embed_dim) |
| | ) |
| |
|
| | self.swish = Swish() |
| | |
| | n = hidden_dim |
| | |
| | self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) |
| | |
| | self.blocks = nn.ModuleList([ |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ]) |
| | |
| | self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) |
| | self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) |
| |
|
| | self.final = nn.Sequential( |
| | nn.Conv1d(n, n, kernel_size=1), |
| | nn.GELU(), |
| | nn.Conv1d(n, self.alphabet_size, kernel_size=1) |
| | ) |
| |
|
| |
|
| | def forward(self, x, t): |
| | """ |
| | Args: |
| | x: Tensor of shape (B, L) containing DNA token indices. |
| | t: Tensor of shape (B,) containing the time steps. |
| | Returns: |
| | out: Tensor of shape (B, L, 4) with output logits for each DNA base. |
| | """ |
| | |
| | with torch.no_grad(): |
| | x = self.esm(input_ids=x).last_hidden_state |
| | time_embed = self.swish(self.time_embed(t)) |
| | |
| | out = x.permute(0, 2, 1) |
| | out = self.swish(self.linear(out)) |
| | |
| | |
| | for block, dense, norm in zip(self.blocks, self.denses, self.norms): |
| | |
| | h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) |
| | |
| | if h.shape == out.shape: |
| | out = h + out |
| | else: |
| | out = h |
| | |
| | out = self.final(out) |
| | out = out.permute(0, 2, 1) |
| | |
| | |
| | out = out - out.mean(dim=-1, keepdim=True) |
| | return out |
| |
|
| |
|
| | class MLPModel(nn.Module): |
| | def __init__( |
| | self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.time_dim = time_dim |
| | self.hidden_dim = hidden_dim |
| |
|
| | self.time_embedding = nn.Linear(1, time_dim) |
| | self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim) |
| |
|
| | self.swish = Swish() |
| |
|
| | self.main = nn.Sequential( |
| | self.swish, |
| | nn.Linear(hidden_dim * length + time_dim, hidden_dim), |
| | self.swish, |
| | nn.Linear(hidden_dim, hidden_dim), |
| | self.swish, |
| | nn.Linear(hidden_dim, hidden_dim), |
| | self.swish, |
| | nn.Linear(hidden_dim, self.input_dim * length), |
| | ) |
| |
|
| | def forward(self, x, t): |
| | ''' |
| | x shape (B,L) |
| | t shape (B,) |
| | ''' |
| | t = self.time_embedding(t.unsqueeze(-1)) |
| | x = self.token_embedding(x) |
| |
|
| | B, N, d = x.shape |
| | x = x.reshape(B, N * d) |
| | |
| | h = torch.cat([x, t], dim=1) |
| | h = self.main(h) |
| |
|
| | h = h.reshape(B, N, self.input_dim) |
| |
|
| | return h |
| |
|
| | class CNNModel(nn.Module): |
| | """A time-dependent score-based model built upon U-Net architecture.""" |
| |
|
| | def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): |
| | """ |
| | Args: |
| | embed_dim (int): Dimensionality of the token and time embeddings. |
| | """ |
| | super().__init__() |
| | self.alphabet_size = alphabet_size |
| | |
| | self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) |
| | |
| | |
| | |
| | |
| |
|
| | self.time_embed = nn.Sequential( |
| | GaussianFourierProjection(embed_dim=embed_dim), |
| | nn.Linear(embed_dim, embed_dim) |
| | ) |
| |
|
| | self.swish = Swish() |
| | |
| | n = hidden_dim |
| | |
| | self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) |
| | |
| | self.blocks = nn.ModuleList([ |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ]) |
| | |
| | self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)]) |
| | self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)]) |
| |
|
| | self.final = nn.Sequential( |
| | nn.Conv1d(n, n, kernel_size=1), |
| | nn.GELU(), |
| | nn.Conv1d(n, self.alphabet_size, kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x, t): |
| | """ |
| | Args: |
| | x: Tensor of shape (B, L) containing DNA token indices. |
| | t: Tensor of shape (B,) containing the time steps. |
| | Returns: |
| | out: Tensor of shape (B, L, 4) with output logits for each DNA base. |
| | """ |
| | x = self.token_embedding(x) |
| | |
| | |
| | time_embed = self.swish(self.time_embed(t)) |
| | |
| | out = x.permute(0, 2, 1) |
| | out = self.swish(self.linear(out)) |
| | |
| | |
| | for block, dense, norm in zip(self.blocks, self.denses, self.norms): |
| | |
| | h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) |
| | |
| | if h.shape == out.shape: |
| | out = h + out |
| | else: |
| | out = h |
| | |
| | out = self.final(out) |
| | out = out.permute(0, 2, 1) |
| | |
| | |
| | out = out - out.mean(dim=-1, keepdim=True) |
| | return out |
| |
|
| | class CNNModel_Large(nn.Module): |
| | """A time-dependent score-based model built upon U-Net architecture.""" |
| |
|
| | def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256): |
| | """ |
| | Args: |
| | embed_dim (int): Dimensionality of the token and time embeddings. |
| | """ |
| | super().__init__() |
| | self.alphabet_size = alphabet_size |
| | |
| | self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim) |
| |
|
| | self.time_embed = nn.Sequential( |
| | GaussianFourierProjection(embed_dim=embed_dim), |
| | nn.Linear(embed_dim, embed_dim) |
| | ) |
| |
|
| | self.swish = Swish() |
| | |
| | n = hidden_dim |
| | |
| | self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4) |
| | |
| | self.blocks = nn.ModuleList([ |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, padding=4), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64), |
| | nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256) |
| | ]) |
| | |
| | self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(20)]) |
| | self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(20)]) |
| |
|
| | self.final = nn.Sequential( |
| | nn.Conv1d(n, n, kernel_size=1), |
| | nn.GELU(), |
| | nn.Conv1d(n, self.alphabet_size, kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x, t): |
| | """ |
| | Args: |
| | x: Tensor of shape (B, L) containing DNA token indices. |
| | t: Tensor of shape (B,) containing the time steps. |
| | Returns: |
| | out: Tensor of shape (B, L, 4) with output logits for each DNA base. |
| | """ |
| | x = self.token_embedding(x) |
| | time_embed = self.swish(self.time_embed(t)) |
| | |
| | out = x.permute(0, 2, 1) |
| | out = self.swish(self.linear(out)) |
| | |
| | |
| | for block, dense, norm in zip(self.blocks, self.denses, self.norms): |
| | |
| | h = self.swish(block(norm(out + dense(time_embed)[:, :, None]))) |
| | |
| | if h.shape == out.shape: |
| | out = h + out |
| | else: |
| | out = h |
| | |
| | out = self.final(out) |
| | out = out.permute(0, 2, 1) |
| | |
| | |
| | out = out - out.mean(dim=-1, keepdim=True) |
| | return out |