| | import argparse |
| | import sys |
| | import yaml |
| | import torch |
| | import numpy as np |
| | import pickle |
| | from argparse import Namespace |
| |
|
| | from pathlib import Path |
| |
|
| | basedir = Path(__file__).resolve().parent.parent |
| | sys.path.append(str(basedir)) |
| |
|
| | from src import utils |
| | from src.utils import dict_to_namespace, namespace_to_dict |
| | from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb |
| | from src.data.data_utils import TensorDict, Residues |
| | from src.data.postprocessing import process_all |
| | from src.model.lightning import DrugFlow |
| | from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow |
| |
|
| | from tqdm import tqdm |
| | from pdb import set_trace |
| |
|
| |
|
| | def combine(base_args, override_args): |
| | assert not isinstance(base_args, dict) |
| | assert not isinstance(override_args, dict) |
| |
|
| | arg_dict = base_args.__dict__ |
| | for key, value in override_args.__dict__.items(): |
| | if key not in arg_dict or arg_dict[key] is None: |
| | print(f"Add parameter {key}: {value}") |
| | arg_dict[key] = value |
| | elif isinstance(value, Namespace): |
| | arg_dict[key] = combine(arg_dict[key], value) |
| | else: |
| | print(f"Replace parameter {key}: {arg_dict[key]} -> {value}") |
| | arg_dict[key] = value |
| | return base_args |
| |
|
| |
|
| | def path_to_str(input_dict): |
| | for key, value in input_dict.items(): |
| | if isinstance(value, dict): |
| | input_dict[key] = path_to_str(value) |
| | else: |
| | input_dict[key] = str(value) if isinstance(value, Path) else value |
| | return input_dict |
| |
|
| |
|
| | def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1): |
| | print('Sampling...') |
| | model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False, |
| | **model_params) |
| | model.setup(stage='fit' if cfg.set == 'train' else cfg.set) |
| | model.eval().to(cfg.device) |
| |
|
| | dataloader = getattr(model, f'{cfg.set}_dataloader')() |
| | print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}') |
| |
|
| | name2count = {} |
| | for i, data in enumerate(tqdm(dataloader)): |
| | if i % n_jobs != job_id: |
| | print(f'Skipping batch {i}') |
| | continue |
| |
|
| | new_data = { |
| | 'ligand': TensorDict(**data['ligand']).to(cfg.device), |
| | 'pocket': Residues(**data['pocket']).to(cfg.device), |
| | } |
| | try: |
| | rdmols, rdpockets, names = model.sample( |
| | data=new_data, |
| | n_samples=cfg.n_samples, |
| | num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None) |
| | ) |
| | except Exception as e: |
| | if cfg.set == 'train': |
| | names = data['ligand']['name'] |
| | print(f'Failed to sample for {names}: {e}') |
| | continue |
| | else: |
| | raise e |
| |
|
| | for mol, pocket, name in zip(rdmols, rdpockets, names): |
| | name = name.replace('.sdf', '') |
| | idx = name2count.setdefault(name, 0) |
| | output_dir = Path(samples_dir, name) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | if cfg.postprocess: |
| | mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0) |
| |
|
| | for prop in mol.GetAtoms()[0].GetPropsAsDict().keys(): |
| | |
| | mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()])) |
| |
|
| | |
| | out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb') |
| | mol_as_pdb(mol, out_pdb_path, bfactor=prop) |
| |
|
| | out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf') |
| | out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb') |
| | utils.write_sdf_file(out_sdf_path, [mol]) |
| | mols_to_pdbfile([pocket], out_pdb_path) |
| |
|
| | name2count[name] += 1 |
| |
|
| |
|
| | def evaluate(cfg, model_params, samples_dir): |
| | print('Evaluation...') |
| | data, table_detailed, table_aggregated = compute_all_metrics_drugflow( |
| | in_dir=samples_dir, |
| | gnina_path=model_params['train_params'].gnina, |
| | reduce_path=cfg.reduce, |
| | reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'), |
| | n_samples=cfg.n_samples, |
| | exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators, |
| | ) |
| | with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f: |
| | pickle.dump(data, f) |
| | table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False) |
| | table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | p = argparse.ArgumentParser() |
| | p.add_argument('--config', type=str) |
| | p.add_argument('--job_id', type=int, default=0, help='Job ID') |
| | p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs') |
| | args = p.parse_args() |
| |
|
| | with open(args.config, 'r') as f: |
| | cfg = yaml.safe_load(f) |
| | cfg = dict_to_namespace(cfg) |
| |
|
| | utils.set_deterministic(seed=cfg.seed) |
| | utils.disable_rdkit_logging() |
| |
|
| | model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters'] |
| | if 'model_args' in cfg: |
| | ckpt_args = dict_to_namespace(model_params) |
| | model_params = combine(ckpt_args, cfg.model_args).__dict__ |
| |
|
| | ckpt_path = Path(cfg.checkpoint) |
| | ckpt_name = ckpt_path.parts[-1].split('.')[0] |
| | n_steps = model_params['simulation_params'].n_steps |
| | samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \ |
| | Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}') |
| | assert cfg.set in {'val', 'test', 'train'} |
| | samples_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | with open(Path(samples_dir, 'model_params.yaml'), 'w') as f: |
| | yaml.dump(path_to_str(namespace_to_dict(model_params)), f) |
| | with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f: |
| | yaml.dump(path_to_str(namespace_to_dict(cfg)), f) |
| |
|
| | if cfg.sample: |
| | sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs) |
| |
|
| | if cfg.evaluate: |
| | assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines' |
| | evaluate(cfg, model_params, samples_dir) |