| import random |
|
|
| import torch |
| import numpy as np |
|
|
| import torch.nn.functional as F |
|
|
| from src.PeptiVerse.inference import PeptiVersePredictor |
| from src.utils.model_utils import _print |
|
|
|
|
| class MadSBMSampler: |
| def __init__(self, model, config, device, guidance=None): |
| self.config = config |
| self.device = device |
| self.model = model |
| self.tokenizer = model.tokenizer |
| self.mask_id = self.tokenizer.mask_token_id |
| self.eps = config.time_embed.min_time |
| self.seed_everything(seed=42) |
|
|
| if guidance: |
| self.guidance = guidance |
| self.peptiverse = PeptiVersePredictor( |
| manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt", |
| classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse", |
| device=self.device |
| ) |
|
|
|
|
| @torch.inference_mode() |
| def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None): |
| xt = xt.clone() |
| B, L = xt.shape |
| assert B == 1, "Do only 1 sequence at a time" |
|
|
| t_max = 1.0 - self.eps |
| dt = 1.0 / num_steps |
| attn_mask = torch.ones_like(xt, device=self.device) |
|
|
| action_traj = {} |
| tot_action = 0.0 |
|
|
| tracer.log_step(xt=xt, step_idx=0) |
|
|
| converge_idx = num_steps |
| converged = False |
|
|
| for k in range(num_steps): |
| |
| prog = (k + 1) / float(num_steps) |
| t_val = t_max - (t_max - self.eps) * prog |
| t = torch.full((B,), fill_value=float(t_val), device=self.device) |
|
|
| |
| outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t) |
|
|
| u_tilt = outs['dit'] |
| total_logits = outs['madsbm'] |
| esm_logits = outs['esm'] |
|
|
| if self.config.model.ablate: |
| actional = self.compute_action(u_tilt, esm_logits=None) |
| else: |
| actional = self.compute_action(u_tilt, esm_logits=esm_logits) |
|
|
| action_traj[f"action_step_{k+1}"] = actional |
| tot_action += (actional * dt) |
|
|
| |
| |
| r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale) |
| R_tot = r_theta.sum(dim=-1) |
| rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0) |
| jump_prob = 1.0 - torch.exp(rate) |
|
|
| |
| logits = total_logits.clone() |
| logits /= self.config.sampling.tau |
| logits = self.top_p_filter(logits, self.config.sampling.top_p) |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
| probs = probs.view(-1, probs.size(-1)) |
| sample = torch.multinomial(probs, 1) |
| candidate_toks = sample.view(B, L) |
|
|
| |
| rand = torch.rand(B, L, device=self.device) |
| can_jump = (rand < jump_prob) |
| updatable = can_jump & self.is_masked(xt) |
|
|
| |
| if guidance: |
| chosen_candidate = self.binding_guidance(probs, target_toks, B, L) |
| xt[updatable] = chosen_candidate[updatable] |
| else: |
| xt[updatable] = candidate_toks[updatable] |
|
|
| tracer.log_step(xt=xt, step_idx = k+1) |
|
|
| if k == num_steps-1: |
| final_logits = total_logits |
| still_masked = self.is_masked(xt) |
| |
| if not converged and not self.is_masked(xt).any(): |
| converge_idx = k + 1 |
| converged = True |
|
|
| |
| if still_masked.any(): |
| final_toks = final_logits.argmax(dim=-1) |
| xt[still_masked] = final_toks[still_masked] |
| |
| tracer.log_step(xt, num_steps + 1) |
| |
| binding_affin = self.peptiverse.predict_binding_affinity( |
| mode = 'wt', |
| target_ids = target_toks, |
| binder_ids = xt |
| )['affinity'] |
|
|
| return xt, binding_affin |
|
|
| |
| def binding_guidance(self, probs, target_toks, B, L): |
| M = self.config.sampling.M |
| candidate_toks = [] |
| affinities = [] |
|
|
| for _ in range(M): |
| ith_sample = torch.multinomial(probs, 1).view(B, L) |
| candidate_toks.append(ith_sample) |
| |
| for toks in candidate_toks: |
| pred = self.peptiverse.predict_binding_affinity( |
| mode = 'wt', |
| target_ids = target_toks, |
| binder_ids = toks.detach() |
| )['affinity'] |
| affinities.append(pred) |
| |
| affinities = torch.tensor(affinities, dtype=torch.float32) |
| weights = F.softmax(affinities / self.config.sampling.tau, dim=0) |
| chosen_idx = torch.multinomial(weights, 1).item() |
|
|
| return candidate_toks[chosen_idx] |
|
|
|
|
| def compute_action(self, u_tilt, esm_logits=None): |
| """ Computes the action functional for evals """ |
| if esm_logits is not None: |
| R0 = torch.softmax(esm_logits, dim=-1) |
| else: |
| R0 = 1.0 / self.tokenizer.vocab_size |
|
|
| psi_u = torch.exp(u_tilt) - u_tilt - 1.0 |
| action_per_tok = (R0 * psi_u).sum(dim=-1) |
|
|
| return action_per_tok.mean().item() |
|
|
|
|
| def top_p_filter(self, logits, p_val): |
| """ |
| Implementation of nucleus / top-p sampling |
| Masks out tokens that contribute to the bottom (1 - p) cumulative probability |
| """ |
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| |
| sorted_idx_to_remove = cum_probs > p_val |
| |
| |
| sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() |
| sorted_idx_to_remove[..., 0] = 0 |
|
|
| idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove) |
| logits[idx_to_remove] = float('-inf') |
| return logits |
|
|
|
|
| def is_masked(self, xt): |
| return (xt == self.mask_id) |
|
|
|
|
| def seed_everything(self, seed): |
| if seed is None: |
| return |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|