| |
| """ |
| Inference script. |
| |
| To run with base.yaml as the config, |
| |
| > python run_inference.py |
| |
| To specify a different config, |
| |
| > python run_inference.py --config-name symmetry |
| |
| where symmetry can be the filename of any other config (without .yaml extension) |
| See https://hydra.cc/docs/advanced/hydra-command-line-flags/ for more options. |
| |
| """ |
|
|
| import re |
| import os, time, pickle |
| import torch |
| from omegaconf import OmegaConf |
| import hydra |
| import logging |
| from rfdiffusion.util import writepdb_multi, writepdb |
| from rfdiffusion.inference import utils as iu |
| from hydra.core.hydra_config import HydraConfig |
| import numpy as np |
| import random |
| import glob |
|
|
|
|
| def make_deterministic(seed=0): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
|
|
| @hydra.main(version_base=None, config_path="../config/inference", config_name="base") |
| def main(conf: HydraConfig) -> None: |
| log = logging.getLogger(__name__) |
| if conf.inference.deterministic: |
| make_deterministic() |
|
|
| |
| if torch.cuda.is_available(): |
| device_name = torch.cuda.get_device_name(torch.cuda.current_device()) |
| log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}") |
| else: |
| log.info("////////////////////////////////////////////////") |
| log.info("///// NO GPU DETECTED! Falling back to CPU /////") |
| log.info("////////////////////////////////////////////////") |
|
|
| |
| sampler = iu.sampler_selector(conf) |
|
|
| |
| design_startnum = sampler.inf_conf.design_startnum |
| if sampler.inf_conf.design_startnum == -1: |
| existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb") |
| indices = [-1] |
| for e in existing: |
| print(e) |
| m = re.match(".*_(\d+)\.pdb$", e) |
| print(m) |
| if not m: |
| continue |
| m = m.groups()[0] |
| indices.append(int(m)) |
| design_startnum = max(indices) + 1 |
|
|
| for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs): |
| if conf.inference.deterministic: |
| make_deterministic(i_des) |
|
|
| start_time = time.time() |
| out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}" |
| log.info(f"Making design {out_prefix}") |
| if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"): |
| log.info( |
| f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists." |
| ) |
| continue |
|
|
| x_init, seq_init = sampler.sample_init() |
| denoised_xyz_stack = [] |
| px0_xyz_stack = [] |
| seq_stack = [] |
| plddt_stack = [] |
|
|
| x_t = torch.clone(x_init) |
| seq_t = torch.clone(seq_init) |
| |
| for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1): |
| px0, x_t, seq_t, plddt = sampler.sample_step( |
| t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step |
| ) |
| px0_xyz_stack.append(px0) |
| denoised_xyz_stack.append(x_t) |
| seq_stack.append(seq_t) |
| plddt_stack.append(plddt[0]) |
|
|
| |
| denoised_xyz_stack = torch.stack(denoised_xyz_stack) |
| denoised_xyz_stack = torch.flip( |
| denoised_xyz_stack, |
| [ |
| 0, |
| ], |
| ) |
| px0_xyz_stack = torch.stack(px0_xyz_stack) |
| px0_xyz_stack = torch.flip( |
| px0_xyz_stack, |
| [ |
| 0, |
| ], |
| ) |
|
|
| |
| plddt_stack = torch.stack(plddt_stack) |
|
|
| |
| os.makedirs(os.path.dirname(out_prefix), exist_ok=True) |
| final_seq = seq_stack[-1] |
|
|
| |
| final_seq = torch.where( |
| torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1) |
| ) |
|
|
| bfacts = torch.ones_like(final_seq.squeeze()) |
| |
| bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0 |
| |
| out = f"{out_prefix}.pdb" |
|
|
| |
| writepdb( |
| out, |
| denoised_xyz_stack[0, :, :4], |
| final_seq, |
| sampler.binderlen, |
| chain_idx=sampler.chain_idx, |
| bfacts=bfacts, |
| ) |
|
|
| |
| trb = dict( |
| config=OmegaConf.to_container(sampler._conf, resolve=True), |
| plddt=plddt_stack.cpu().numpy(), |
| device=torch.cuda.get_device_name(torch.cuda.current_device()) |
| if torch.cuda.is_available() |
| else "CPU", |
| time=time.time() - start_time, |
| ) |
| if hasattr(sampler, "contig_map"): |
| for key, value in sampler.contig_map.get_mappings().items(): |
| trb[key] = value |
| with open(f"{out_prefix}.trb", "wb") as f_out: |
| pickle.dump(trb, f_out) |
|
|
| if sampler.inf_conf.write_trajectory: |
| |
| traj_prefix = ( |
| os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix) |
| ) |
| os.makedirs(os.path.dirname(traj_prefix), exist_ok=True) |
|
|
| out = f"{traj_prefix}_Xt-1_traj.pdb" |
| writepdb_multi( |
| out, |
| denoised_xyz_stack, |
| bfacts, |
| final_seq.squeeze(), |
| use_hydrogens=False, |
| backbone_only=False, |
| chain_ids=sampler.chain_idx, |
| ) |
|
|
| out = f"{traj_prefix}_pX0_traj.pdb" |
| writepdb_multi( |
| out, |
| px0_xyz_stack, |
| bfacts, |
| final_seq.squeeze(), |
| use_hydrogens=False, |
| backbone_only=False, |
| chain_ids=sampler.chain_idx, |
| ) |
|
|
| log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|