| | import warnings |
| | import tempfile |
| | from typing import Optional, Union |
| | from time import time |
| | from pathlib import Path |
| | from functools import partial |
| | from itertools import accumulate |
| | from argparse import Namespace |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | from rdkit import Chem |
| | import torch |
| | from torch.utils.data import DataLoader, SubsetRandomSampler |
| | from torch.distributions.categorical import Categorical |
| | import pytorch_lightning as pl |
| | from torch_scatter import scatter_mean |
| |
|
| | import src.utils as utils |
| | from src.constants import atom_encoder, atom_decoder, aa_encoder, aa_decoder, \ |
| | bond_encoder, bond_decoder, residue_encoder, residue_bond_encoder, \ |
| | residue_decoder, residue_bond_decoder, aa_atom_index, aa_atom_mask |
| | from src.data.dataset import ProcessedLigandPocketDataset, ClusteredDataset, get_wds |
| | from src.data import data_utils |
| | from src.data.data_utils import AppendVirtualNodesInCoM, center_data, Residues, TensorDict, randomize_tensors |
| | from src.model.flows import CoordICFM, TorusICFM, CoordICFMPredictFinal, TorusICFMPredictFinal, SO3ICFM |
| | from src.model.markov_bridge import UniformPriorMarkovBridge, MarginalPriorMarkovBridge |
| | from src.model.dynamics import Dynamics |
| | from src.model.dynamics_hetero import DynamicsHetero |
| | from src.model.diffusion_utils import DistributionNodes |
| | from src.model.loss_utils import TimestepWeights, clash_loss |
| | from src.analysis.visualization_utils import pocket_to_rdkit, mols_to_pdbfile |
| | from src.analysis.metrics import MoleculeValidity, CategoricalDistribution, MolecularProperties |
| | from src.data.molecule_builder import build_molecule |
| | from src.data.postprocessing import process_all |
| | from src.sbdd_metrics.metrics import FullEvaluator |
| | from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics |
| | from tqdm import tqdm |
| |
|
| | |
| | aa_atom_mask_tensor = torch.tensor([aa_atom_mask[aa] for aa in aa_decoder]) |
| | aa_atom_decoder = {aa: {v: k for k, v in aa_atom_index[aa].items()} for aa in aa_decoder} |
| | aa_atom_type_tensor = torch.tensor([[atom_encoder.get(aa_atom_decoder[aa].get(i, '-')[0], -42) |
| | for i in range(14)] for aa in aa_decoder]) |
| |
|
| |
|
| | def set_default(namespace, key, default_val): |
| | val = vars(namespace).get(key, default_val) |
| | setattr(namespace, key, val) |
| |
|
| |
|
| | class DrugFlow(pl.LightningModule): |
| | def __init__( |
| | self, |
| | pocket_representation: str, |
| | train_params: Namespace, |
| | loss_params: Namespace, |
| | eval_params: Namespace, |
| | predictor_params: Namespace, |
| | simulation_params: Namespace, |
| | virtual_nodes: Union[list, None], |
| | flexible: bool, |
| | flexible_bb: bool = False, |
| | debug: bool = False, |
| | overfit: bool = False, |
| | ): |
| | super(DrugFlow, self).__init__() |
| | self.save_hyperparameters() |
| |
|
| | |
| | set_default(train_params, "sharded_dataset", False) |
| | set_default(train_params, "sample_from_clusters", False) |
| | set_default(train_params, "lr_step_size", None) |
| | set_default(train_params, "lr_gamma", None) |
| | set_default(train_params, "gnina", None) |
| | set_default(loss_params, "lambda_x", 1.0) |
| | set_default(loss_params, "lambda_clash", None) |
| | set_default(loss_params, "reduce", "mean") |
| | set_default(loss_params, "regularize_uncertainty", None) |
| | set_default(eval_params, "n_loss_per_sample", 1) |
| | set_default(eval_params, "n_sampling_steps", simulation_params.n_steps) |
| | set_default(predictor_params, "transform_sc_pred", False) |
| | set_default(predictor_params, "add_chi_as_feature", False) |
| | set_default(predictor_params, "augment_residue_sc", False) |
| | set_default(predictor_params, "augment_ligand_sc", False) |
| | set_default(predictor_params, "add_all_atom_diff", False) |
| | set_default(predictor_params, "angle_act_fn", None) |
| | set_default(simulation_params, "predict_confidence", False) |
| | set_default(simulation_params, "predict_final", False) |
| | set_default(simulation_params, "scheduler_chi", None) |
| |
|
| | |
| | assert pocket_representation in {'side_chain_bead', 'CA+'} |
| | self.pocket_representation = pocket_representation |
| |
|
| | assert flexible or not predictor_params.augment_residue_sc |
| | self.augment_residue_sc = predictor_params.augment_residue_sc \ |
| | if 'augment_residue_sc' in predictor_params else False |
| | self.augment_ligand_sc = predictor_params.augment_ligand_sc \ |
| | if 'augment_ligand_sc' in predictor_params else False |
| |
|
| | assert not (flexible_bb and predictor_params.normal_modes), \ |
| | "Normal mode eigenvectors are only meaningful for fixed backbones" |
| | assert (not flexible_bb) or flexible, \ |
| | "Currently atom vectors aren't updated if flexible=False" |
| |
|
| | assert not (simulation_params.predict_confidence and |
| | (not predictor_params.heterogeneous_graph or simulation_params.predict_final)) |
| |
|
| | |
| | self.train_dataset = None |
| | self.val_dataset = None |
| | self.test_dataset = None |
| | self.virtual_nodes = virtual_nodes |
| | self.flexible = flexible |
| | self.flexible_bb = flexible_bb |
| | self.debug = debug |
| | self.overfit = overfit |
| | self.predict_confidence = simulation_params.predict_confidence |
| |
|
| | if self.virtual_nodes: |
| | self.add_virtual_min = virtual_nodes[0] |
| | self.add_virtual_max = virtual_nodes[1] |
| |
|
| | |
| | self.datadir = train_params.datadir |
| | self.receptor_dir = train_params.datadir |
| | self.batch_size = train_params.batch_size |
| | self.lr = train_params.lr |
| | self.lr_step_size = train_params.lr_step_size |
| | self.lr_gamma = train_params.lr_gamma |
| | self.num_workers = train_params.num_workers |
| | self.sample_from_clusters = train_params.sample_from_clusters |
| | self.sharded_dataset = train_params.sharded_dataset |
| | self.clip_grad = train_params.clip_grad |
| | if self.clip_grad: |
| | self.gradnorm_queue = utils.Queue() |
| | |
| | self.gradnorm_queue.add(3000) |
| |
|
| | |
| | self.outdir = eval_params.outdir |
| | self.eval_batch_size = eval_params.eval_batch_size |
| | self.eval_epochs = eval_params.eval_epochs |
| | |
| | self.visualize_sample_epoch = eval_params.visualize_sample_epoch |
| | self.visualize_chain_epoch = eval_params.visualize_chain_epoch |
| | self.sample_with_ground_truth_size = eval_params.sample_with_ground_truth_size |
| | self.n_loss_per_sample = eval_params.n_loss_per_sample |
| | self.n_eval_samples = eval_params.n_eval_samples |
| | self.n_visualize_samples = eval_params.n_visualize_samples |
| | self.keep_frames = eval_params.keep_frames |
| | self.gnina = train_params.gnina |
| |
|
| | |
| | self.atom_encoder = atom_encoder |
| | self.atom_decoder = atom_decoder |
| | self.bond_encoder = bond_encoder |
| | self.bond_decoder = bond_decoder |
| | self.aa_encoder = aa_encoder |
| | self.aa_decoder = aa_decoder |
| | self.residue_encoder = residue_encoder |
| | self.residue_decoder = residue_decoder |
| | self.residue_bond_encoder = residue_bond_encoder |
| | self.residue_bond_decoder = residue_bond_decoder |
| |
|
| | self.atom_nf = len(self.atom_decoder) |
| | self.residue_nf = len(self.aa_decoder) |
| | if self.pocket_representation == 'side_chain_bead': |
| | self.residue_nf += len(self.residue_encoder) |
| | if self.pocket_representation == 'CA+': |
| | self.aa_atom_index = aa_atom_index |
| | self.n_atom_aa = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 |
| | self.residue_nf = (self.residue_nf, self.n_atom_aa) |
| | self.bond_nf = len(self.bond_decoder) |
| | self.pocket_bond_nf = len(self.residue_bond_decoder) |
| | self.x_dim = 3 |
| |
|
| | |
| | self.dynamics = self.init_model(predictor_params) |
| |
|
| | |
| | if simulation_params.predict_final: |
| | self.module_x = CoordICFMPredictFinal(None) |
| | self.module_chi = TorusICFMPredictFinal(None, 5) if self.flexible else None |
| | if self.flexible_bb: |
| | raise NotImplementedError() |
| | else: |
| | self.module_x = CoordICFM(None) |
| | |
| | scheduler_args = None if simulation_params.scheduler_chi is None else vars(simulation_params.scheduler_chi) |
| | self.module_chi = TorusICFM(None, 5, scheduler_args) if self.flexible else None |
| | self.module_trans = CoordICFM(None) if self.flexible_bb else None |
| | self.module_rot = SO3ICFM(None) if self.flexible_bb else None |
| |
|
| | if simulation_params.prior_h == 'uniform': |
| | self.module_h = UniformPriorMarkovBridge(self.atom_nf, loss_type=loss_params.discrete_loss) |
| | elif simulation_params.prior_h == 'marginal': |
| | self.register_buffer('prior_h', self.get_categorical_prop('atom')) |
| | self.module_h = MarginalPriorMarkovBridge(self.atom_nf, self.prior_h, loss_type=loss_params.discrete_loss) |
| |
|
| | if simulation_params.prior_e == 'uniform': |
| | self.module_e = UniformPriorMarkovBridge(self.bond_nf, loss_type=loss_params.discrete_loss) |
| | elif simulation_params.prior_e == 'marginal': |
| | self.register_buffer('prior_e', self.get_categorical_prop('bond')) |
| | self.module_e = MarginalPriorMarkovBridge(self.bond_nf, self.prior_e, loss_type=loss_params.discrete_loss) |
| |
|
| |
|
| | |
| | self.loss_reduce = loss_params.reduce |
| | self.lambda_x = loss_params.lambda_x |
| | self.lambda_h = loss_params.lambda_h |
| | self.lambda_e = loss_params.lambda_e |
| | self.lambda_chi = loss_params.lambda_chi if self.flexible else None |
| | self.lambda_trans = loss_params.lambda_trans if self.flexible_bb else None |
| | self.lambda_rot = loss_params.lambda_rot if self.flexible_bb else None |
| | self.lambda_clash = loss_params.lambda_clash |
| | self.regularize_uncertainty = loss_params.regularize_uncertainty |
| |
|
| | if loss_params.timestep_weights is not None: |
| | weight_type = loss_params.timestep_weights.split('_')[0] |
| | kwargs = loss_params.timestep_weights.split('_')[1:] |
| | kwargs = {x.split('=')[0]: float(x.split('=')[1]) for x in kwargs} |
| | self.timestep_weights = TimestepWeights(weight_type, **kwargs) |
| | else: |
| | self.timestep_weights = None |
| |
|
| |
|
| | |
| | self.T_sampling = eval_params.n_sampling_steps |
| | self.train_step_size = 1 / simulation_params.n_steps |
| | self.size_distribution = None |
| |
|
| |
|
| | |
| | self.train_smiles = None |
| | self.ligand_metrics = None |
| | self.molecule_properties = None |
| | self.evaluator = None |
| | self.ligand_atom_type_distribution = None |
| | self.ligand_bond_type_distribution = None |
| |
|
| | |
| | self.training_step_outputs = [] |
| | self.validation_step_outputs = [] |
| |
|
| | def on_load_checkpoint(self, checkpoint): |
| | """ |
| | This hook is only used for backward compatibility with checkpoints that |
| | did not save prior_h and prior_e in state_dict in the past |
| | """ |
| | if hasattr(self, "prior_h") and "prior_h" not in checkpoint["state_dict"]: |
| | checkpoint["state_dict"]["prior_h"] = self.get_categorical_prop('atom') |
| | if hasattr(self, "prior_e") and "prior_e" not in checkpoint["state_dict"]: |
| | checkpoint["state_dict"]["prior_e"] = self.get_categorical_prop('bond') |
| | if "prior_e" in checkpoint["state_dict"] and not hasattr(self, "prior_e"): |
| | |
| | self.register_buffer("prior_e", self.get_categorical_prop('bond')) |
| |
|
| | def init_model(self, predictor_params): |
| |
|
| | model_type = predictor_params.backbone |
| |
|
| | if 'heterogeneous_graph' in predictor_params and predictor_params.heterogeneous_graph: |
| | return DynamicsHetero( |
| | atom_nf=self.atom_nf, |
| | residue_nf=self.residue_nf, |
| | bond_dict=self.bond_encoder, |
| | pocket_bond_dict=self.residue_bond_encoder, |
| | model=model_type, |
| | num_rbf_time=predictor_params.__dict__.get('num_rbf_time'), |
| | model_params=getattr(predictor_params, model_type + '_params'), |
| | edge_cutoff_ligand=predictor_params.edge_cutoff_ligand, |
| | edge_cutoff_pocket=predictor_params.edge_cutoff_pocket, |
| | edge_cutoff_interaction=predictor_params.edge_cutoff_interaction, |
| | predict_angles=self.flexible, |
| | predict_frames=self.flexible_bb, |
| | add_cycle_counts=predictor_params.cycle_counts, |
| | add_spectral_feat=predictor_params.spectral_feat, |
| | add_nma_feat=predictor_params.normal_modes, |
| | reflection_equiv=predictor_params.reflection_equivariant, |
| | d_max=predictor_params.d_max, |
| | num_rbf_dist=predictor_params.num_rbf, |
| | self_conditioning=predictor_params.self_conditioning, |
| | augment_residue_sc=self.augment_residue_sc, |
| | augment_ligand_sc=self.augment_ligand_sc, |
| | add_chi_as_feature=predictor_params.add_chi_as_feature, |
| | angle_act_fn=predictor_params.angle_act_fn, |
| | add_all_atom_diff=predictor_params.add_all_atom_diff, |
| | predict_confidence=self.predict_confidence, |
| | ) |
| |
|
| | else: |
| | if predictor_params.__dict__.get('num_rbf_time') is not None: |
| | raise NotImplementedError("RBF time embedding not yet implemented") |
| |
|
| | return Dynamics( |
| | atom_nf=self.atom_nf, |
| | residue_nf=self.residue_nf, |
| | joint_nf=predictor_params.joint_nf, |
| | bond_dict=self.bond_encoder, |
| | pocket_bond_dict=self.residue_bond_encoder, |
| | edge_nf=predictor_params.edge_nf, |
| | hidden_nf=predictor_params.hidden_nf, |
| | model=model_type, |
| | model_params=getattr(predictor_params, model_type + '_params'), |
| | edge_cutoff_ligand=predictor_params.edge_cutoff_ligand, |
| | edge_cutoff_pocket=predictor_params.edge_cutoff_pocket, |
| | edge_cutoff_interaction=predictor_params.edge_cutoff_interaction, |
| | predict_angles=self.flexible, |
| | predict_frames=self.flexible_bb, |
| | add_cycle_counts=predictor_params.cycle_counts, |
| | add_spectral_feat=predictor_params.spectral_feat, |
| | add_nma_feat=predictor_params.normal_modes, |
| | self_conditioning=predictor_params.self_conditioning, |
| | augment_residue_sc=self.augment_residue_sc, |
| | augment_ligand_sc=self.augment_ligand_sc, |
| | add_chi_as_feature=predictor_params.add_chi_as_feature, |
| | angle_act_fn=predictor_params.angle_act_fn, |
| | ) |
| |
|
| | def _load_histogram(self, type): |
| | """ |
| | Load empirical categorical distributions of atom or bond types from disk. |
| | Returns None if the required file is not found. |
| | """ |
| | assert type in {"atom", "bond"} |
| | filename = 'ligand_type_histogram.npy' if type == 'atom' else 'ligand_bond_type_histogram.npy' |
| | encoder = self.atom_encoder if type == 'atom' else self.bond_encoder |
| | hist_file = Path(self.datadir, filename) |
| | if not hist_file.exists(): |
| | return None |
| | hist = np.load(hist_file, allow_pickle=True).item() |
| | return CategoricalDistribution(hist, encoder) |
| |
|
| | def get_categorical_prop(self, type): |
| | hist = self._load_histogram(type) |
| | encoder = self.atom_encoder if type == 'atom' else self.bond_encoder |
| | |
| | |
| | return torch.zeros(len(encoder)) * float("nan") if hist is None else torch.tensor(hist.p) |
| |
|
| | def configure_optimizers(self): |
| | optimizers = [ |
| | torch.optim.AdamW(self.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12), |
| | ] |
| |
|
| | if self.lr_step_size is None or self.lr_gamma is None: |
| | lr_schedulers = [] |
| | else: |
| | lr_schedulers = [ |
| | torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=self.lr_step_size, gamma=self.lr_gamma), |
| | ] |
| | return optimizers, lr_schedulers |
| |
|
| | def setup(self, stage: Optional[str] = None): |
| |
|
| | self.setup_sampling() |
| |
|
| | if stage == 'fit': |
| | self.train_dataset = self.get_dataset(stage='train') |
| | self.val_dataset = self.get_dataset(stage='val') |
| | self.setup_metrics() |
| | elif stage == 'val': |
| | self.val_dataset = self.get_dataset(stage='val') |
| | self.setup_metrics() |
| | elif stage == 'test': |
| | self.test_dataset = self.get_dataset(stage='test') |
| | self.setup_metrics() |
| | elif stage == 'generation': |
| | pass |
| | else: |
| | raise NotImplementedError |
| |
|
| | def get_dataset(self, stage, pocket_transform=None): |
| |
|
| | |
| | if self.virtual_nodes and stage == "train": |
| | ligand_transform = AppendVirtualNodesInCoM( |
| | atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max) |
| | else: |
| | ligand_transform = None |
| |
|
| | |
| | catch_errors = stage == "train" |
| |
|
| | if self.sharded_dataset: |
| | return get_wds( |
| | data_path=self.datadir, |
| | stage='val' if self.debug else stage, |
| | ligand_transform=ligand_transform, |
| | pocket_transform=pocket_transform, |
| | ) |
| |
|
| | if self.sample_from_clusters and stage == "train": |
| | return ClusteredDataset( |
| | pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'), |
| | ligand_transform=ligand_transform, |
| | pocket_transform=pocket_transform, |
| | catch_errors=catch_errors |
| | ) |
| |
|
| | return ProcessedLigandPocketDataset( |
| | pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'), |
| | ligand_transform=ligand_transform, |
| | pocket_transform=pocket_transform, |
| | catch_errors=catch_errors |
| | ) |
| |
|
| | def setup_sampling(self): |
| | |
| | histogram_file = Path(self.datadir, 'size_distribution.npy') |
| | size_histogram = np.load(histogram_file).tolist() |
| | self.size_distribution = DistributionNodes(size_histogram) |
| |
|
| | def setup_metrics(self): |
| | |
| | smiles_file = Path(self.datadir, 'train_smiles.npy') |
| | self.train_smiles = None if not smiles_file.exists() else np.load(smiles_file) |
| |
|
| | self.ligand_metrics = MoleculeValidity() |
| | self.molecule_properties = MolecularProperties() |
| | self.evaluator = FullEvaluator(gnina=self.gnina, exclude_evaluators=['geometry', 'ring_count']) |
| | self.ligand_atom_type_distribution = self._load_histogram('atom') |
| | self.ligand_bond_type_distribution = self._load_histogram('bond') |
| |
|
| | def train_dataloader(self): |
| | shuffle = None if self.overfit else False if self.sharded_dataset else True |
| | return DataLoader(self.train_dataset, self.batch_size, shuffle=shuffle, |
| | sampler=SubsetRandomSampler([0]) if self.overfit else None, |
| | num_workers=self.num_workers, |
| | collate_fn=self.train_dataset.collate_fn, |
| | |
| | pin_memory=True) |
| |
|
| | def val_dataloader(self): |
| | if self.overfit: |
| | return self.train_dataloader() |
| |
|
| | return DataLoader(self.val_dataset, self.eval_batch_size, |
| | shuffle=False, num_workers=self.num_workers, |
| | collate_fn=self.val_dataset.collate_fn, |
| | pin_memory=True) |
| |
|
| | def test_dataloader(self): |
| | return DataLoader(self.test_dataset, self.eval_batch_size, shuffle=False, |
| | num_workers=self.num_workers, |
| | collate_fn=self.test_dataset.collate_fn, |
| | pin_memory=True) |
| |
|
| | def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs): |
| | for m, value in metrics_dict.items(): |
| | self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs) |
| |
|
| | def aggregate_metrics(self, step_outputs, prefix): |
| | if 'timestep' in step_outputs[0]: |
| | timesteps = torch.cat([x['timestep'] for x in step_outputs]).squeeze() |
| |
|
| | if 'loss_per_sample' in step_outputs[0]: |
| | losses = torch.cat([x['loss_per_sample'] for x in step_outputs]) |
| | pearson_corr = torch.corrcoef(torch.stack([timesteps, losses], dim=0))[0, 1] |
| | self.log(f'corr_loss_timestep/{prefix}', pearson_corr, prog_bar=False) |
| |
|
| | if 'eps_hat_norm' in step_outputs[0]: |
| | eps_norm = torch.cat([x['eps_hat_norm'] for x in step_outputs]) |
| | pearson_corr = torch.corrcoef(torch.stack([timesteps, eps_norm], dim=0))[0, 1] |
| | self.log(f'corr_eps_timestep/{prefix}', pearson_corr, prog_bar=False) |
| |
|
| | def on_train_epoch_end(self): |
| | self.aggregate_metrics(self.training_step_outputs, 'train') |
| | self.training_step_outputs.clear() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def get_sc_transform_fn(self, zt_chi, zt_x, t, z0_chi, ligand_mask, pocket): |
| | sc_transform = {} |
| |
|
| | if self.augment_residue_sc: |
| | def pred_all_atom(pred_chi, pred_trans=None, pred_rot=None): |
| | temp_pocket = pocket.deepcopy() |
| |
|
| | if pred_trans is not None and pred_rot is not None: |
| | zt_trans = pocket['x'] |
| | zt_rot = pocket['axis_angle'] |
| | z1_trans_pred = self.module_trans.get_z1_given_zt_and_pred( |
| | zt_trans, pred_trans, None, t, pocket['mask']) |
| | z1_rot_pred = self.module_rot.get_z1_given_zt_and_pred( |
| | zt_rot, pred_rot, None, t, pocket['mask']) |
| | temp_pocket.set_frame(z1_trans_pred, z1_rot_pred) |
| |
|
| | z1_chi_pred = self.module_chi.get_z1_given_zt_and_pred( |
| | zt_chi[..., :5], pred_chi, z0_chi, t, pocket['mask']) |
| | temp_pocket.set_chi(z1_chi_pred) |
| |
|
| | all_coord = temp_pocket['v'] + temp_pocket['x'].unsqueeze(1) |
| | return all_coord - pocket['x'].unsqueeze(1) |
| |
|
| | sc_transform['residues'] = pred_all_atom |
| |
|
| | if self.augment_ligand_sc: |
| | |
| | sc_transform['atoms'] = lambda pred: (self.module_x.get_z1_given_zt_and_pred( |
| | zt_x, pred.squeeze(1), None, t, ligand_mask) - zt_x).unsqueeze(1) |
| |
|
| | return sc_transform |
| |
|
| | def compute_loss(self, ligand, pocket, return_info=False): |
| | """ |
| | Samples time steps and computes network predictions |
| | """ |
| | |
| | pocket = Residues(**pocket) |
| |
|
| | |
| | ligand, pocket = center_data(ligand, pocket) |
| | if pocket['x'].numel() > 0: |
| | pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
| | else: |
| | pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) |
| |
|
| | |
| | |
| |
|
| | |
| | t = torch.rand(ligand['size'].size(0), device=ligand['x'].device).unsqueeze(-1) |
| |
|
| | |
| | z0_x = self.module_x.sample_z0(pocket_com, ligand['mask']) |
| | z0_h = self.module_h.sample_z0(ligand['mask']) |
| | z0_e = self.module_e.sample_z0(ligand['bond_mask']) |
| | zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask']) |
| | zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask']) |
| | zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask']) |
| |
|
| | if self.flexible_bb: |
| | z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask']) |
| | z1_trans = pocket['x'].detach().clone() |
| | zt_trans = self.module_trans.sample_zt(z0_trans, z1_trans, t, pocket['mask']) |
| |
|
| | z0_rot = self.module_rot.sample_z0(pocket['mask']) |
| | z1_rot = pocket['axis_angle'].detach().clone() |
| | zt_rot = self.module_rot.sample_zt(z0_rot, z1_rot, t, pocket['mask']) |
| |
|
| | |
| | pocket.set_frame(zt_trans, zt_rot) |
| |
|
| | z0_chi, zt_chi = None, None |
| | if self.flexible: |
| | |
| | |
| | |
| | z1_chi = pocket['chi'][:, :5].detach().clone() |
| |
|
| | z0_chi = self.module_chi.sample_z0(pocket['mask']) |
| | zt_chi = self.module_chi.sample_zt(z0_chi, z1_chi, t, pocket['mask']) |
| |
|
| | |
| | pocket.set_chi(zt_chi) |
| | if pocket['x'].numel() == 0: |
| | pocket.set_empty_v() |
| |
|
| | |
| | sc_transform = self.get_sc_transform_fn(zt_chi, zt_x, t, z0_chi, ligand['mask'], pocket) |
| | |
| | pred_ligand, pred_residues = self.dynamics( |
| | zt_x, zt_h, ligand['mask'], pocket, t, |
| | bonds_ligand=(ligand['bonds'], zt_e), sc_transform=sc_transform |
| | ) |
| |
|
| | |
| | if self.predict_confidence: |
| | loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce='none') |
| |
|
| | |
| | k = self.module_x.dim |
| | sigma = pred_ligand['uncertainty_vel'] |
| | loss_x = loss_x / (2 * sigma ** 2) + k * torch.log(sigma) |
| |
|
| | if self.regularize_uncertainty is not None: |
| | loss_x = loss_x + self.regularize_uncertainty * (pred_ligand['uncertainty_vel'] - 1) ** 2 |
| |
|
| | loss_x = self.module_x.reduce_loss(loss_x, ligand['mask'], reduce=self.loss_reduce) |
| | else: |
| | loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce) |
| |
|
| | |
| | t_next = torch.clamp(t + self.train_step_size, max=1.0) |
| | loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce) |
| | loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce) |
| |
|
| | loss = self.lambda_x * loss_x + self.lambda_h * loss_h + self.lambda_e * loss_e |
| | if self.flexible: |
| | loss_chi = self.module_chi.compute_loss(pred_residues['chi'], z0_chi, z1_chi, zt_chi, t, pocket['mask'], reduce=self.loss_reduce) |
| | loss = loss + self.lambda_chi * loss_chi |
| |
|
| | if self.flexible_bb: |
| | loss_trans = self.module_trans.compute_loss(pred_residues['trans'], z0_trans, z1_trans, t, pocket['mask'], reduce=self.loss_reduce) |
| | loss_rot = self.module_rot.compute_loss(pred_residues['rot'], z0_rot, z1_rot, zt_rot, t, pocket['mask'], reduce=self.loss_reduce) |
| | loss = loss + self.lambda_trans * loss_trans + self.lambda_rot * loss_rot |
| |
|
| | if self.lambda_clash is not None and self.lambda_clash > 0: |
| |
|
| | if self.flexible_bb: |
| | pred_z1_trans = self.module_trans.get_z1_given_zt_and_pred(zt_trans, pred_residues['trans'], z0_trans, t, pocket['mask']) |
| | pred_z1_rot = self.module_rot.get_z1_given_zt_and_pred(zt_rot, pred_residues['rot'], z0_rot, t, pocket['mask']) |
| | pocket.set_frame(pred_z1_trans, pred_z1_rot) |
| |
|
| | if self.flexible: |
| | |
| | pred_z1_chi = self.module_chi.get_z1_given_zt_and_pred(zt_chi, pred_residues['chi'], z0_chi, t, pocket['mask']) |
| | pocket.set_chi(pred_z1_chi) |
| |
|
| | pocket_coord = pocket['x'].unsqueeze(1) + pocket['v'] |
| | pocket_types = aa_atom_type_tensor[pocket['one_hot'].argmax(dim=-1)] |
| | pocket_mask = pocket['mask'].unsqueeze(-1).repeat((1, pocket['v'].size(1))) |
| |
|
| | |
| | atom_mask = aa_atom_mask_tensor[pocket['one_hot'].argmax(dim=-1)] |
| | pocket_coord = pocket_coord[atom_mask] |
| | pocket_types = pocket_types[atom_mask] |
| | pocket_mask = pocket_mask[atom_mask] |
| |
|
| | |
| | pred_z1_x = self.module_x.get_z1_given_zt_and_pred(zt_x, pred_ligand['vel'], z0_x, t, ligand['mask']) |
| | pred_z1_h = pred_ligand['logits_h'].argmax(dim=-1) |
| | loss_clash = clash_loss(pred_z1_x, pred_z1_h, ligand['mask'], |
| | pocket_coord, pocket_types, pocket_mask) |
| | loss = loss + self.lambda_clash * loss_clash |
| |
|
| | if self.timestep_weights is not None: |
| | w_t = self.timestep_weights(t).squeeze() |
| | loss = w_t * loss |
| |
|
| | loss = loss.mean(0) |
| |
|
| | info = { |
| | 'loss_x': loss_x.mean().item(), |
| | 'loss_h': loss_h.mean().item(), |
| | 'loss_e': loss_e.mean().item(), |
| | } |
| | if self.flexible: |
| | info['loss_chi'] = loss_chi.mean().item() |
| | if self.flexible_bb: |
| | info['loss_trans'] = loss_trans.mean().item() |
| | info['loss_rot'] = loss_rot.mean().item() |
| | if self.lambda_clash is not None: |
| | info['loss_clash'] = loss_clash.mean().item() |
| | if self.predict_confidence: |
| | sigma_x_mol = scatter_mean(pred_ligand['uncertainty_vel'], ligand['mask'], dim=0) |
| | info['pearson_sigma_x'] = torch.corrcoef(torch.stack([sigma_x_mol.detach(), t.squeeze()]))[0, 1].item() |
| | info['mean_sigma_x'] = sigma_x_mol.mean().item() |
| | entropy_h = Categorical(logits=pred_ligand['logits_h']).entropy() |
| | entropy_h_mol = scatter_mean(entropy_h, ligand['mask'], dim=0) |
| | info['pearson_entropy_h'] = torch.corrcoef(torch.stack([entropy_h_mol.detach(), t.squeeze()]))[0, 1].item() |
| | info['mean_entropy_h'] = entropy_h_mol.mean().item() |
| | entropy_e = Categorical(logits=pred_ligand['logits_e']).entropy() |
| | entropy_e_mol = scatter_mean(entropy_e, ligand['bond_mask'], dim=0) |
| | info['pearson_entropy_e'] = torch.corrcoef(torch.stack([entropy_e_mol.detach(), t.squeeze()]))[0, 1].item() |
| | info['mean_entropy_e'] = entropy_e_mol.mean().item() |
| |
|
| | return (loss, info) if return_info else loss |
| |
|
| | def training_step(self, data, *args): |
| | ligand, pocket = data['ligand'], data['pocket'] |
| | try: |
| | loss, info = self.compute_loss(ligand, pocket, return_info=True) |
| | except RuntimeError as e: |
| | |
| | if self.trainer.num_devices < 2 and 'out of memory' in str(e): |
| | print('WARNING: ran out of memory, skipping to the next batch') |
| | return None |
| | else: |
| | raise e |
| |
|
| | log_dict = {k: v for k, v in info.items() if isinstance(v, float) |
| | or torch.numel(v) <= 1} |
| | |
| | |
| | |
| | |
| |
|
| | self.log_metrics({'loss': loss, **log_dict}, 'train', |
| | batch_size=len(ligand['size'])) |
| |
|
| | out = {'loss': loss, **info} |
| | self.training_step_outputs.append(out) |
| | return out |
| |
|
| | def validation_step(self, data, *args): |
| |
|
| | |
| | loss_list, info_list = [], [] |
| | self.dynamics.train() |
| | for _ in range(self.n_loss_per_sample): |
| | loss, info = self.compute_loss(data['ligand'].copy(), |
| | data['pocket'].copy(), |
| | return_info=True) |
| | loss_list.append(loss.item()) |
| | info_list.append(info) |
| | self.dynamics.eval() |
| | if len(loss_list) >= 1: |
| | loss = np.mean(loss_list) |
| | info = {k: np.mean([x[k] for x in info_list]) for k in info_list[0]} |
| | self.log_metrics({'loss': loss, **info}, 'val', batch_size=len(data['ligand']['size'])) |
| |
|
| | |
| | rdmols, rdpockets, _ = self.sample( |
| | data=data, |
| | n_samples=self.n_eval_samples, |
| | num_nodes="ground_truth" if self.sample_with_ground_truth_size else None, |
| | ) |
| |
|
| | out = { |
| | 'ligands': rdmols, |
| | 'pockets': rdpockets, |
| | 'receptor_files': [Path(self.receptor_dir, 'val', x) for x in data['pocket']['name']] |
| | } |
| | self.validation_step_outputs.append(out) |
| | return out |
| |
|
| | |
| | |
| |
|
| | def on_validation_epoch_end(self): |
| |
|
| | outdir = Path(self.outdir, f'epoch_{self.current_epoch}') |
| |
|
| | rdmols = [m for x in self.validation_step_outputs for m in x['ligands']] |
| | rdpockets = [p for x in self.validation_step_outputs for p in x['pockets']] |
| | receptors = [r for x in self.validation_step_outputs for r in x['receptor_files']] |
| | self.validation_step_outputs.clear() |
| |
|
| | ligand_atom_types = [atom_encoder[a.GetSymbol()] for m in rdmols for a in m.GetAtoms()] |
| | ligand_bond_types = [] |
| | for m in rdmols: |
| | bonds = m.GetBonds() |
| | no_bonds = m.GetNumAtoms() * (m.GetNumAtoms() - 1) // 2 - m.GetNumBonds() |
| | ligand_bond_types += [bond_encoder['NOBOND']] * no_bonds |
| | for b in bonds: |
| | ligand_bond_types.append(bond_encoder[b.GetBondType().name]) |
| |
|
| | tic = time() |
| | results = self.analyze_sample( |
| | rdmols, ligand_atom_types, ligand_bond_types, receptors=(rdpockets if len(rdpockets) != 0 else None) |
| | ) |
| | self.log_metrics(results, 'val') |
| | print(f'Evaluation took {time() - tic:.2f} seconds') |
| |
|
| | if (self.current_epoch + 1) % self.visualize_sample_epoch == 0: |
| | tic = time() |
| |
|
| | outdir.mkdir(exist_ok=True, parents=True) |
| |
|
| | |
| | rdmols = rdmols[:self.n_visualize_samples] |
| | rdpockets = rdpockets[:self.n_visualize_samples] |
| | for m, p in zip(rdmols, rdpockets): |
| | center = m.GetConformer().GetPositions().mean(axis=0) |
| | for i in range(m.GetNumAtoms()): |
| | x, y, z = m.GetConformer().GetPositions()[i] - center |
| | m.GetConformer().SetAtomPosition(i, (x, y, z)) |
| | for i in range(p.GetNumAtoms()): |
| | x, y, z = p.GetConformer().GetPositions()[i] - center |
| | p.GetConformer().SetAtomPosition(i, (x, y, z)) |
| |
|
| | |
| | utils.write_sdf_file(Path(outdir, 'molecules.sdf'), rdmols) |
| |
|
| | |
| | utils.write_sdf_file(Path(outdir, 'pockets.sdf'), rdpockets) |
| |
|
| | print(f'Sample visualization took {time() - tic:.2f} seconds') |
| |
|
| | if (self.current_epoch + 1) % self.visualize_chain_epoch == 0: |
| | tic = time() |
| | outdir.mkdir(exist_ok=True, parents=True) |
| |
|
| | if self.sharded_dataset: |
| | index = torch.randint(len(self.val_dataset), size=(1,)).item() |
| | for i, x in enumerate(self.val_dataset): |
| | if i == index: |
| | break |
| | batch = self.val_dataset.collate_fn([x]) |
| | else: |
| | batch = self.val_dataset.collate_fn([self.val_dataset[torch.randint(len(self.val_dataset), size=(1,))]]) |
| | batch['pocket'] = Residues(**batch['pocket']).to(self.device) |
| | pocket_copy = batch['pocket'].copy() |
| |
|
| | if len(batch['pocket']['x']) > 0: |
| | ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames) |
| | else: |
| | num_nodes, _ = self.size_distribution.sample() |
| | ligand_chain, pocket_chain, info = self.sample_chain(batch['pocket'], self.keep_frames, num_nodes=num_nodes) |
| |
|
| | |
| | |
| | if self.flexible or self.flexible_bb: |
| | |
| | ground_truth_pocket = pocket_to_rdkit( |
| | pocket_copy, self.pocket_representation, |
| | self.atom_encoder, self.atom_decoder, |
| | self.aa_decoder, self.residue_decoder, |
| | self.aa_atom_index |
| | )[0] |
| | ground_truth_ligand = build_molecule( |
| | batch['ligand']['x'], batch['ligand']['one_hot'].argmax(1), |
| | bonds=batch['ligand']['bonds'], |
| | bond_types=batch['ligand']['bond_one_hot'].argmax(1), |
| | atom_decoder=self.atom_decoder, |
| | bond_decoder=self.bond_decoder |
| | ) |
| | pocket_chain.insert(0, ground_truth_pocket) |
| | ligand_chain.insert(0, ground_truth_ligand) |
| | |
| | |
| |
|
| | |
| | utils.write_sdf_file(Path(outdir, 'chain_ligand.sdf'), ligand_chain) |
| |
|
| | |
| | mols_to_pdbfile(pocket_chain, Path(outdir, 'chain_pocket.pdb')) |
| |
|
| | self.log_metrics(info, 'val') |
| | print(f'Chain visualization took {time() - tic:.2f} seconds') |
| |
|
| |
|
| | |
| | |
| | |
| | @property |
| | def total_batch_idx(self) -> int: |
| | """Returns the current batch index (across epochs)""" |
| | |
| | |
| | return max(0, self.batch_progress.total.ready - 1) |
| |
|
| | @property |
| | def batch_idx(self) -> int: |
| | """Returns the current batch index (within this epoch)""" |
| | |
| | |
| | return max(0, self.batch_progress.current.ready - 1) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | def analyze_sample(self, rdmols, atom_types, bond_types, aa_types=None, receptors=None): |
| | out = {} |
| |
|
| | |
| | kl_div_atom = self.ligand_atom_type_distribution.kl_divergence(atom_types) \ |
| | if self.ligand_atom_type_distribution is not None else -1 |
| | out['kl_div_atom_types'] = kl_div_atom |
| |
|
| | |
| | kl_div_bond = self.ligand_bond_type_distribution.kl_divergence(bond_types) \ |
| | if self.ligand_bond_type_distribution is not None else -1 |
| | out['kl_div_bond_types'] = kl_div_bond |
| |
|
| | if aa_types is not None: |
| | kl_div_aa = self.pocket_type_distribution.kl_divergence(aa_types) \ |
| | if self.pocket_type_distribution is not None else -1 |
| | out['kl_div_residue_types'] = kl_div_aa |
| |
|
| | |
| | results = [] |
| | if receptors is not None: |
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | for mol, receptor in zip(tqdm(rdmols, desc='FullEvaluator'), receptors): |
| | receptor_path = Path(tmpdir, 'receptor.pdb') |
| | Chem.MolToPDBFile(receptor, str(receptor_path)) |
| | results.append(self.evaluator(mol, receptor_path)) |
| | else: |
| | for mol in tqdm(rdmols, desc='FullEvaluator'): |
| | self.evaluator = FullEvaluator(pb_conf='mol') |
| | results.append(self.evaluator(mol)) |
| |
|
| | results = pd.DataFrame(results) |
| | agg_results = aggregated_metrics(results, self.evaluator.dtypes, VALIDITY_METRIC_NAME).fillna(0) |
| | agg_results['metric'] = agg_results['metric'].str.replace('.', '/') |
| |
|
| | col_results = collection_metrics(results, self.train_smiles, VALIDITY_METRIC_NAME, exclude_evaluators='fcd') |
| | col_results['metric'] = 'collection/' + col_results['metric'] |
| |
|
| | all_results = pd.concat([agg_results, col_results]) |
| | out.update(**dict(all_results[['metric', 'value']].values)) |
| |
|
| | return out |
| |
|
| | def sample_zt_given_zs(self, zs_ligand, zs_pocket, s, t, delta_eps_x=None, uncertainty=None): |
| |
|
| | sc_transform = self.get_sc_transform_fn(zs_pocket.get('chi'), zs_ligand['x'], s, None, zs_ligand['mask'], zs_pocket) |
| | pred_ligand, pred_residues = self.dynamics( |
| | zs_ligand['x'], zs_ligand['h'], zs_ligand['mask'], zs_pocket, s, bonds_ligand=(zs_ligand['bonds'], zs_ligand['e']), |
| | sc_transform=sc_transform |
| | ) |
| |
|
| | if delta_eps_x is not None: |
| | pred_ligand['vel'] = pred_ligand['vel'] + delta_eps_x |
| |
|
| | zt_ligand = zs_ligand.copy() |
| | zt_ligand['x'] = self.module_x.sample_zt_given_zs(zs_ligand['x'], pred_ligand['vel'], s, t, zs_ligand['mask']) |
| |
|
| | zt_ligand['h'] = self.module_h.sample_zt_given_zs(zs_ligand['h'], pred_ligand['logits_h'], s, t, zs_ligand['mask']) |
| | zt_ligand['e'] = self.module_e.sample_zt_given_zs(zs_ligand['e'], pred_ligand['logits_e'], s, t, zs_ligand['edge_mask']) |
| |
|
| | zt_pocket = zs_pocket.copy() |
| | if self.flexible_bb: |
| | zt_trans_pocket = self.module_trans.sample_zt_given_zs(zs_pocket['x'], pred_residues['trans'], s, t, zs_pocket['mask']) |
| | zt_rot_pocket = self.module_rot.sample_zt_given_zs(zs_pocket['axis_angle'], pred_residues['rot'], s, t, zs_pocket['mask']) |
| |
|
| | |
| | zt_pocket.set_frame(zt_trans_pocket, zt_rot_pocket) |
| |
|
| | if self.flexible: |
| | zt_chi_pocket = self.module_chi.sample_zt_given_zs(zs_pocket['chi'][..., :5], pred_residues['chi'], s, t, zs_pocket['mask']) |
| |
|
| | |
| | zt_pocket.set_chi(zt_chi_pocket) |
| |
|
| | if self.predict_confidence: |
| | assert uncertainty is not None |
| | dt = (t - s).view(-1)[zt_ligand['mask']] |
| | uncertainty['sigma_x_squared'] += (dt * pred_ligand['uncertainty_vel']**2) |
| | uncertainty['entropy_h'] += (dt * Categorical(logits=pred_ligand['logits_h']).entropy()) |
| |
|
| | return zt_ligand, zt_pocket |
| |
|
| | def simulate(self, ligand, pocket, timesteps, t_start, t_end=1.0, |
| | return_frames=1, guide_log_prob=None): |
| | """ |
| | Take a version of the ligand and pocket (at any time step t_start) and |
| | simulate the generative process from t_start to t_end. |
| | """ |
| |
|
| | assert 0 < return_frames <= timesteps |
| | assert timesteps % return_frames == 0 |
| | assert 0.0 <= t_start < 1.0 |
| | assert 0 < t_end <= 1.0 |
| | assert t_start < t_end |
| |
|
| | device = ligand['x'].device |
| | n_samples = len(pocket['size']) |
| | delta_t = (t_end - t_start) / timesteps |
| |
|
| | |
| | out_ligand = { |
| | 'x': torch.zeros((return_frames, len(ligand['mask']), self.x_dim), device=device), |
| | 'h': torch.zeros((return_frames, len(ligand['mask']), self.atom_nf), device=device), |
| | 'e': torch.zeros((return_frames, len(ligand['edge_mask']), self.bond_nf), device=device) |
| | } |
| | if self.predict_confidence: |
| | out_ligand['sigma_x'] = torch.zeros((return_frames, len(ligand['mask'])), device=device) |
| | out_ligand['entropy_h'] = torch.zeros((return_frames, len(ligand['mask'])), device=device) |
| | out_pocket = { |
| | 'x': torch.zeros((return_frames, len(pocket['mask']), 3), device=device), |
| | 'v': torch.zeros((return_frames, len(pocket['mask']), self.n_atom_aa, 3), device=device) |
| | } |
| |
|
| | cumulative_uncertainty = { |
| | 'sigma_x_squared': torch.zeros(len(ligand['mask']), device=device), |
| | 'entropy_h': torch.zeros(len(ligand['mask']), device=device) |
| | } if self.predict_confidence else None |
| |
|
| | for i, t in enumerate(torch.linspace(t_start, t_end - delta_t, timesteps)): |
| | t_array = torch.full((n_samples, 1), fill_value=t, device=device) |
| |
|
| | if guide_log_prob is not None: |
| | raise NotImplementedError('Not yet implemented for flow matching model') |
| | alpha_t = self.diffusion_x.schedule.alpha(self.gamma_x(t_array)) |
| |
|
| | with torch.enable_grad(): |
| | zt_x_ligand.requires_grad = True |
| | g = guide_log_prob(t_array, x=ligand['x'], h=ligand['h'], batch_mask=ligand['mask'], |
| | bonds=ligand['bonds'], bond_types=ligand['e']) |
| |
|
| | |
| | grad_x_lig = torch.autograd.grad(g.sum(), inputs=ligand['x'])[0] |
| |
|
| | |
| | g_max = 1.0 |
| | clip_mask = (grad_x_lig.norm(dim=-1) > g_max) |
| | grad_x_lig[clip_mask] = \ |
| | grad_x_lig[clip_mask] / grad_x_lig[clip_mask].norm( |
| | dim=-1, keepdim=True) * g_max |
| |
|
| | delta_eps_lig = -1 * (1 - alpha_t[lig_mask]).sqrt() * grad_x_lig |
| | else: |
| | delta_eps_lig = None |
| |
|
| | ligand, pocket = self.sample_zt_given_zs( |
| | ligand, pocket, t_array, t_array + delta_t, delta_eps_lig, cumulative_uncertainty) |
| |
|
| | |
| | if (i + 1) % (timesteps // return_frames) == 0: |
| | idx = (i + 1) // (timesteps // return_frames) |
| | idx = idx - 1 |
| |
|
| | out_ligand['x'][idx] = ligand['x'].detach() |
| | out_ligand['h'][idx] = ligand['h'].detach() |
| | out_ligand['e'][idx] = ligand['e'].detach() |
| | if pocket['x'].numel() > 0: |
| | out_pocket['x'][idx] = pocket['x'].detach() |
| | out_pocket['v'][idx] = pocket['v'][:, :self.n_atom_aa, :].detach() |
| | if self.predict_confidence: |
| | out_ligand['sigma_x'][idx] = cumulative_uncertainty['sigma_x_squared'].sqrt().detach() |
| | out_ligand['entropy_h'][idx] = cumulative_uncertainty['entropy_h'].detach() |
| |
|
| | |
| | out_ligand = {k: v.squeeze(0) for k, v in out_ligand.items()} |
| | out_pocket = {k: v.squeeze(0) for k, v in out_pocket.items()} |
| |
|
| | return out_ligand, out_pocket |
| |
|
| | def init_ligand(self, num_nodes_lig, pocket): |
| | device = pocket['x'].device |
| |
|
| | n_samples = len(pocket['size']) |
| | lig_mask = utils.num_nodes_to_batch_mask(n_samples, num_nodes_lig, device) |
| |
|
| | |
| | lig_bonds = torch.stack(torch.where(torch.triu( |
| | lig_mask[:, None] == lig_mask[None, :], diagonal=1)), dim=0) |
| | lig_edge_mask = lig_mask[lig_bonds[0]] |
| |
|
| | |
| | pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
| | z0_x = self.module_x.sample_z0(pocket_com, lig_mask) |
| | z0_h = self.module_h.sample_z0(lig_mask) |
| | z0_e = self.module_e.sample_z0(lig_edge_mask) |
| |
|
| | return TensorDict(**{ |
| | 'x': z0_x, 'h': z0_h, 'e': z0_e, 'mask': lig_mask, |
| | 'bonds': lig_bonds, 'edge_mask': lig_edge_mask |
| | }) |
| |
|
| | def init_pocket(self, pocket): |
| |
|
| | if self.flexible_bb: |
| | pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
| | z0_trans = self.module_trans.sample_z0(pocket_com, pocket['mask']) |
| | z0_rot = self.module_rot.sample_z0(pocket['mask']) |
| |
|
| | |
| | pocket.set_frame(z0_trans, z0_rot) |
| |
|
| | if self.flexible: |
| | z0_chi = self.module_chi.sample_z0(pocket['mask']) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | pocket.set_chi(z0_chi) |
| |
|
| | if pocket['x'].numel() == 0: |
| | pocket.set_empty_v() |
| |
|
| | return pocket |
| |
|
| | def parse_num_nodes_spec(self, batch, spec=None, size_model=None): |
| |
|
| | if spec == "2d_histogram" or spec is None: |
| | assert "pocket" in batch |
| | num_nodes = self.size_distribution.sample_conditional( |
| | n1=None, n2=batch['pocket']['size']) |
| |
|
| | |
| | num_nodes[num_nodes < 2] = 2 |
| |
|
| | elif isinstance(spec, (int, torch.Tensor)): |
| | num_nodes = spec |
| |
|
| | elif spec == "ground_truth": |
| | assert "ligand" in batch |
| | num_nodes = batch['ligand']['size'] |
| |
|
| | elif spec == "nn_prediction": |
| | assert size_model is not None |
| | assert "pocket" in batch |
| | predictions = size_model.forward(batch['pocket']) |
| | predictions = torch.softmax(predictions, dim=-1) |
| | predictions[:, :5] = 0.0 |
| | probabilities = predictions / predictions.sum(dim=1, keepdims=True) |
| | num_nodes = torch.distributions.Categorical(probabilities).sample() |
| |
|
| | elif isinstance(spec, str) and spec.startswith("uniform"): |
| | |
| | assert "pocket" in batch |
| | left, right = map(int, spec.split("_")[1:]) |
| | shape = batch['pocket']['size'].shape |
| | num_nodes = torch.randint(left, right + 1, shape, dtype=torch.long) |
| |
|
| | else: |
| | raise NotImplementedError(f"Invalid size specification {spec}") |
| |
|
| | if self.virtual_nodes: |
| | num_nodes += self.add_virtual_max |
| |
|
| | return num_nodes |
| |
|
| | @torch.no_grad() |
| | def sample(self, data, n_samples, num_nodes=None, timesteps=None, |
| | guide_log_prob=None, size_model=None, **kwargs): |
| |
|
| | |
| | data['pocket'] = Residues(**data['pocket']) |
| |
|
| | timesteps = self.T_sampling if timesteps is None else timesteps |
| |
|
| | if len(data['pocket']['x']) > 0: |
| | pocket = data_utils.repeat_items(data['pocket'], n_samples) |
| | else: |
| | pocket = Residues(**{key: value for key, value in data['pocket'].items()}) |
| | pocket['name'] = pocket['name'] * n_samples |
| | pocket['size'] = pocket['size'].repeat(n_samples) |
| | pocket['n_bonds'] = pocket['n_bonds'].repeat(n_samples) |
| |
|
| | _ligand = data_utils.repeat_items(data['ligand'], n_samples) |
| | |
| |
|
| | batch = {"ligand": _ligand, "pocket": pocket} |
| | num_nodes = self.parse_num_nodes_spec(batch, spec=num_nodes, size_model=size_model) |
| |
|
| | |
| | if pocket['x'].numel() > 0: |
| | ligand = self.init_ligand(num_nodes, pocket) |
| | else: |
| | ligand = self.init_ligand(num_nodes, _ligand) |
| | pocket = self.init_pocket(pocket) |
| |
|
| | |
| | if timesteps == 0: |
| | |
| | rdmols = [build_molecule(coords=m['x'], |
| | atom_types=m['h'].argmax(1), |
| | bonds=m['bonds'], |
| | bond_types=m['e'].argmax(1), |
| | atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder) |
| | for m in data_utils.split_entity(ligand.detach().cpu(), edge_types={"e", "edge_mask"}, edge_mask=ligand["edge_mask"])] |
| |
|
| | rdpockets = pocket_to_rdkit(pocket, self.pocket_representation, |
| | self.atom_encoder, self.atom_decoder, |
| | self.aa_decoder, self.residue_decoder, |
| | self.aa_atom_index) |
| |
|
| | return rdmols, rdpockets, _ligand['name'] |
| |
|
| | out_tensors_ligand, out_tensors_pocket = self.simulate( |
| | ligand, pocket, timesteps, 0.0, 1.0, |
| | guide_log_prob=guide_log_prob |
| | ) |
| |
|
| | |
| | x = out_tensors_ligand['x'].detach().cpu() |
| | ligand_type = out_tensors_ligand['h'].argmax(1).detach().cpu() |
| | edge_type = out_tensors_ligand['e'].argmax(1).detach().cpu() |
| | lig_mask = ligand['mask'].detach().cpu() |
| | lig_bonds = ligand['bonds'].detach().cpu() |
| | lig_edge_mask = ligand['edge_mask'].detach().cpu() |
| | sizes = torch.unique(ligand['mask'], return_counts=True)[1].tolist() |
| | offsets = list(accumulate(sizes[:-1], initial=0)) |
| | mol_kwargs = { |
| | 'coords': utils.batch_to_list(x, lig_mask), |
| | 'atom_types': utils.batch_to_list(ligand_type, lig_mask), |
| | 'bonds': utils.batch_to_list_for_indices(lig_bonds, lig_edge_mask, offsets), |
| | 'bond_types': utils.batch_to_list(edge_type, lig_edge_mask) |
| | } |
| | if self.predict_confidence: |
| | sigma_x = out_tensors_ligand['sigma_x'].detach().cpu() |
| | entropy_h = out_tensors_ligand['entropy_h'].detach().cpu() |
| | mol_kwargs['atom_props'] = [ |
| | {'sigma_x': x[0], 'entropy_h': x[1]} |
| | for x in zip(utils.batch_to_list(sigma_x, lig_mask), |
| | utils.batch_to_list(entropy_h, lig_mask)) |
| | ] |
| | mol_kwargs = [{k: v[i] for k, v in mol_kwargs.items()} |
| | for i in range(len(mol_kwargs['coords']))] |
| |
|
| | |
| | rdmols = [build_molecule( |
| | **m, atom_decoder=self.atom_decoder, bond_decoder=self.bond_decoder) |
| | for m in mol_kwargs |
| | ] |
| |
|
| | out_pocket = pocket.copy() |
| | out_pocket['x'] = out_tensors_pocket['x'] |
| | out_pocket['v'] = out_tensors_pocket['v'] |
| | rdpockets = pocket_to_rdkit(out_pocket, self.pocket_representation, |
| | self.atom_encoder, self.atom_decoder, |
| | self.aa_decoder, self.residue_decoder, |
| | self.aa_atom_index) |
| |
|
| | return rdmols, rdpockets, _ligand['name'] |
| |
|
| | @torch.no_grad() |
| | def sample_chain(self, pocket, keep_frames, num_nodes=None, timesteps=None, |
| | guide_log_prob=None, **kwargs): |
| |
|
| | |
| | pocket = Residues(**pocket) |
| |
|
| | info = {} |
| |
|
| | timesteps = self.T_sampling if timesteps is None else timesteps |
| |
|
| | |
| | |
| | assert len(pocket['mask'].unique()) <= 1, "sample_chain only supports a single sample" |
| |
|
| | |
| | |
| |
|
| | num_nodes = self.parse_num_nodes_spec(batch={"pocket": pocket}, spec=num_nodes) |
| |
|
| | |
| | if pocket['x'].numel() > 0: |
| | ligand = self.init_ligand(num_nodes, pocket) |
| | else: |
| | dummy_pocket = Residues.empty(pocket['x'].device) |
| | ligand = self.init_ligand(num_nodes, dummy_pocket) |
| |
|
| | pocket = self.init_pocket(pocket) |
| |
|
| | out_tensors_ligand, out_tensors_pocket = self.simulate( |
| | ligand, pocket, timesteps, 0.0, 1.0, guide_log_prob=guide_log_prob, return_frames=keep_frames) |
| |
|
| | |
| | |
| | |
| |
|
| | info['traj_displacement_lig'] = torch.norm(out_tensors_ligand['x'][-1] - out_tensors_ligand['x'][0], dim=-1).mean() |
| | info['traj_rms_lig'] = out_tensors_ligand['x'].std(dim=0).mean() |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | assert keep_frames == out_tensors_ligand['x'].size(0) == out_tensors_pocket['x'].size(0) |
| | n_atoms = out_tensors_ligand['x'].size(1) |
| | n_bonds = out_tensors_ligand['e'].size(1) |
| | n_residues = out_tensors_pocket['x'].size(1) |
| | device = out_tensors_ligand['x'].device |
| |
|
| | def flatten_tensor(chain): |
| | if len(chain.size()) == 3: |
| | return chain.view(-1, chain.size(-1)) |
| | elif len(chain.size()) == 4: |
| | return chain.view(-1, chain.size(-2), chain.size(-1)) |
| | else: |
| | warnings.warn(f"Could not flatten frame dimension of tensor with shape {list(chain.size())}") |
| | return chain |
| |
|
| | out_tensors_ligand_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_ligand.items()} |
| | out_tensors_pocket_flat = {k: flatten_tensor(chain) for k, chain in out_tensors_pocket.items()} |
| | |
| | |
| | ligand_mask_flat = torch.arange(keep_frames).repeat_interleave(n_atoms).to(device) |
| |
|
| | |
| | |
| | |
| | |
| | pocket_mask_flat = torch.arange(keep_frames).repeat_interleave(n_residues).to(device) |
| |
|
| | |
| | |
| | bond_mask_flat = torch.arange(keep_frames).repeat_interleave(n_bonds).to(device) |
| | edges_flat = ligand['bonds'].repeat(1, keep_frames) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | x = out_tensors_ligand_flat['x'].detach().cpu() |
| | ligand_type = out_tensors_ligand_flat['h'].argmax(1).detach().cpu() |
| | ligand_mask_flat = ligand_mask_flat.detach().cpu() |
| | bond_mask_flat = bond_mask_flat.detach().cpu() |
| | edges_flat = edges_flat.detach().cpu() |
| | edge_type = out_tensors_ligand_flat['e'].argmax(1).detach().cpu() |
| | offsets = torch.zeros(keep_frames, dtype=int) |
| | molecules = list( |
| | zip(utils.batch_to_list(x, ligand_mask_flat), |
| | utils.batch_to_list(ligand_type, ligand_mask_flat), |
| | utils.batch_to_list_for_indices(edges_flat, bond_mask_flat, offsets), |
| | utils.batch_to_list(edge_type, bond_mask_flat) |
| | ) |
| | ) |
| |
|
| | |
| | ligand_chain = [build_molecule( |
| | *graph, atom_decoder=self.atom_decoder, |
| | bond_decoder=self.bond_decoder) for graph in molecules |
| | ] |
| |
|
| | |
| | |
| | |
| | out_pocket = { |
| | 'x': out_tensors_pocket_flat['x'], |
| | 'one_hot': pocket['one_hot'].repeat(keep_frames, 1), |
| | 'mask': pocket_mask_flat, |
| | 'v': out_tensors_pocket_flat['v'], |
| | 'atom_mask': pocket['atom_mask'].repeat(keep_frames, 1), |
| | } if self.flexible else pocket |
| | pocket_chain = pocket_to_rdkit(out_pocket, self.pocket_representation, |
| | self.atom_encoder, self.atom_decoder, |
| | self.aa_decoder, self.residue_decoder, |
| | self.aa_atom_index) |
| |
|
| | return ligand_chain, pocket_chain, info |
| |
|
| | |
| | |
| | def configure_gradient_clipping(self, optimizer, *args, **kwargs): |
| |
|
| | if not self.clip_grad: |
| | return |
| |
|
| | |
| | max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \ |
| | 2 * self.gradnorm_queue.std() |
| |
|
| | |
| | max_grad_norm = min(max_grad_norm, 10.0) |
| |
|
| | |
| | params = [p for g in optimizer.param_groups for p in g['params']] |
| | grad_norm = utils.get_grad_norm(params) |
| |
|
| | |
| | self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm, |
| | gradient_clip_algorithm='norm') |
| |
|
| | if float(grad_norm) > max_grad_norm: |
| | print(f'Clipped gradient with value {grad_norm:.1f} ' |
| | f'while allowed {max_grad_norm:.1f}') |
| | grad_norm = max_grad_norm |
| |
|
| | self.gradnorm_queue.add(float(grad_norm)) |