| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| """Sinusoidal time embedding as used in ProtFlow paper.""" |
| |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
| |
| def forward(self, time): |
| device = time.device |
| half_dim = self.dim // 2 |
| embeddings = math.log(10000) / (half_dim - 1) |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
| |
| if time.dim() > 2: |
| time = time.squeeze() |
| embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
| |
| if embeddings.dim() > 2: |
| embeddings = embeddings.squeeze() |
| return embeddings |
|
|
| class LabelMLP(nn.Module): |
| """ |
| MLP for processing class labels into embeddings. |
| This approach processes labels separately from time embeddings. |
| """ |
| def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256): |
| super().__init__() |
| self.num_classes = num_classes |
| |
| |
| self.label_mlp = nn.Sequential( |
| nn.Embedding(num_classes, mlp_dim), |
| nn.Linear(mlp_dim, mlp_dim), |
| nn.GELU(), |
| nn.Linear(mlp_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim) |
| ) |
| |
| |
| nn.init.normal_(self.label_mlp[0].weight, std=0.02) |
| |
| def forward(self, labels): |
| """ |
| Args: |
| labels: (B,) tensor of class labels |
| - 0: AMP (MIC < 100) |
| - 1: Non-AMP (MIC >= 100) |
| - 2: Mask (Unknown MIC) |
| Returns: |
| embeddings: (B, hidden_dim) tensor of processed label embeddings |
| """ |
| return self.label_mlp(labels) |
|
|
| class AMPFlowMatcherCFGConcat(nn.Module): |
| """ |
| Flow Matching model with Classifier-Free Guidance using concatenation approach. |
| - 12-layer transformer with long skip connections |
| - Time embedding + MLP-processed label embedding (concatenated then projected) |
| - Optimized for peptide sequences (max length 50) |
| """ |
| |
| def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16, |
| dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.compressed_dim = compressed_dim |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
| self.use_cfg = use_cfg |
| |
| |
| self.time_embed = nn.Sequential( |
| SinusoidalTimeEmbedding(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim) |
| ) |
| |
| |
| if use_cfg: |
| self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim) |
| |
| |
| self.condition_proj = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim) |
| ) |
| |
| |
| self.compress_proj = nn.Linear(compressed_dim, hidden_dim) |
| self.decompress_proj = nn.Linear(hidden_dim, compressed_dim) |
| |
| |
| self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim)) |
| |
| |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=n_heads, |
| dim_feedforward=dim_ff, |
| dropout=dropout, |
| activation='gelu', |
| batch_first=True |
| ) for _ in range(n_layers) |
| ]) |
| |
| |
| self.skip_projections = nn.ModuleList([ |
| nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1) |
| ]) |
| |
| |
| self.output_proj = nn.Linear(hidden_dim, compressed_dim) |
| |
| def forward(self, x, t, labels=None, mask=None): |
| """ |
| Args: |
| x: compressed latent (B, L, compressed_dim) - AMP embeddings |
| t: time scalar (B,) or (B, 1) |
| labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask |
| mask: attention mask (B, L) if needed |
| """ |
| B, L, D = x.shape |
| |
| |
| x = self.compress_proj(x) |
| |
| |
| if L <= self.max_seq_len: |
| x = x + self.pos_embed[:, :L, :] |
| |
| |
| if t.dim() == 1: |
| t = t.unsqueeze(-1) |
| elif t.dim() > 2: |
| t = t.squeeze() |
| if t.dim() == 1: |
| t = t.unsqueeze(-1) |
| |
| t_emb = self.time_embed(t) |
| |
| if t_emb.dim() > 2: |
| t_emb = t_emb.squeeze() |
| t_emb = t_emb.unsqueeze(1).expand(-1, L, -1) |
| |
| |
| if self.use_cfg and labels is not None: |
| |
| label_emb = self.label_mlp(labels) |
| label_emb = label_emb.unsqueeze(1).expand(-1, L, -1) |
| |
| |
| combined_emb = torch.cat([t_emb, label_emb], dim=-1) |
| projected_emb = self.condition_proj(combined_emb) |
| else: |
| projected_emb = t_emb |
| |
| |
| skip_features = [] |
| |
| |
| for i, layer in enumerate(self.layers): |
| |
| if i > 0 and i < len(self.layers) - 1: |
| skip_feat = skip_features[i-1] |
| skip_feat = self.skip_projections[i-1](skip_feat) |
| x = x + skip_feat |
| |
| |
| if i < len(self.layers) - 1: |
| skip_features.append(x.clone()) |
| |
| |
| x = x + projected_emb |
| |
| |
| x = layer(x, src_key_padding_mask=mask) |
| |
| |
| x = self.output_proj(x) |
| |
| return x |
|
|
| class AMPProtFlowPipelineCFG: |
| """ |
| Complete ProtFlow pipeline for AMP generation with CFG. |
| """ |
| |
| def __init__(self, compressor, decompressor, flow_model, device='cuda'): |
| self.compressor = compressor |
| self.decompressor = decompressor |
| self.flow_model = flow_model |
| self.device = device |
| |
| |
| self.stats = torch.load('normalization_stats.pt', map_location=device) |
| |
| def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5, |
| condition_label=0): |
| """ |
| Generate AMP samples using CFG. |
| |
| Args: |
| num_samples: Number of samples to generate |
| num_steps: Number of ODE solving steps |
| cfg_scale: CFG guidance scale (higher = stronger conditioning) |
| condition_label: 0=AMP, 1=Non-AMP, 2=Mask |
| """ |
| print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...") |
| |
| |
| batch_size = min(num_samples, 32) |
| all_samples = [] |
| |
| for i in range(0, num_samples, batch_size): |
| current_batch = min(batch_size, num_samples - i) |
| |
| |
| eps = torch.randn(current_batch, self.flow_model.max_seq_len, |
| self.flow_model.compressed_dim, device=self.device) |
| |
| |
| xt = eps.clone() |
| for step in range(num_steps): |
| t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) |
| |
| |
| if cfg_scale > 0: |
| |
| vt_cond = self.flow_model(xt, t, |
| labels=torch.full((current_batch,), condition_label, |
| device=self.device)) |
| |
| |
| vt_uncond = self.flow_model(xt, t, |
| labels=torch.full((current_batch,), 2, |
| device=self.device)) |
| |
| |
| vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
| else: |
| |
| vt = self.flow_model(xt, t, |
| labels=torch.full((current_batch,), 2, |
| device=self.device)) |
| |
| |
| |
| dt = -1.0 / num_steps |
| xt = xt + vt * dt |
| |
| all_samples.append(xt) |
| |
| |
| generated = torch.cat(all_samples, dim=0) |
| |
| |
| with torch.no_grad(): |
| |
| decompressed = self.decompressor(generated) |
| |
| |
| m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] |
| decompressed = decompressed * (mx - mn + 1e-8) + mn |
| decompressed = decompressed * s + m |
| |
| return generated, decompressed |
|
|
| |
| if __name__ == "__main__": |
| |
| flow_model = AMPFlowMatcherCFGConcat( |
| hidden_dim=480, |
| compressed_dim=30, |
| n_layers=12, |
| n_heads=16, |
| dim_ff=3072, |
| max_seq_len=25, |
| use_cfg=True |
| ) |
| |
| print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}") |
| |
| |
| batch_size = 4 |
| seq_len = 20 |
| compressed_dim = 30 |
| |
| x = torch.randn(batch_size, seq_len, compressed_dim) |
| t = torch.rand(batch_size) |
| labels = torch.randint(0, 3, (batch_size,)) |
| |
| with torch.no_grad(): |
| output = flow_model(x, t, labels=labels) |
| print(f"Input shape: {x.shape}") |
| print(f"Output shape: {output.shape}") |
| print(f"Time embedding shape: {t.shape}") |
| print(f"Labels: {labels}") |
| |
| print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!") |