| | from collections.abc import Iterable |
| | from abc import abstractmethod |
| | import random |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| |
|
| | from src.constants import INT_TYPE |
| | from src.model.gvp import GVPModel, GVP, LayerNorm |
| | from src.model.gvp_transformer import GVPTransformerModel |
| | from src.constants import FLOAT_TYPE |
| |
|
| | from pdb import set_trace |
| |
|
| |
|
| | def binomial_coefficient(n, k): |
| | |
| | return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp() |
| |
|
| |
|
| | def cycle_counts(adj): |
| | assert (adj.diag() == 0).all() |
| | assert (adj == adj.T).all() |
| |
|
| | A = adj.float() |
| | d = A.sum(dim=-1) |
| |
|
| | |
| | A2 = A @ A |
| | A3 = A2 @ A |
| | A4 = A3 @ A |
| | A5 = A4 @ A |
| |
|
| | x3 = A3.diag() / 2 |
| | x4 = (A4.diag() - d * (d - 1) - A @ d) / 2 |
| |
|
| | """ New (different from DiGress) |
| | case where correction is relevant: |
| | 2 o |
| | | |
| | 1,3 o--o 4 |
| | | / |
| | 0,5 o |
| | """ |
| | |
| | T = adj * A2 |
| | x5 = (A5.diag() - 2 * T @ d - 4 * d * x3 - 2 * A @ x3 + 10 * x3) / 2 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return torch.stack([x3, x4, x5], dim=-1) |
| |
|
| |
|
| | |
| | |
| | |
| | def eigenfeatures(A, batch_mask, k=5): |
| | |
| | |
| | |
| |
|
| | |
| | batch = [] |
| | for i in torch.unique(batch_mask, sorted=True): |
| | batch_inds = torch.where(batch_mask == i)[0] |
| | batch.append(A[torch.meshgrid(batch_inds, batch_inds, indexing='ij')]) |
| |
|
| | eigenfeats = [get_nontrivial_eigenvectors(adj)[:, :k] for adj in batch] |
| | |
| | eigenfeats = [torch.cat([ |
| | x, torch.zeros(x.size(0), max(k - x.size(1), 0), device=x.device)], dim=-1) |
| | for x in eigenfeats] |
| | return torch.cat(eigenfeats, dim=0) |
| |
|
| |
|
| | def get_nontrivial_eigenvectors(A, normalize_l=True, thresh=1e-5, |
| | norm_eps=1e-12): |
| | """ |
| | Compute eigenvectors of the graph Laplacian corresponding to non-zero |
| | eigenvalues. |
| | """ |
| | assert (A == A.T).all(), "undirected graph" |
| |
|
| | |
| | d = A.sum(-1) |
| | D = d.diag() |
| | L = D - A |
| |
|
| | if normalize_l: |
| | D_inv_sqrt = (1 / (d.sqrt() + norm_eps)).diag() |
| | L = D_inv_sqrt @ L @ D_inv_sqrt |
| |
|
| | |
| | |
| | |
| | eigvals, eigvecs = torch.linalg.eigh(L) |
| |
|
| | |
| | try: |
| | idx = torch.nonzero(eigvals > thresh)[0].item() |
| | except IndexError: |
| | |
| | idx = eigvecs.size(1) |
| |
|
| | return eigvecs[:, idx:] |
| |
|
| |
|
| | class DynamicsBase(nn.Module): |
| | """ |
| | Implements self-conditioning logic and basic functions |
| | """ |
| | def __init__( |
| | self, |
| | predict_angles=False, |
| | predict_frames=False, |
| | add_cycle_counts=False, |
| | add_spectral_feat=False, |
| | self_conditioning=False, |
| | augment_residue_sc=False, |
| | augment_ligand_sc=False |
| | ): |
| | super().__init__() |
| |
|
| | if not hasattr(self, 'predict_angles'): |
| | self.predict_angles = predict_angles |
| |
|
| | if not hasattr(self, 'predict_frames'): |
| | self.predict_frames = predict_frames |
| |
|
| | if not hasattr(self, 'add_cycle_counts'): |
| | self.add_cycle_counts = add_cycle_counts |
| |
|
| | if not hasattr(self, 'add_spectral_feat'): |
| | self.add_spectral_feat = add_spectral_feat |
| |
|
| | if not hasattr(self, 'self_conditioning'): |
| | self.self_conditioning = self_conditioning |
| |
|
| | if not hasattr(self, 'augment_residue_sc'): |
| | self.augment_residue_sc = augment_residue_sc |
| |
|
| | if not hasattr(self, 'augment_ligand_sc'): |
| | self.augment_ligand_sc = augment_ligand_sc |
| |
|
| | if self.self_conditioning: |
| | self.prev_ligand = None |
| | self.prev_residues = None |
| |
|
| | @abstractmethod |
| | def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, |
| | h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): |
| | """ |
| | Implement forward pass. |
| | Returns: |
| | - vel |
| | - h_final_atoms |
| | - edge_final_atoms |
| | - residue_angles |
| | - residue_trans |
| | - residue_rot |
| | """ |
| | pass |
| |
|
| | def make_sc_input(self, pred_ligand, pred_residues, sc_transform): |
| |
|
| | if self.predict_confidence: |
| | h_atoms_sc = (torch.cat([pred_ligand['logits_h'], pred_ligand['uncertainty_vel'].unsqueeze(1)], dim=-1), |
| | pred_ligand['vel'].unsqueeze(1)) |
| | else: |
| | h_atoms_sc = (pred_ligand['logits_h'], pred_ligand['vel'].unsqueeze(1)) |
| | e_atoms_sc = pred_ligand['logits_e'] |
| |
|
| | if self.predict_frames: |
| | h_residues_sc = (torch.cat([pred_residues['chi'], pred_residues['rot']], dim=-1), |
| | pred_residues['trans'].unsqueeze(1)) |
| | elif self.predict_angles: |
| | h_residues_sc = pred_residues['chi'] |
| | else: |
| | h_residues_sc = None |
| |
|
| | if self.augment_residue_sc and h_residues_sc is not None: |
| | if self.predict_frames: |
| | h_residues_sc = (h_residues_sc[0], torch.cat( |
| | [h_residues_sc[1], sc_transform['residues'](pred_residues['chi'], pred_residues['trans'].squeeze(1), pred_residues['rot'])], dim=1)) |
| |
|
| | else: |
| | h_residues_sc = (h_residues_sc, sc_transform['residues'](pred_residues['chi'])) |
| |
|
| | if self.augment_ligand_sc: |
| | h_atoms_sc = (h_atoms_sc[0], torch.cat( |
| | [h_atoms_sc[1], sc_transform['atoms'](pred_ligand['vel'].unsqueeze(1))], dim=1)) |
| |
|
| | return h_atoms_sc, e_atoms_sc, h_residues_sc |
| |
|
| | def forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, sc_transform=None): |
| | """ |
| | Implements self-conditioning as in https://arxiv.org/abs/2208.04202 |
| | """ |
| |
|
| | h_atoms_sc, e_atoms_sc = None, None |
| | h_residues_sc = None |
| |
|
| | if self.self_conditioning: |
| |
|
| | |
| | if not self.training and t.min() > 0.0: |
| | assert t.min() == t.max(), "currently only supports sampling at same time steps" |
| | assert self.prev_ligand is not None |
| | assert self.prev_residues is not None or not self.predict_frames |
| |
|
| | else: |
| | |
| | zeros_ligand = {'logits_h': torch.zeros_like(h_atoms), |
| | 'vel': torch.zeros_like(x_atoms), |
| | 'logits_e': torch.zeros_like(bonds_ligand[1])} |
| | if self.predict_confidence: |
| | zeros_ligand['uncertainty_vel'] = torch.zeros( |
| | len(x_atoms), dtype=x_atoms.dtype, device=x_atoms.device) |
| |
|
| | zeros_residues = {} |
| | if self.predict_angles: |
| | zeros_residues['chi'] = torch.zeros((pocket['one_hot'].size(0), 5), device=pocket['one_hot'].device) |
| | if self.predict_frames: |
| | zeros_residues['trans'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device) |
| | zeros_residues['rot'] = torch.zeros((pocket['one_hot'].size(0), 3), device=pocket['one_hot'].device) |
| |
|
| | |
| | if self.training and random.random() > 0.5: |
| | with torch.no_grad(): |
| | h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input( |
| | zeros_ligand, zeros_residues, sc_transform) |
| |
|
| | self.prev_ligand, self.prev_residues = self._forward( |
| | x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand, |
| | h_atoms_sc, e_atoms_sc, h_residues_sc) |
| |
|
| | |
| | else: |
| | self.prev_ligand = zeros_ligand |
| | self.prev_residues = zeros_residues |
| |
|
| | h_atoms_sc, e_atoms_sc, h_residues_sc = self.make_sc_input( |
| | self.prev_ligand, self.prev_residues, sc_transform) |
| |
|
| | pred_ligand, pred_residues = self._forward( |
| | x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand, |
| | h_atoms_sc, e_atoms_sc, h_residues_sc |
| | ) |
| |
|
| | if self.self_conditioning and not self.training: |
| | self.prev_ligand = pred_ligand.copy() |
| | self.prev_residues = pred_residues.copy() |
| |
|
| | return pred_ligand, pred_residues |
| |
|
| | def compute_extra_features(self, batch_mask, edge_indices, edge_types): |
| |
|
| | feat = torch.zeros(len(batch_mask), 0, device=batch_mask.device) |
| |
|
| | if not (self.add_cycle_counts or self.add_spectral_feat): |
| | return feat |
| |
|
| | adj = batch_mask[:, None] == batch_mask[None, :] |
| |
|
| | E = torch.zeros_like(adj, dtype=INT_TYPE) |
| | E[edge_indices[0], edge_indices[1]] = edge_types |
| |
|
| | A = (E > 0).float() |
| |
|
| | if self.add_cycle_counts: |
| | cycle_features = cycle_counts(A) |
| | cycle_features[cycle_features > 10] = 10 |
| |
|
| | feat = torch.cat([feat, cycle_features], dim=-1) |
| |
|
| | if self.add_spectral_feat: |
| | feat = torch.cat([feat, eigenfeatures(A, batch_mask)], dim=-1) |
| |
|
| | return feat |
| |
|
| |
|
| | class Dynamics(DynamicsBase): |
| | def __init__(self, atom_nf, residue_nf, joint_nf, bond_dict, pocket_bond_dict, |
| | edge_nf, hidden_nf, act_fn=torch.nn.SiLU(), condition_time=True, |
| | model='egnn', model_params=None, |
| | edge_cutoff_ligand=None, edge_cutoff_pocket=None, |
| | edge_cutoff_interaction=None, |
| | predict_angles=False, predict_frames=False, |
| | add_cycle_counts=False, add_spectral_feat=False, |
| | add_nma_feat=False, self_conditioning=False, |
| | augment_residue_sc=False, augment_ligand_sc=False, |
| | add_chi_as_feature=False, angle_act_fn=False): |
| | super().__init__() |
| | self.model = model |
| | self.edge_cutoff_l = edge_cutoff_ligand |
| | self.edge_cutoff_p = edge_cutoff_pocket |
| | self.edge_cutoff_i = edge_cutoff_interaction |
| | self.hidden_nf = hidden_nf |
| | self.predict_angles = predict_angles |
| | self.predict_frames = predict_frames |
| | self.bond_dict = bond_dict |
| | self.pocket_bond_dict = pocket_bond_dict |
| | self.bond_nf = len(bond_dict) |
| | self.pocket_bond_nf = len(pocket_bond_dict) |
| | self.edge_nf = edge_nf |
| | self.add_cycle_counts = add_cycle_counts |
| | self.add_spectral_feat = add_spectral_feat |
| | self.add_nma_feat = add_nma_feat |
| | self.self_conditioning = self_conditioning |
| | self.augment_residue_sc = augment_residue_sc |
| | self.augment_ligand_sc = augment_ligand_sc |
| | self.add_chi_as_feature = add_chi_as_feature |
| | self.predict_confidence = False |
| |
|
| | if self.self_conditioning: |
| | self.prev_vel = None |
| | self.prev_h = None |
| | self.prev_e = None |
| | self.prev_a = None |
| | self.prev_ca = None |
| | self.prev_rot = None |
| |
|
| | lig_nf = atom_nf |
| | if self.add_cycle_counts: |
| | lig_nf = lig_nf + 3 |
| | if self.add_spectral_feat: |
| | lig_nf = lig_nf + 5 |
| |
|
| |
|
| | if not isinstance(joint_nf, Iterable): |
| | |
| | joint_nf = (joint_nf, 0) |
| |
|
| |
|
| | if isinstance(residue_nf, Iterable): |
| | _atom_in_nf = (lig_nf, 0) |
| | _residue_atom_dim = residue_nf[1] |
| |
|
| | if self.add_nma_feat: |
| | residue_nf = (residue_nf[0], residue_nf[1] + 5) |
| |
|
| | if self.self_conditioning: |
| | _atom_in_nf = (_atom_in_nf[0] + atom_nf, 1) |
| |
|
| | if self.augment_ligand_sc: |
| | _atom_in_nf = (_atom_in_nf[0], _atom_in_nf[1] + 1) |
| |
|
| | if self.predict_angles: |
| | residue_nf = (residue_nf[0] + 5, residue_nf[1]) |
| |
|
| | if self.predict_frames: |
| | residue_nf = (residue_nf[0], residue_nf[1] + 2) |
| |
|
| | if self.augment_residue_sc: |
| | assert self.predict_angles |
| | residue_nf = (residue_nf[0], residue_nf[1] + _residue_atom_dim) |
| |
|
| | if self.add_chi_as_feature: |
| | residue_nf = (residue_nf[0] + 5, residue_nf[1]) |
| |
|
| | self.atom_encoder = nn.Sequential( |
| | GVP(_atom_in_nf, joint_nf, activations=(act_fn, torch.sigmoid)), |
| | LayerNorm(joint_nf, learnable_vector_weight=True), |
| | GVP(joint_nf, joint_nf, activations=(None, None)), |
| | ) |
| |
|
| | self.residue_encoder = nn.Sequential( |
| | GVP(residue_nf, joint_nf, activations=(act_fn, torch.sigmoid)), |
| | LayerNorm(joint_nf, learnable_vector_weight=True), |
| | GVP(joint_nf, joint_nf, activations=(None, None)), |
| | ) |
| |
|
| | else: |
| | |
| | assert joint_nf[1] == 0 |
| |
|
| | |
| | assert not self.self_conditioning |
| |
|
| | |
| | assert not self.add_nma_feat |
| |
|
| | if self.add_chi_as_feature: |
| | residue_nf += 5 |
| |
|
| | self.atom_encoder = nn.Sequential( |
| | nn.Linear(lig_nf, 2 * atom_nf), |
| | act_fn, |
| | nn.Linear(2 * atom_nf, joint_nf[0]) |
| | ) |
| |
|
| | self.residue_encoder = nn.Sequential( |
| | nn.Linear(residue_nf, 2 * residue_nf), |
| | act_fn, |
| | nn.Linear(2 * residue_nf, joint_nf[0]) |
| | ) |
| |
|
| | self.atom_decoder = nn.Sequential( |
| | nn.Linear(joint_nf[0], 2 * atom_nf), |
| | act_fn, |
| | nn.Linear(2 * atom_nf, atom_nf) |
| | ) |
| |
|
| | self.edge_decoder = nn.Sequential( |
| | nn.Linear(hidden_nf, hidden_nf), |
| | act_fn, |
| | nn.Linear(hidden_nf, self.bond_nf) |
| | ) |
| |
|
| | _atom_bond_nf = 2 * self.bond_nf if self.self_conditioning else self.bond_nf |
| | self.ligand_bond_encoder = nn.Sequential( |
| | nn.Linear(_atom_bond_nf, hidden_nf), |
| | act_fn, |
| | nn.Linear(hidden_nf, self.edge_nf) |
| | ) |
| |
|
| | self.pocket_bond_encoder = nn.Sequential( |
| | nn.Linear(self.pocket_bond_nf, hidden_nf), |
| | act_fn, |
| | nn.Linear(hidden_nf, self.edge_nf) |
| | ) |
| |
|
| | out_nf = (joint_nf[0], 1) |
| | res_out_nf = (0, 0) |
| | if self.predict_angles: |
| | res_out_nf = (res_out_nf[0] + 5, res_out_nf[1]) |
| | if self.predict_frames: |
| | res_out_nf = (res_out_nf[0], res_out_nf[1] + 2) |
| | self.residue_decoder = nn.Sequential( |
| | GVP(out_nf, out_nf, activations=(act_fn, torch.sigmoid)), |
| | LayerNorm(out_nf, learnable_vector_weight=True), |
| | GVP(out_nf, res_out_nf, activations=(None, None)), |
| | ) if res_out_nf != (0, 0) else None |
| |
|
| | if angle_act_fn is None: |
| | self.angle_act_fn = None |
| | elif angle_act_fn == 'tanh': |
| | self.angle_act_fn = lambda x: np.pi * F.tanh(x) |
| | else: |
| | raise NotImplementedError(f"Angle activation {angle_act_fn} not available") |
| |
|
| | |
| | |
| | self.cross_emb = nn.Parameter(torch.zeros(self.edge_nf), |
| | requires_grad=True) |
| |
|
| | if condition_time: |
| | dynamics_node_nf = (joint_nf[0] + 1, joint_nf[1]) |
| | else: |
| | print('Warning: dynamics model is NOT conditioned on time.') |
| | dynamics_node_nf = (joint_nf[0], joint_nf[1]) |
| |
|
| | if model == 'egnn': |
| | raise NotImplementedError |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | elif model == 'gvp': |
| | self.net = GVPModel( |
| | node_in_dim=dynamics_node_nf, node_h_dim=model_params.node_h_dim, |
| | node_out_nf=joint_nf[0], edge_in_nf=self.edge_nf, |
| | edge_h_dim=model_params.edge_h_dim, edge_out_nf=hidden_nf, |
| | num_layers=model_params.n_layers, |
| | drop_rate=model_params.dropout, |
| | vector_gate=model_params.vector_gate, |
| | reflection_equiv=model_params.reflection_equivariant, |
| | d_max=model_params.d_max, |
| | num_rbf=model_params.num_rbf, |
| | update_edge_attr=True |
| | ) |
| |
|
| | elif model == 'gvp_transformer': |
| | self.net = GVPTransformerModel( |
| | node_in_dim=dynamics_node_nf, |
| | node_h_dim=model_params.node_h_dim, |
| | node_out_nf=joint_nf[0], |
| | edge_in_nf=self.edge_nf, |
| | edge_h_dim=model_params.edge_h_dim, |
| | edge_out_nf=hidden_nf, |
| | num_layers=model_params.n_layers, |
| | dk=model_params.dk, |
| | dv=model_params.dv, |
| | de=model_params.de, |
| | db=model_params.db, |
| | dy=model_params.dy, |
| | attn_heads=model_params.attn_heads, |
| | n_feedforward=model_params.n_feedforward, |
| | drop_rate=model_params.dropout, |
| | reflection_equiv=model_params.reflection_equivariant, |
| | d_max=model_params.d_max, |
| | num_rbf=model_params.num_rbf, |
| | vector_gate=model_params.vector_gate, |
| | attention=model_params.attention, |
| | ) |
| |
|
| | elif model == 'gnn': |
| | raise NotImplementedError |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | else: |
| | raise NotImplementedError(f"{model} is not available") |
| |
|
| | |
| | |
| | self.condition_time = condition_time |
| |
|
| | def _forward(self, x_atoms, h_atoms, mask_atoms, pocket, t, bonds_ligand=None, |
| | h_atoms_sc=None, e_atoms_sc=None, h_residues_sc=None): |
| | """ |
| | :param x_atoms: |
| | :param h_atoms: |
| | :param mask_atoms: |
| | :param pocket: must contain keys: 'x', 'one_hot', 'mask', 'bonds' and 'bond_one_hot' |
| | :param t: |
| | :param bonds_ligand: tuple - bond indices (2, n_bonds) & bond types (n_bonds, bond_nf) |
| | :param h_atoms_sc: additional node feature for self-conditioning, (s, V) |
| | :param e_atoms_sc: additional edge feature for self-conditioning, only scalar |
| | :param h_residues_sc: additional node feature for self-conditioning, tensor or tuple |
| | :return: |
| | """ |
| | x_residues, h_residues, mask_residues = pocket['x'], pocket['one_hot'], pocket['mask'] |
| | if 'bonds' in pocket: |
| | bonds_pocket = (pocket['bonds'], pocket['bond_one_hot']) |
| | else: |
| | bonds_pocket = None |
| |
|
| | if self.add_chi_as_feature: |
| | h_residues = torch.cat([h_residues, pocket['chi'][:, :5]], dim=-1) |
| |
|
| | if 'v' in pocket: |
| | v_residues = pocket['v'] |
| | if self.add_nma_feat: |
| | v_residues = torch.cat([v_residues, pocket['nma_vec']], dim=1) |
| | h_residues = (h_residues, v_residues) |
| |
|
| | if h_residues_sc is not None: |
| | |
| | if isinstance(h_residues_sc, tuple): |
| | h_residues = (torch.cat([h_residues[0], h_residues_sc[0]], dim=-1), |
| | torch.cat([h_residues[1], h_residues_sc[1]], dim=1)) |
| | else: |
| | h_residues = (torch.cat([h_residues[0], h_residues_sc], dim=-1), |
| | h_residues[1]) |
| |
|
| | |
| | if bonds_ligand is not None: |
| | |
| | ligand_bond_indices = bonds_ligand[0] |
| |
|
| | |
| | ligand_edge_indices = torch.cat( |
| | [bonds_ligand[0], bonds_ligand[0].flip(dims=[0])], dim=1) |
| | ligand_edge_types = torch.cat([bonds_ligand[1], bonds_ligand[1]], dim=0) |
| | |
| |
|
| | |
| | extra_features = self.compute_extra_features( |
| | mask_atoms, ligand_edge_indices, ligand_edge_types.argmax(-1)) |
| | h_atoms = torch.cat([h_atoms, extra_features], dim=-1) |
| |
|
| | if bonds_pocket is not None: |
| | |
| | pocket_edge_indices = torch.cat( |
| | [bonds_pocket[0], bonds_pocket[0].flip(dims=[0])], dim=1) |
| | pocket_edge_types = torch.cat([bonds_pocket[1], bonds_pocket[1]], dim=0) |
| | |
| |
|
| | if h_atoms_sc is not None: |
| | h_atoms = (torch.cat([h_atoms, h_atoms_sc[0]], dim=-1), |
| | h_atoms_sc[1]) |
| |
|
| | if e_atoms_sc is not None: |
| | e_atoms_sc = torch.cat([e_atoms_sc, e_atoms_sc], dim=0) |
| | ligand_edge_types = torch.cat([ligand_edge_types, e_atoms_sc], dim=-1) |
| |
|
| | |
| | h_atoms = self.atom_encoder(h_atoms) |
| | e_ligand = self.ligand_bond_encoder(ligand_edge_types) |
| |
|
| | if len(h_residues) > 0: |
| | h_residues = self.residue_encoder(h_residues) |
| | e_pocket = self.pocket_bond_encoder(pocket_edge_types) |
| | else: |
| | e_pocket = pocket_edge_types |
| | h_residues = (h_residues, h_residues) |
| | pocket_edge_indices = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device) |
| | pocket_edge_types = torch.tensor([[], []], dtype=torch.long, device=h_residues[0].device) |
| |
|
| | if isinstance(h_atoms, tuple): |
| | h_atoms, v_atoms = h_atoms |
| | h_residues, v_residues = h_residues |
| | v = torch.cat((v_atoms, v_residues), dim=0) |
| | else: |
| | v = None |
| |
|
| | edges, edge_feat = self.get_edges( |
| | mask_atoms, mask_residues, x_atoms, x_residues, |
| | bond_inds_ligand=ligand_edge_indices, bond_inds_pocket=pocket_edge_indices, |
| | bond_feat_ligand=e_ligand, bond_feat_pocket=e_pocket) |
| |
|
| | |
| | x = torch.cat((x_atoms, x_residues), dim=0) |
| | h = torch.cat((h_atoms, h_residues), dim=0) |
| | mask = torch.cat([mask_atoms, mask_residues]) |
| |
|
| | if self.condition_time: |
| | if np.prod(t.size()) == 1: |
| | |
| | h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) |
| | else: |
| | |
| | h_time = t[mask] |
| | h = torch.cat([h, h_time], dim=1) |
| |
|
| | assert torch.all(mask[edges[0]] == mask[edges[1]]) |
| |
|
| | if self.model == 'egnn': |
| | |
| | update_coords_mask = torch.cat((torch.ones_like(mask_atoms), |
| | torch.zeros_like(mask_residues))).unsqueeze(1) |
| | h_final, vel, edge_final = self.net( |
| | h, x, edges, batch_mask=mask, edge_attr=edge_feat, |
| | update_coords_mask=update_coords_mask) |
| | |
| |
|
| | elif self.model == 'gvp' or self.model == 'gvp_transformer': |
| | h_final, vel, edge_final = self.net( |
| | h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat) |
| |
|
| | elif self.model == 'gnn': |
| | xh = torch.cat([x, h], dim=1) |
| | output = self.net(xh, edges, node_mask=None, edge_attr=edge_feat) |
| | vel = output[:, :3] |
| | h_final = output[:, 3:] |
| |
|
| | else: |
| | raise NotImplementedError(f"Wrong model ({self.model})") |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)]) |
| |
|
| | if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final_atoms)): |
| | if self.training: |
| | vel[torch.isnan(vel)] = 0.0 |
| | h_final_atoms[torch.isnan(h_final_atoms)] = 0.0 |
| | else: |
| | raise ValueError("NaN detected in network output") |
| |
|
| | |
| | ligand_edge_mask = (edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms)) |
| | edge_final = edge_final[ligand_edge_mask] |
| | edges = edges[:, ligand_edge_mask] |
| |
|
| | |
| | edge_logits = torch.zeros( |
| | (len(mask_atoms), len(mask_atoms), self.hidden_nf), |
| | device=mask_atoms.device) |
| | edge_logits[edges[0], edges[1]] = edge_final |
| | edge_logits = (edge_logits + edge_logits.transpose(0, 1)) * 0.5 |
| | |
| |
|
| | |
| | edge_logits = edge_logits[ligand_bond_indices[0], ligand_bond_indices[1]] |
| | |
| |
|
| | edge_final_atoms = self.edge_decoder(edge_logits) |
| |
|
| | |
| | residue_angles = None |
| | residue_trans, residue_rot = None, None |
| | if self.residue_decoder is not None: |
| | h_residues = h_final[len(mask_atoms):] |
| | vec_residues = vel[len(mask_atoms):].unsqueeze(1) |
| | residue_angles = self.residue_decoder((h_residues, vec_residues)) |
| | if self.predict_frames: |
| | residue_angles, residue_frames = residue_angles |
| | residue_trans = residue_frames[:, 0, :].squeeze(1) |
| | residue_rot = residue_frames[:, 1, :].squeeze(1) |
| | if self.angle_act_fn is not None: |
| | residue_angles = self.angle_act_fn(residue_angles) |
| |
|
| | |
| | pred_ligand = {'vel': vel[:len(mask_atoms)], 'logits_h': h_final_atoms, 'logits_e': edge_final_atoms} |
| | pred_residues = {'chi': residue_angles, 'trans': residue_trans, 'rot': residue_rot} |
| | return pred_ligand, pred_residues |
| |
|
| | def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand, |
| | x_pocket, bond_inds_ligand=None, bond_inds_pocket=None, |
| | bond_feat_ligand=None, bond_feat_pocket=None, self_edges=False): |
| |
|
| | |
| | adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] |
| | adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] |
| | adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] |
| |
|
| | if self.edge_cutoff_l is not None: |
| | adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) |
| |
|
| | |
| | adj_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = True |
| |
|
| | if self.edge_cutoff_p is not None and len(x_pocket) > 0: |
| | adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) |
| |
|
| | |
| | adj_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = True |
| |
|
| | if self.edge_cutoff_i is not None and len(x_pocket) > 0: |
| | adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) |
| |
|
| | adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), |
| | torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) |
| |
|
| | if not self_edges: |
| | adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | ligand_nobond_onehot = F.one_hot(torch.tensor( |
| | self.bond_dict['NOBOND'], device=bond_feat_ligand.device), |
| | num_classes=self.ligand_bond_encoder[0].in_features) |
| | ligand_nobond_emb = self.ligand_bond_encoder( |
| | ligand_nobond_onehot.to(FLOAT_TYPE)) |
| | feat_ligand = ligand_nobond_emb.repeat(*adj_ligand.shape, 1) |
| | feat_ligand[bond_inds_ligand[0], bond_inds_ligand[1]] = bond_feat_ligand |
| |
|
| | if len(adj_pocket) > 0: |
| | pocket_nobond_onehot = F.one_hot(torch.tensor( |
| | self.pocket_bond_dict['NOBOND'], device=bond_feat_pocket.device), |
| | num_classes=self.pocket_bond_nf) |
| | pocket_nobond_emb = self.pocket_bond_encoder( |
| | pocket_nobond_onehot.to(FLOAT_TYPE)) |
| | feat_pocket = pocket_nobond_emb.repeat(*adj_pocket.shape, 1) |
| | feat_pocket[bond_inds_pocket[0], bond_inds_pocket[1]] = bond_feat_pocket |
| |
|
| | feat_cross = self.cross_emb.repeat(*adj_cross.shape, 1) |
| |
|
| | feats = torch.cat((torch.cat((feat_ligand, feat_cross), dim=1), |
| | torch.cat((feat_cross.transpose(0, 1), feat_pocket), dim=1)), dim=0) |
| | else: |
| | feats = feat_ligand |
| |
|
| | |
| | edges = torch.stack(torch.where(adj), dim=0) |
| | edge_feat = feats[edges[0], edges[1]] |
| |
|
| | return edges, edge_feat |
| |
|