| | import argparse |
| | import sys |
| | import os |
| | import warnings |
| | import tempfile |
| | import pandas as pd |
| |
|
| | from Bio.PDB import PDBParser |
| | from pathlib import Path |
| | from rdkit import Chem |
| | from torch.utils.data import DataLoader |
| | from functools import partial |
| |
|
| | basedir = Path(__file__).resolve().parent.parent |
| | sys.path.append(str(basedir)) |
| | warnings.filterwarnings("ignore") |
| |
|
| | from src import utils |
| | from src.data.dataset import ProcessedLigandPocketDataset |
| | from src.data.data_utils import TensorDict, process_raw_pair |
| | from src.model.lightning import DrugFlow |
| | from src.sbdd_metrics.metrics import FullEvaluator |
| |
|
| | from tqdm import tqdm |
| | from pdb import set_trace |
| |
|
| |
|
| | def aggregate_metrics(table): |
| | agg_col = 'posebusters' |
| | total = 0 |
| | table[agg_col] = 0 |
| | for column in table.columns: |
| | if column.startswith(agg_col) and column != agg_col: |
| | table[agg_col] += table[column].fillna(0).astype(float) |
| | total += 1 |
| | table[agg_col] = table[agg_col] / total |
| |
|
| | agg_col = 'reos' |
| | total = 0 |
| | table[agg_col] = 0 |
| | for column in table.columns: |
| | if column.startswith(agg_col) and column != agg_col: |
| | table[agg_col] += table[column].fillna(0).astype(float) |
| | total += 1 |
| | table[agg_col] = table[agg_col] / total |
| |
|
| | agg_col = 'chembl_ring_systems' |
| | total = 0 |
| | table[agg_col] = 0 |
| | for column in table.columns: |
| | if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'): |
| | table[agg_col] += table[column].fillna(0).astype(float) |
| | total += 1 |
| | table[agg_col] = table[agg_col] / total |
| | return table |
| |
|
| |
|
| | if __name__ == "__main__": |
| | p = argparse.ArgumentParser() |
| | p.add_argument('--protein', type=str, required=True, help="Input PDB file.") |
| | p.add_argument('--ref_ligand', type=str, required=True, help="SDF file with reference ligand used to define the pocket.") |
| | p.add_argument('--checkpoint', type=str, required=True, help="Model checkpoint file.") |
| | p.add_argument('--molecule_size', type=str, required=False, default=None, help="Maximum number of atoms in the sampled molecules. Can be a single number or a range, e.g. '15,20'. If None, size will be sampled.") |
| | p.add_argument('--output', type=str, required=False, default='samples.sdf', help="Output file.") |
| | p.add_argument('--n_samples', type=int, required=False, default=10, help="Number of sampled molecules.") |
| | p.add_argument('--batch_size', type=int, required=False, default=32, help="Batch size.") |
| | p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0, help="Distance cutoff to define the pocket around the reference ligand.") |
| | p.add_argument('--n_steps', type=int, required=False, default=None, help="Number of denoising steps.") |
| | p.add_argument('--device', type=str, required=False, default='cuda:0', help="Device to use.") |
| | p.add_argument('--datadir', type=Path, required=False, default=Path(basedir, 'src', 'default'), help="Needs to be specified to sample molecule sizes.") |
| | p.add_argument('--seed', type=int, required=False, default=42, help="Random seed.") |
| | p.add_argument('--filter', action='store_true', required=False, default=False, help="Apply basic filters and keep sampling until `n_samples` molecules passing these filters are found.") |
| | p.add_argument('--metrics_output', type=str, required=False, default=None, help="If provided, metrics will be computed and saved in csv format at this location.") |
| | p.add_argument('--gnina', type=str, required=False, default=None, help="Path to a gnina executable. Required for computing docking scores.") |
| | p.add_argument('--reduce', type=str, required=False, default=None, help="Path to a reduce executable. Required for computing interactions.") |
| | args = p.parse_args() |
| |
|
| | utils.set_deterministic(seed=args.seed) |
| | utils.disable_rdkit_logging() |
| |
|
| | if args.molecule_size is None and (args.datadir is None or not args.datadir.exists()): |
| | raise NotImplementedError( |
| | "Please provide a path to the processed dataset (using `--datadir`) "\ |
| | "to infer the number of nodes. It contains the size distribution histogram." |
| | ) |
| | |
| | if not args.filter: |
| | args.batch_size = min(args.batch_size, args.n_samples) |
| |
|
| | |
| | chkpt_path = Path(args.checkpoint) |
| | chkpt_name = chkpt_path.parts[-1].split('.')[0] |
| | model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False) |
| | if args.datadir is not None: |
| | model.datadir = args.datadir |
| |
|
| | model.setup(stage='generation') |
| | model.batch_size = model.eval_batch_size = args.batch_size |
| | model.eval().to(args.device) |
| | if args.n_steps is not None: |
| | model.T = args.n_steps |
| |
|
| | |
| | size_model = None |
| | molecule_size = None |
| | molecule_size_boundaries = None |
| | if args.molecule_size is not None: |
| | if args.molecule_size.isdigit(): |
| | molecule_size = int(args.molecule_size) |
| | print(f'Will generate molecules of size {molecule_size}') |
| | else: |
| | boundaries = [x.strip() for x in args.molecule_size.split(',')] |
| | assert len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit() |
| | left = int(boundaries[0]) |
| | right = int(boundaries[1]) |
| | molecule_size = f"uniform_{left}_{right}" |
| | print(f'Will generate molecules with numbers of atoms sampled from U({left}, {right})') |
| |
|
| | |
| | pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0] |
| | rdmol = Chem.SDMolSupplier(str(args.ref_ligand))[0] |
| |
|
| | ligand, pocket = process_raw_pair( |
| | pdb_model, rdmol, |
| | dist_cutoff=args.pocket_distance_cutoff, |
| | pocket_representation=model.pocket_representation, |
| | compute_nerf_params=True, |
| | nma_input=args.protein if model.dynamics.add_nma_feat else None |
| | ) |
| | ligand['name'] = 'ligand' |
| | dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)] |
| | dataloader = DataLoader( |
| | dataset=dataset, |
| | batch_size=args.batch_size, |
| | collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None), |
| | pin_memory=True |
| | ) |
| |
|
| | |
| | smiles = set() |
| | sampled_molecules = [] |
| | metrics = [] |
| | Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True) |
| | print(f'Will generate {args.n_samples} samples') |
| |
|
| | evaluator = FullEvaluator(gnina=args.gnina, reduce=args.reduce) |
| |
|
| | with tqdm(total=args.n_samples) as pbar: |
| | while len(sampled_molecules) < args.n_samples: |
| | for i, data in enumerate(dataloader): |
| | new_data = { |
| | 'ligand': TensorDict(**data['ligand']).to(args.device), |
| | 'pocket': TensorDict(**data['pocket']).to(args.device), |
| | } |
| | rdmols, rdpockets, _ = model.sample( |
| | new_data, |
| | n_samples=1, |
| | timesteps=args.n_steps, |
| | num_nodes=molecule_size, |
| | ) |
| |
|
| | if args.filter or (args.metrics_output is not None): |
| | results = [] |
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | for mol, receptor in zip(rdmols, rdpockets): |
| | receptor_path = Path(tmpdir, 'receptor.pdb') |
| | Chem.MolToPDBFile(receptor, str(receptor_path)) |
| | results.append(evaluator(mol, receptor_path)) |
| |
|
| | table = pd.DataFrame(results) |
| | table['novel'] = ~table['representation.smiles'].isin(smiles) |
| | table = aggregate_metrics(table) |
| | |
| | added_molecules = 0 |
| | if args.filter: |
| | table['passed_filters'] = ( |
| | (table['posebusters'] == 1) & |
| | |
| | (table['chembl_ring_systems'] == 1) & |
| | (table['novel'] == 1) |
| | ) |
| | for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values): |
| | if passed: |
| | sampled_molecules.append(rdmols[i]) |
| | smiles.add(smi) |
| | added_molecules += 1 |
| |
|
| | if args.metrics_output is not None: |
| | metrics.append(table[table['passed_filters']]) |
| | |
| | else: |
| | sampled_molecules.extend(rdmols) |
| | added_molecules = len(rdmols) |
| | if args.metrics_output is not None: |
| | metrics.append(table) |
| |
|
| | pbar.update(added_molecules) |
| |
|
| | |
| | utils.write_sdf_file(args.output, sampled_molecules) |
| |
|
| | if args.metrics_output is not None: |
| | metrics = pd.concat(metrics) |
| | metrics.to_csv(args.metrics_output, index=False) |
| |
|