| |
| |
| |
| |
|
|
| from collections import OrderedDict |
| import math |
| import requests |
| from io import BytesIO |
| from functools import partial |
| import pickle |
| from typing import Callable, Optional, Sequence, Tuple, List |
| import numpy as np |
| import os |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.nn.init import trunc_normal_ |
| from torchvision import transforms |
| from torchvision.transforms import InterpolationMode |
|
|
| class GLU(nn.Module): |
| def __init__(self,hidden_size): |
| super().__init__() |
| self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False) |
| self.norm1 = nn.LayerNorm(hidden_size) |
| self.act1 = nn.GELU() |
| self.act2 = nn.functional.silu |
| self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False) |
| self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False) |
| self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False) |
|
|
| def forward(self,x): |
| x = self.linear_proj(x) |
| x = self.act1(self.norm1(x)) |
| x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x) |
| x = self.dense_4h_to_h(x) |
| return x |
| def swiglu(x): |
| x = torch.chunk(x, 2, dim=-1) |
| return nn.functional.silu(x[0]) * x[1] |
|
|
| class GLU_new(nn.Module): |
| def __init__(self,hidden_size, dropout=0.1): |
| super().__init__() |
| intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64 |
| intermediate_size = 1280 |
|
|
| self.act = swiglu |
| self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False) |
| self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False) |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| def forward(self,x): |
| x = self.dense_h_to_4h(x) |
| x = self.act(x) |
| x = self.dense_4h_to_h(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| n_queries = 32 |
| def get_abs_pos(abs_pos, tgt_size): |
| |
| |
| |
| src_size = int(math.sqrt(abs_pos.size(0))) |
| tgt_size = int(math.sqrt(tgt_size)) |
| dtype = abs_pos.dtype |
|
|
| if src_size != tgt_size: |
| return F.interpolate( |
| abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), |
| size=(tgt_size, tgt_size), |
| mode="bicubic", |
| align_corners=False, |
| ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) |
| else: |
| return abs_pos |
|
|
| from einops import rearrange, repeat |
|
|
| def get_1d_sincos_pos_embed(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=np.float32) |
| omega /= embed_dim / 2. |
| omega = 1. / 10000**omega |
|
|
| pos = pos.reshape(-1) |
| out = np.einsum('m,d->md', pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
| class Resampler(nn.Module): |
| def __init__( |
| self, |
| kv_dim, |
| embed_dim, |
| num_heads=8, |
| n_queries=64, |
| max_seqlen=1024, |
| perceiver_resampler_positional_emb=True, |
| use_GLU=False, |
| bos_init=False, |
| dropout=0.0 |
| ): |
| super().__init__() |
| self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb |
|
|
| if self.perceiver_resampler_positional_emb: |
| assert n_queries <= max_seqlen |
| self.stride = max_seqlen // n_queries |
| |
| |
| pos = np.arange(max_seqlen, dtype=np.float32) |
| self.register_buffer( |
| "pos_embed", |
| torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float() |
| ) |
| self.latents = nn.Parameter(torch.randn(n_queries, embed_dim)) |
| if bos_init: |
| self.latents.load('') |
| else: |
| nn.init.trunc_normal_(self.latents, std=1e-3) |
|
|
| self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) |
| self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout) |
| self.ln_q = nn.LayerNorm(embed_dim) |
| self.ln_kv = nn.LayerNorm(embed_dim) |
| self.ln_post = nn.LayerNorm(embed_dim) |
| if use_GLU: |
| print('GLU *********************************') |
| self.proj = GLU_new(embed_dim, dropout=dropout) |
| else: |
| self.proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
| self.apply(self._init_weights) |
| |
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=1e-3) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def forward(self, struc_x): |
| """ |
| Args: |
| x (torch.Tensor): protein structure features |
| shape (B, L, C) |
| Returns: |
| shape (B, n, C) where n is self.num_latents |
| """ |
| x = struc_x["encoder_out"] |
| mask = struc_x["encoder_padding_mask"] |
|
|
|
|
| nan_mask = torch.isnan(x) |
| if nan_mask.any(): |
| x = x.masked_fill(nan_mask, 0.0) |
| |
| |
|
|
| x = self.kv_proj(x) |
| x = self.ln_kv(x) |
|
|
| b, seqlen = x.shape[:2] |
|
|
| latents = self.ln_q(self.latents) |
| if self.perceiver_resampler_positional_emb: |
| |
| latents = latents + self.pos_embed[::self.stride].contiguous() |
| pos_emb = self.pos_embed[:seqlen].unsqueeze(0) |
| x = x + pos_emb.contiguous() |
| |
| |
| latents = repeat(latents, "n d -> b n d", b=b) |
| out = self.attn(latents, x, x, key_padding_mask=~mask)[0] |
|
|
| out = self.ln_post(out) |
| out = self.proj(out) |
|
|
| return out |
|
|
| class StructureTransformer(nn.Module): |
|
|
| def __init__( |
| self, |
| width: int = 640, |
| n_queries: int = 32, |
| output_dim: int = 4096, |
| embedding_keys=set(["mpnn_emb"]), |
| max_seqlen: int=1024, |
| num_heads: int=8, |
| structure_emb_path_prefix='structure_emb', |
| **kwargs |
| ): |
| super().__init__() |
|
|
| self.structure_emb_path_prefix = structure_emb_path_prefix |
| |
| self.embedding_keys = embedding_keys |
| self.max_seqlen = max_seqlen |
| self.width = width |
| self.n_queries = n_queries |
|
|
| self.attn_pool = Resampler( |
| embed_dim=output_dim, |
| kv_dim=width, |
| n_queries=n_queries, |
| max_seqlen=max_seqlen, |
| num_heads=num_heads, |
| **kwargs |
| ) |
|
|
| def prepare_structure(self, sample): |
| emb_pad = torch.zeros((self.max_seqlen, self.width)) |
| emb_mask = torch.zeros((self.max_seqlen), dtype=bool) |
| |
| if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample: |
| mask = sample["pifold_mask"] |
| pifold_emb = sample["pifold_emb"] |
| new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan")) |
| new_pifold_emb[mask > 0] = pifold_emb |
| sample["pifold_emb"] = new_pifold_emb |
| |
| |
| emb = [] |
| for ek in self.embedding_keys: |
| if ek in sample: |
| if isinstance( sample[ek], List): |
| emb.append(torch.cat(sample[ek])) |
| else: |
| emb.append(sample[ek]) |
| |
| emb = torch.cat(emb, dim=-1) |
| |
| emb_pad[:len(emb)] = emb |
| emb_mask[:len(emb)] = 1 |
| return emb_pad, emb_mask |
|
|
| def forward(self, x): |
|
|
| |
| x = self.attn_pool(x) |
|
|
| return x |
|
|
| def encode(self, structure_paths: List[str]): |
| structure_embs = [] |
| structure_mask = [] |
|
|
| for structure_path in structure_paths: |
| structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0] |
| structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path)) |
| if not os.path.exists(structure_path): |
| print('no structure found') |
| return None |
| |
| with open(structure_path, 'rb') as f: |
| structure, struc_mask = self.prepare_structure(pickle.load(f)) |
| |
|
|
| structure_embs.append(structure) |
| structure_mask.append(struc_mask) |
|
|
| structure_embs = torch.stack(structure_embs, dim=0).to( |
| device=next(self.attn_pool.parameters()).device, |
| dtype=next(self.attn_pool.parameters()).dtype) |
| structure_mask = torch.stack(structure_mask, dim=0).to( |
| device=next(self.attn_pool.parameters()).device) |
|
|
| return self({ |
| 'encoder_out': structure_embs, |
| 'encoder_padding_mask': structure_mask |
| }) |
|
|