| | import os |
| | import math |
| | from pathlib import Path |
| | import sys |
| | from contextlib import contextmanager |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | from tqdm import tqdm |
| | from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| | from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence |
| | from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM |
| | from lightning.pytorch import seed_everything |
| | seed_everything(1986) |
| |
|
| | CSV_PATH = Path("./Classifier_Weight/training_data_cleaned/binding_affinity/c-binding_with_openfold_scores.csv") |
| |
|
| | OUT_ROOT = Path( |
| | "./Classifier_Weight/training_data_cleaned/binding_affinity" |
| | ) |
| |
|
| | |
| | WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D" |
| | WT_MAX_LEN = 1022 |
| | WT_BATCH = 32 |
| |
|
| | |
| | SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all" |
| | TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt" |
| | TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt" |
| | SMI_MAX_LEN = 768 |
| | SMI_BATCH = 128 |
| |
|
| | |
| | TRAIN_FRAC = 0.80 |
| | RANDOM_SEED = 1986 |
| | AFFINITY_Q_BINS = 30 |
| |
|
| | COL_SEQ1 = "seq1" |
| | COL_SEQ2 = "seq2" |
| | COL_AFF = "affinity" |
| | COL_F2S = "Fasta2SMILES" |
| | COL_REACT = "REACT_SMILES" |
| | COL_WT_IPTM = "wt_iptm_score" |
| | COL_SMI_IPTM = "smiles_iptm_score" |
| |
|
| | |
| | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | QUIET = True |
| | USE_TQDM = False |
| | LOG_FILE = None |
| |
|
| | def log(msg: str): |
| | if LOG_FILE is not None: |
| | Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True) |
| | with open(LOG_FILE, "a") as f: |
| | f.write(msg.rstrip() + "\n") |
| | if not QUIET: |
| | print(msg) |
| |
|
| | def pbar(it, **kwargs): |
| | return tqdm(it, **kwargs) if USE_TQDM else it |
| |
|
| | @contextmanager |
| | def section(title: str): |
| | log(f"\n=== {title} ===") |
| | yield |
| | log(f"=== done: {title} ===") |
| |
|
| |
|
| | |
| | |
| | |
| | def has_uaa(seq: str) -> bool: |
| | return "X" in str(seq).upper() |
| |
|
| | def affinity_to_class(a: float) -> str: |
| | |
| | if a >= 9.0: |
| | return "High" |
| | elif a >= 7.0: |
| | return "Moderate" |
| | else: |
| | return "Low" |
| |
|
| | def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame: |
| | df = df.copy() |
| |
|
| | df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
| | df = df.dropna(subset=[COL_AFF]).reset_index(drop=True) |
| |
|
| | df["affinity_class"] = df[COL_AFF].apply(affinity_to_class) |
| |
|
| | try: |
| | df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop") |
| | strat_col = "aff_bin" |
| | except Exception: |
| | df["aff_bin"] = df["affinity_class"] |
| | strat_col = "aff_bin" |
| |
|
| | rng = np.random.RandomState(RANDOM_SEED) |
| |
|
| | df["split"] = None |
| | for _, g in df.groupby(strat_col, observed=True): |
| | idx = g.index.to_numpy() |
| | rng.shuffle(idx) |
| | n_train = int(math.floor(len(idx) * TRAIN_FRAC)) |
| | df.loc[idx[:n_train], "split"] = "train" |
| | df.loc[idx[n_train:], "split"] = "val" |
| |
|
| | df["split"] = df["split"].fillna("train") |
| | return df |
| |
|
| | def _summ(x): |
| | x = np.asarray(x, dtype=float) |
| | x = x[~np.isnan(x)] |
| | if len(x) == 0: |
| | return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} |
| | return { |
| | "n": int(len(x)), |
| | "mean": float(np.mean(x)), |
| | "std": float(np.std(x)), |
| | "p50": float(np.quantile(x, 0.50)), |
| | "p95": float(np.quantile(x, 0.95)), |
| | } |
| |
|
| | def _len_stats(seqs): |
| | lens = np.asarray([len(str(s)) for s in seqs], dtype=float) |
| | if len(lens) == 0: |
| | return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} |
| | return { |
| | "n": int(len(lens)), |
| | "mean": float(lens.mean()), |
| | "std": float(lens.std()), |
| | "p50": float(np.quantile(lens, 0.50)), |
| | "p95": float(np.quantile(lens, 0.95)), |
| | } |
| |
|
| | def verify_split_before_embedding( |
| | df2: pd.DataFrame, |
| | affinity_col: str, |
| | split_col: str, |
| | seq_col: str, |
| | iptm_col: str, |
| | aff_class_col: str = "affinity_class", |
| | aff_bins: int = 30, |
| | save_report_prefix: str | None = None, |
| | verbose: bool = False, |
| | ): |
| | df2 = df2.copy() |
| | df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce") |
| | df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce") |
| |
|
| | assert split_col in df2.columns, f"Missing split col: {split_col}" |
| | assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}" |
| | assert df2[affinity_col].notna().any(), "No valid affinity values after coercion." |
| |
|
| | try: |
| | df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop") |
| | except Exception: |
| | df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str) |
| |
|
| | tr = df2[df2[split_col] == "train"].reset_index(drop=True) |
| | va = df2[df2[split_col] == "val"].reset_index(drop=True) |
| |
|
| | tr_aff = _summ(tr[affinity_col].to_numpy()) |
| | va_aff = _summ(va[affinity_col].to_numpy()) |
| | tr_len = _len_stats(tr[seq_col].tolist()) |
| | va_len = _len_stats(va[seq_col].tolist()) |
| |
|
| | |
| | bin_ct = ( |
| | df2.groupby([split_col, "_aff_bin_dbg"]) |
| | .size() |
| | .groupby(level=0) |
| | .apply(lambda s: s / s.sum()) |
| | ) |
| | tr_bins = bin_ct.loc["train"] |
| | va_bins = bin_ct.loc["val"] |
| | all_bins = tr_bins.index.union(va_bins.index) |
| | tr_bins = tr_bins.reindex(all_bins, fill_value=0.0) |
| | va_bins = va_bins.reindex(all_bins, fill_value=0.0) |
| | max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values))) |
| |
|
| | msg = ( |
| | f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | " |
| | f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | " |
| | f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | " |
| | f"max_bin_diff={max_bin_diff:.4f}" |
| | ) |
| | log(msg) |
| |
|
| | if verbose and (not QUIET): |
| | class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) |
| | class_prop = class_ct.div(class_ct.sum(axis=1), axis=0) |
| | print("\n[verbose] affinity_class counts:\n", class_ct) |
| | print("\n[verbose] affinity_class proportions:\n", class_prop.round(4)) |
| |
|
| | if save_report_prefix is not None: |
| | out = Path(save_report_prefix) |
| | out.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | stats_df = pd.DataFrame([ |
| | {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}}, |
| | {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}}, |
| | ]) |
| | class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) |
| | class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index() |
| |
|
| | stats_df.to_csv(out.with_suffix(".stats.csv"), index=False) |
| | class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False) |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022): |
| | embs = [] |
| | for i in pbar(range(0, len(seqs), batch_size)): |
| | batch = seqs[i:i + batch_size] |
| | inputs = tokenizer( |
| | batch, |
| | padding=True, |
| | truncation=True, |
| | max_length=max_length, |
| | return_tensors="pt", |
| | ) |
| | inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| | out = model(**inputs) |
| | h = out.last_hidden_state |
| |
|
| | attn = inputs["attention_mask"].unsqueeze(-1) |
| | summed = (h * attn).sum(dim=1) |
| | denom = attn.sum(dim=1).clamp(min=1e-9) |
| | pooled = (summed / denom).detach().cpu().numpy() |
| | embs.append(pooled) |
| |
|
| | return np.vstack(embs) |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022): |
| | tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") |
| | tok = {k: v.to(DEVICE) for k, v in tok.items()} |
| | out = model(**tok) |
| | h = out.last_hidden_state[0] |
| | attn = tok["attention_mask"][0].bool() |
| | ids = tok["input_ids"][0] |
| |
|
| | keep = attn.clone() |
| | if cls_id is not None: |
| | keep &= (ids != cls_id) |
| | if eos_id is not None: |
| | keep &= (ids != eos_id) |
| |
|
| | return h[keep].detach().cpu().to(torch.float16).numpy() |
| |
|
| | def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model): |
| | """ |
| | Expects df_split to have: |
| | - target_sequence (seq1) |
| | - sequence (binder seq2; WT binder) |
| | - label, affinity_class, COL_AFF, COL_WT_IPTM |
| | Saves a dataset where each row contains BOTH: |
| | - target_embedding (Lt,H), target_attention_mask, target_length |
| | - binder_embedding (Lb,H), binder_attention_mask, binder_length |
| | """ |
| | cls_id = tokenizer.cls_token_id |
| | eos_id = tokenizer.eos_token_id |
| | H = model.config.hidden_size |
| |
|
| | features = Features({ |
| | "target_sequence": Value("string"), |
| | "sequence": Value("string"), |
| | "label": Value("float32"), |
| | "affinity": Value("float32"), |
| | "affinity_class": Value("string"), |
| |
|
| | "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)), |
| | "target_attention_mask": HFSequence(Value("int8")), |
| | "target_length": Value("int64"), |
| |
|
| | "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)), |
| | "binder_attention_mask": HFSequence(Value("int8")), |
| | "binder_length": Value("int64"), |
| |
|
| | COL_WT_IPTM: Value("float32"), |
| | COL_AFF: Value("float32"), |
| | }) |
| |
|
| | def gen_rows(df: pd.DataFrame): |
| | for r in pbar(df.itertuples(index=False), total=len(df)): |
| | tgt = str(getattr(r, "target_sequence")).strip() |
| | bnd = str(getattr(r, "sequence")).strip() |
| |
|
| | y = float(getattr(r, "label")) |
| | aff = float(getattr(r, COL_AFF)) |
| | acls = str(getattr(r, "affinity_class")) |
| |
|
| | iptm = getattr(r, COL_WT_IPTM) |
| | iptm = float(iptm) if pd.notna(iptm) else np.nan |
| |
|
| | |
| | t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) |
| | b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) |
| |
|
| | t_list = t_emb.tolist() |
| | b_list = b_emb.tolist() |
| | Lt = len(t_list) |
| | Lb = len(b_list) |
| |
|
| | yield { |
| | "target_sequence": tgt, |
| | "sequence": bnd, |
| | "label": np.float32(y), |
| | "affinity": np.float32(aff), |
| | "affinity_class": acls, |
| |
|
| | "target_embedding": t_list, |
| | "target_attention_mask": [1] * Lt, |
| | "target_length": int(Lt), |
| |
|
| | "binder_embedding": b_list, |
| | "binder_attention_mask": [1] * Lb, |
| | "binder_length": int(Lb), |
| |
|
| | COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), |
| | COL_AFF: np.float32(aff), |
| | } |
| |
|
| | out_dir.mkdir(parents=True, exist_ok=True) |
| | ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) |
| | ds.save_to_disk(str(out_dir), max_shard_size="1GB") |
| | return ds |
| |
|
| | def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled, |
| | smi_tok, smi_roformer): |
| | """ |
| | df_split must have: |
| | - target_sequence (seq1) |
| | - sequence (binder smiles string) |
| | - label, affinity_class, COL_AFF, COL_SMI_IPTM |
| | Saves rows with: |
| | target_embedding (Lt,Ht) from ESM |
| | binder_embedding (Lb,Hb) from PeptideCLM |
| | """ |
| | cls_id = wt_tokenizer.cls_token_id |
| | eos_id = wt_tokenizer.eos_token_id |
| | Ht = wt_model_unpooled.config.hidden_size |
| |
|
| | Hb = getattr(smi_roformer.config, "hidden_size", None) |
| | if Hb is None: |
| | Hb = getattr(smi_roformer.config, "dim", None) |
| | if Hb is None: |
| | raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.") |
| |
|
| | features = Features({ |
| | "target_sequence": Value("string"), |
| | "sequence": Value("string"), |
| | "label": Value("float32"), |
| | "affinity": Value("float32"), |
| | "affinity_class": Value("string"), |
| |
|
| | "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)), |
| | "target_attention_mask": HFSequence(Value("int8")), |
| | "target_length": Value("int64"), |
| |
|
| | "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)), |
| | "binder_attention_mask": HFSequence(Value("int8")), |
| | "binder_length": Value("int64"), |
| |
|
| | COL_SMI_IPTM: Value("float32"), |
| | COL_AFF: Value("float32"), |
| | }) |
| |
|
| | def gen_rows(df: pd.DataFrame): |
| | for r in pbar(df.itertuples(index=False), total=len(df)): |
| | tgt = str(getattr(r, "target_sequence")).strip() |
| | bnd = str(getattr(r, "sequence")).strip() |
| |
|
| | y = float(getattr(r, "label")) |
| | aff = float(getattr(r, COL_AFF)) |
| | acls = str(getattr(r, "affinity_class")) |
| |
|
| | iptm = getattr(r, COL_SMI_IPTM) |
| | iptm = float(iptm) if pd.notna(iptm) else np.nan |
| |
|
| | |
| | t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN) |
| | t_list = t_emb.tolist() |
| | Lt = len(t_list) |
| |
|
| | |
| | _, tok_list, mask_list, lengths = smiles_embed_batch_return_both( |
| | [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN |
| | ) |
| | b_emb = tok_list[0] |
| | b_list = b_emb.tolist() |
| | Lb = int(lengths[0]) |
| | b_mask = mask_list[0].astype(np.int8).tolist() |
| |
|
| | yield { |
| | "target_sequence": tgt, |
| | "sequence": bnd, |
| | "label": np.float32(y), |
| | "affinity": np.float32(aff), |
| | "affinity_class": acls, |
| |
|
| | "target_embedding": t_list, |
| | "target_attention_mask": [1] * Lt, |
| | "target_length": int(Lt), |
| |
|
| | "binder_embedding": b_list, |
| | "binder_attention_mask": [int(x) for x in b_mask], |
| | "binder_length": int(Lb), |
| |
|
| | COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), |
| | COL_AFF: np.float32(aff), |
| | } |
| |
|
| | out_dir.mkdir(parents=True, exist_ok=True) |
| | ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) |
| | ds.save_to_disk(str(out_dir), max_shard_size="1GB") |
| | return ds |
| |
|
| |
|
| | |
| | |
| | |
| | def get_special_ids(tokenizer_obj): |
| | cand = [ |
| | getattr(tokenizer_obj, "pad_token_id", None), |
| | getattr(tokenizer_obj, "cls_token_id", None), |
| | getattr(tokenizer_obj, "sep_token_id", None), |
| | getattr(tokenizer_obj, "bos_token_id", None), |
| | getattr(tokenizer_obj, "eos_token_id", None), |
| | getattr(tokenizer_obj, "mask_token_id", None), |
| | ] |
| | return sorted({x for x in cand if x is not None}) |
| |
|
| | @torch.no_grad() |
| | def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length): |
| | tok = tokenizer_obj( |
| | batch_sequences, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=max_length, |
| | ) |
| | input_ids = tok["input_ids"].to(DEVICE) |
| | attention_mask = tok["attention_mask"].to(DEVICE) |
| |
|
| | outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask) |
| | last_hidden = outputs.last_hidden_state |
| |
|
| | special_ids = get_special_ids(tokenizer_obj) |
| | valid = attention_mask.bool() |
| | if len(special_ids) > 0: |
| | sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long) |
| | if hasattr(torch, "isin"): |
| | valid = valid & (~torch.isin(input_ids, sid)) |
| | else: |
| | m = torch.zeros_like(valid) |
| | for s in special_ids: |
| | m |= (input_ids == s) |
| | valid = valid & (~m) |
| |
|
| | valid_f = valid.unsqueeze(-1).float() |
| | summed = torch.sum(last_hidden * valid_f, dim=1) |
| | denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) |
| | pooled = (summed / denom).detach().cpu().numpy() |
| |
|
| | token_emb_list, mask_list, lengths = [], [], [] |
| | for b in range(last_hidden.shape[0]): |
| | emb = last_hidden[b, valid[b]] |
| | token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) |
| | li = emb.shape[0] |
| | lengths.append(int(li)) |
| | mask_list.append(np.ones((li,), dtype=np.int8)) |
| |
|
| | return pooled, token_emb_list, mask_list, lengths |
| |
|
| | def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length): |
| | pooled_all = [] |
| | token_emb_all = [] |
| | mask_all = [] |
| | lengths_all = [] |
| |
|
| | for i in pbar(range(0, len(seqs), batch_size)): |
| | batch = seqs[i:i + batch_size] |
| | pooled, tok_list, m_list, lens = smiles_embed_batch_return_both( |
| | batch, tokenizer_obj, model_roformer, max_length |
| | ) |
| | pooled_all.append(pooled) |
| | token_emb_all.extend(tok_list) |
| | mask_all.extend(m_list) |
| | lengths_all.extend(lens) |
| |
|
| | return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all |
| |
|
| | def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame): |
| | wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) |
| | wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() |
| |
|
| | |
| | tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist() |
| | tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist() |
| |
|
| | wt_train_tgt_emb = wt_pooled_embeddings( |
| | tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN |
| | ) |
| | wt_val_tgt_emb = wt_pooled_embeddings( |
| | tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN |
| | ) |
| |
|
| | |
| | train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)} |
| | val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)} |
| | return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map |
| | |
| | |
| | |
| | def main(): |
| | log(f"[INFO] DEVICE: {DEVICE}") |
| | OUT_ROOT.mkdir(parents=True, exist_ok=True) |
| |
|
| | with section("load csv + dedup"): |
| | df = pd.read_csv(CSV_PATH) |
| | for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]: |
| | if c in df.columns: |
| | df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x) |
| | |
| | |
| | DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT] |
| | df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True) |
| | |
| | print("Rows after dedup on", DEDUP_COLS, ":", len(df)) |
| |
|
| | need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM] |
| | missing = [c for c in need if c not in df.columns] |
| | if missing: |
| | raise ValueError(f"Missing required columns: {missing}") |
| |
|
| | |
| | df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
| |
|
| | |
| | with section("prepare wt/smiles subsets"): |
| | |
| | df_wt = df.copy() |
| | df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip() |
| | df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True) |
| | df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")] |
| | df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True) |
| |
|
| | |
| | df_smi = df.copy() |
| | df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True) |
| | df_smi = df_smi[ |
| | pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna() |
| | ].reset_index(drop=True) |
| |
|
| | is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False) |
| | df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S]) |
| | df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip() |
| | df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")] |
| | df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True) |
| |
|
| | log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)") |
| |
|
| | |
| | with section("split wt and smiles separately"): |
| | df_wt2 = make_distribution_matched_split(df_wt) |
| | df_smi2 = make_distribution_matched_split(df_smi) |
| |
|
| | |
| | wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv" |
| | smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv" |
| | df_wt2.to_csv(wt_split_csv, index=False) |
| | df_smi2.to_csv(smi_split_csv, index=False) |
| | log(f"Saved WT split meta: {wt_split_csv}") |
| | log(f"Saved SMILES split meta: {smi_split_csv}") |
| |
|
| | verify_split_before_embedding( |
| | df2=df_wt2, |
| | affinity_col=COL_AFF, |
| | split_col="split", |
| | seq_col="wt_sequence", |
| | iptm_col=COL_WT_IPTM, |
| | aff_class_col="affinity_class", |
| | aff_bins=AFFINITY_Q_BINS, |
| | save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"), |
| | verbose=False, |
| | ) |
| | verify_split_before_embedding( |
| | df2=df_smi2, |
| | affinity_col=COL_AFF, |
| | split_col="split", |
| | seq_col="smiles_sequence", |
| | iptm_col=COL_SMI_IPTM, |
| | aff_class_col="affinity_class", |
| | aff_bins=AFFINITY_Q_BINS, |
| | save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"), |
| | verbose=False, |
| | ) |
| |
|
| | |
| | def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame: |
| | out = df_in.copy() |
| | out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() |
| | out["sequence"] = out[binder_seq_col].astype(str).str.strip() |
| | out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce") |
| | out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce") |
| | out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce") |
| | out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True) |
| | return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]] |
| |
|
| | wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM) |
| | smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM) |
| |
|
| | |
| | |
| | |
| | wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True) |
| | wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True) |
| | smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True) |
| | smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True) |
| | |
| | |
| | |
| | |
| | |
| | with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"): |
| | wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) |
| | wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() |
| | |
| | |
| | wt_train_tgt_emb = wt_pooled_embeddings( |
| | wt_train["target_sequence"].astype(str).str.strip().tolist(), |
| | wt_tok, wt_esm, |
| | batch_size=WT_BATCH, |
| | max_length=WT_MAX_LEN, |
| | ).astype(np.float32) |
| | |
| | wt_val_tgt_emb = wt_pooled_embeddings( |
| | wt_val["target_sequence"].astype(str).str.strip().tolist(), |
| | wt_tok, wt_esm, |
| | batch_size=WT_BATCH, |
| | max_length=WT_MAX_LEN, |
| | ).astype(np.float32) |
| | |
| | |
| | smi_train_tgt_emb = wt_pooled_embeddings( |
| | smi_train["target_sequence"].astype(str).str.strip().tolist(), |
| | wt_tok, wt_esm, |
| | batch_size=WT_BATCH, |
| | max_length=WT_MAX_LEN, |
| | ).astype(np.float32) |
| | |
| | smi_val_tgt_emb = wt_pooled_embeddings( |
| | smi_val["target_sequence"].astype(str).str.strip().tolist(), |
| | wt_tok, wt_esm, |
| | batch_size=WT_BATCH, |
| | max_length=WT_MAX_LEN, |
| | ).astype(np.float32) |
| | |
| | |
| | |
| | |
| | |
| | with section("WT pooled binder embeddings + save"): |
| | wt_train_emb = wt_pooled_embeddings( |
| | wt_train["sequence"].astype(str).str.strip().tolist(), |
| | wt_tok, wt_esm, |
| | batch_size=WT_BATCH, |
| | max_length=WT_MAX_LEN, |
| | ).astype(np.float32) |
| | |
| | wt_val_emb = wt_pooled_embeddings( |
| | wt_val["sequence"].astype(str).str.strip().tolist(), |
| | wt_tok, wt_esm, |
| | batch_size=WT_BATCH, |
| | max_length=WT_MAX_LEN, |
| | ).astype(np.float32) |
| | |
| | wt_train_ds = Dataset.from_dict({ |
| | "target_sequence": wt_train["target_sequence"].tolist(), |
| | "sequence": wt_train["sequence"].tolist(), |
| | "label": wt_train["label"].astype(float).tolist(), |
| | "target_embedding": wt_train_tgt_emb, |
| | "embedding": wt_train_emb, |
| | COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(), |
| | COL_AFF: wt_train[COL_AFF].astype(float).tolist(), |
| | "affinity_class": wt_train["affinity_class"].tolist(), |
| | }) |
| | |
| | wt_val_ds = Dataset.from_dict({ |
| | "target_sequence": wt_val["target_sequence"].tolist(), |
| | "sequence": wt_val["sequence"].tolist(), |
| | "label": wt_val["label"].astype(float).tolist(), |
| | "target_embedding": wt_val_tgt_emb, |
| | "embedding": wt_val_emb, |
| | COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(), |
| | COL_AFF: wt_val[COL_AFF].astype(float).tolist(), |
| | "affinity_class": wt_val["affinity_class"].tolist(), |
| | }) |
| | |
| | wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds}) |
| | wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled" |
| | wt_pooled_dd.save_to_disk(str(wt_pooled_out)) |
| | log(f"Saved WT pooled -> {wt_pooled_out}") |
| | |
| | |
| | |
| | |
| | |
| | with section("SMILES pooled binder embeddings + save"): |
| | smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS) |
| | smi_roformer = ( |
| | AutoModelForMaskedLM |
| | .from_pretrained(SMI_MODEL_NAME) |
| | .roformer |
| | .to(DEVICE) |
| | .eval() |
| | ) |
| | |
| | smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both( |
| | smi_train["sequence"].astype(str).str.strip().tolist(), |
| | smi_tok, smi_roformer, |
| | batch_size=SMI_BATCH, |
| | max_length=SMI_MAX_LEN, |
| | ) |
| | |
| | smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both( |
| | smi_val["sequence"].astype(str).str.strip().tolist(), |
| | smi_tok, smi_roformer, |
| | batch_size=SMI_BATCH, |
| | max_length=SMI_MAX_LEN, |
| | ) |
| | |
| | smi_train_ds = Dataset.from_dict({ |
| | "target_sequence": smi_train["target_sequence"].tolist(), |
| | "sequence": smi_train["sequence"].tolist(), |
| | "label": smi_train["label"].astype(float).tolist(), |
| | "target_embedding": smi_train_tgt_emb, |
| | "embedding": smi_train_pooled.astype(np.float32), |
| | COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(), |
| | COL_AFF: smi_train[COL_AFF].astype(float).tolist(), |
| | "affinity_class": smi_train["affinity_class"].tolist(), |
| | }) |
| | |
| | smi_val_ds = Dataset.from_dict({ |
| | "target_sequence": smi_val["target_sequence"].tolist(), |
| | "sequence": smi_val["sequence"].tolist(), |
| | "label": smi_val["label"].astype(float).tolist(), |
| | "target_embedding": smi_val_tgt_emb, |
| | "embedding": smi_val_pooled.astype(np.float32), |
| | COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(), |
| | COL_AFF: smi_val[COL_AFF].astype(float).tolist(), |
| | "affinity_class": smi_val["affinity_class"].tolist(), |
| | }) |
| | |
| | smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds}) |
| | smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled" |
| | smi_pooled_dd.save_to_disk(str(smi_pooled_out)) |
| | log(f"Saved SMILES pooled -> {smi_pooled_out}") |
| |
|
| |
|
| | |
| | |
| | |
| | with section("WT unpooled paired embeddings + save"): |
| | wt_tok_unpooled = wt_tok |
| | wt_esm_unpooled = wt_esm |
| |
|
| | wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled" |
| | wt_unpooled_dd = DatasetDict({ |
| | "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train", |
| | wt_tok_unpooled, wt_esm_unpooled), |
| | "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val", |
| | wt_tok_unpooled, wt_esm_unpooled), |
| | }) |
| | wt_unpooled_dd.save_to_disk(str(wt_unpooled_out)) |
| | log(f"Saved WT unpooled -> {wt_unpooled_out}") |
| |
|
| |
|
| | |
| | |
| | |
| | with section("SMILES unpooled paired embeddings + save"): |
| | smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled" |
| | smi_unpooled_dd = DatasetDict({ |
| | "train": build_smiles_unpooled_paired_dataset( |
| | smi_train, smi_unpooled_out / "train", |
| | wt_tok, wt_esm, |
| | smi_tok, smi_roformer |
| | ), |
| | "val": build_smiles_unpooled_paired_dataset( |
| | smi_val, smi_unpooled_out / "val", |
| | wt_tok, wt_esm, |
| | smi_tok, smi_roformer |
| | ), |
| | }) |
| | smi_unpooled_dd.save_to_disk(str(smi_unpooled_out)) |
| | log(f"Saved SMILES unpooled -> {smi_unpooled_out}") |
| |
|
| | log(f"\n[DONE] All datasets saved under: {OUT_ROOT}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|