| | import matplotlib.pyplot as plt |
| | import matplotlib as mpl |
| | import numpy as np |
| | import os |
| | import pandas as pd |
| | from rdkit import Chem, DataStructs |
| | from rdkit.Chem import AllChem |
| | from rdkit.ML.Cluster import Butina |
| | from lightning.pytorch import seed_everything |
| | import torch |
| | from tqdm import tqdm |
| | from transformers import AutoModelForMaskedLM |
| | from datasets import Dataset, DatasetDict |
| | from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| |
|
| | seed_everything(1986) |
| |
|
| | df = pd.read_csv("caco2.csv") |
| |
|
| | mols = [] |
| | canon = [] |
| | keep_rows = [] |
| | bad = 0 |
| |
|
| | for i, smi in enumerate(df["SMILES"].astype(str)): |
| | m = Chem.MolFromSmiles(smi) |
| | if m is None: |
| | bad += 1 |
| | continue |
| | smi_can = Chem.MolToSmiles(m, canonical=True, isomericSmiles=True) |
| | mols.append(m) |
| | canon.append(smi_can) |
| | keep_rows.append(i) |
| |
|
| | df = df.iloc[keep_rows].reset_index(drop=True) |
| | df["SMILES_CANON"] = canon |
| |
|
| | print(f"Invalid SMILES dropped: {bad} / {len(df) + bad}") |
| |
|
| | |
| | dup_mask = df.duplicated(subset=["SMILES_CANON"], keep="first") |
| | df = df.loc[~dup_mask].reset_index(drop=True) |
| | mols = [m for m, isdup in zip(mols, dup_mask) if not isdup] |
| |
|
| | |
| | morgan = AllChem.GetMorganGenerator(radius=2, fpSize=2048, includeChirality=True) |
| | fps = [morgan.GetFingerprint(m) for m in mols] |
| |
|
| | |
| | sim_thresh = 0.6 |
| | dist_thresh = 1.0 - sim_thresh |
| |
|
| | dists = [] |
| | n = len(fps) |
| | for i in range(1, n): |
| | sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) |
| | dists.extend([1.0 - x for x in sims]) |
| |
|
| | clusters = Butina.ClusterData(dists, nPts=n, distThresh=dist_thresh, isDistData=True) |
| |
|
| | cluster_ids = np.empty(n, dtype=int) |
| | for cid, idxs in enumerate(clusters): |
| | for idx in idxs: |
| | cluster_ids[idx] = cid |
| |
|
| | df["cluster_id"] = cluster_ids |
| |
|
| | |
| | train_fraction = 0.8 |
| | rng = np.random.default_rng() |
| |
|
| | unique_clusters = df["cluster_id"].unique() |
| | rng.shuffle(unique_clusters) |
| |
|
| | train_target = int(train_fraction * len(df)) |
| | train_clusters = set() |
| | count = 0 |
| | for cid in unique_clusters: |
| | csize = (df["cluster_id"] == cid).sum() |
| | if count + csize <= train_target: |
| | train_clusters.add(cid) |
| | count += csize |
| |
|
| | df["split"] = np.where(df["cluster_id"].isin(train_clusters), "train", "val") |
| |
|
| | df[df["split"] == "train"].to_csv("caco2_train.csv", index=False) |
| | df[df["split"] == "val"].to_csv("caco2_val.csv", index=False) |
| | df.to_csv("caco2_meta_with_split.csv", index=False) |
| |
|
| | print(df["split"].value_counts()) |
| |
|
| | |
| | |
| | |
| | MAX_LENGTH = 768 |
| | BATCH_SIZE = 128 |
| |
|
| | TRAIN_CSV = "caco2_train.csv" |
| | VAL_CSV = "caco2_val.csv" |
| |
|
| | SMILES_COL = "SMILES" |
| | LABEL_COL = "Caco2" |
| |
|
| | OUT_PATH = "./Classifier_Weight/training_data_cleaned/permeability_caco2/caco2_smiles_with_embeddings" |
| |
|
| | |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | |
| | |
| | |
| | print("Loading tokenizer and model...") |
| | tokenizer = SMILES_SPE_Tokenizer( |
| | "./Classifier_Weight/tokenizer/new_vocab.txt", |
| | "./Classifier_Weight/tokenizer/new_splits.txt", |
| | ) |
| |
|
| | embedding_model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer |
| | embedding_model.to(device) |
| | embedding_model.eval() |
| |
|
| | HIDDEN_KEY = "last_hidden_state" |
| |
|
| | def get_special_ids(tokenizer): |
| | cand = [ |
| | getattr(tokenizer, "pad_token_id", None), |
| | getattr(tokenizer, "cls_token_id", None), |
| | getattr(tokenizer, "sep_token_id", None), |
| | getattr(tokenizer, "bos_token_id", None), |
| | getattr(tokenizer, "eos_token_id", None), |
| | getattr(tokenizer, "mask_token_id", None), |
| | ] |
| | special_ids = sorted({x for x in cand if x is not None}) |
| | if len(special_ids) == 0: |
| | print("[WARN] No special token ids found on tokenizer; pooling will only exclude padding via attention_mask.") |
| | return special_ids |
| |
|
| | SPECIAL_IDS = get_special_ids(tokenizer) |
| | SPECIAL_IDS_T = torch.tensor(SPECIAL_IDS, device=device, dtype=torch.long) if len(SPECIAL_IDS) else None |
| |
|
| | @torch.no_grad() |
| | def embed_batch_return_both(batch_sequences, max_length, device): |
| | tok = tokenizer( |
| | batch_sequences, |
| | return_tensors="pt", |
| | padding=True, |
| | max_length=max_length, |
| | truncation=True, |
| | ) |
| | input_ids = tok["input_ids"].to(device) |
| | attention_mask = tok["attention_mask"].to(device) |
| |
|
| | outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask) |
| | last_hidden = outputs.last_hidden_state |
| |
|
| | valid = attention_mask.bool() |
| | if SPECIAL_IDS_T is not None and SPECIAL_IDS_T.numel() > 0: |
| | valid = valid & (~torch.isin(input_ids, SPECIAL_IDS_T)) |
| |
|
| | |
| | valid_f = valid.unsqueeze(-1).float() |
| | summed = torch.sum(last_hidden * valid_f, dim=1) |
| | denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) |
| | pooled = (summed / denom).detach().cpu().numpy() |
| |
|
| | |
| | token_emb_list = [] |
| | mask_list = [] |
| | lengths = [] |
| | for b in range(last_hidden.shape[0]): |
| | emb = last_hidden[b, valid[b]] |
| | token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) |
| | L_i = emb.shape[0] |
| | lengths.append(int(L_i)) |
| | mask_list.append(np.ones((L_i,), dtype=np.int8)) |
| |
|
| | return pooled, token_emb_list, mask_list, lengths |
| |
|
| | def generate_embeddings_batched_both(sequences, batch_size, max_length): |
| | pooled_all = [] |
| | token_emb_all = [] |
| | mask_all = [] |
| | lengths_all = [] |
| |
|
| | for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding batches"): |
| | batch = sequences[i:i + batch_size] |
| | pooled, token_list, m_list, lens = embed_batch_return_both(batch, max_length, device) |
| | pooled_all.append(pooled) |
| | token_emb_all.extend(token_list) |
| | mask_all.extend(m_list) |
| | lengths_all.extend(lens) |
| |
|
| | pooled_all = np.vstack(pooled_all) |
| | return pooled_all, token_emb_all, mask_all, lengths_all |
| |
|
| | from datasets import Dataset, DatasetDict |
| |
|
| | def make_split_datasets(csv_path, split_name): |
| | df = pd.read_csv(csv_path) |
| | df = df.dropna(subset=[SMILES_COL, LABEL_COL]).reset_index(drop=True) |
| | df["sequence"] = df[SMILES_COL].astype(str) |
| |
|
| | labels = pd.to_numeric(df[LABEL_COL], errors="coerce") |
| | df = df.loc[~labels.isna()].reset_index(drop=True) |
| | sequences = df["sequence"].tolist() |
| | labels = pd.to_numeric(df[LABEL_COL], errors="coerce").tolist() |
| |
|
| | |
| | pooled_embs, token_emb_list, mask_list, lengths = generate_embeddings_batched_both( |
| | sequences, batch_size=BATCH_SIZE, max_length=MAX_LENGTH |
| | ) |
| |
|
| | pooled_ds = Dataset.from_dict({ |
| | "sequence": sequences, |
| | "label": labels, |
| | "embedding": pooled_embs, |
| | }) |
| |
|
| | full_ds = Dataset.from_dict({ |
| | "sequence": sequences, |
| | "label": labels, |
| | "embedding": token_emb_list, |
| | "attention_mask": mask_list, |
| | "length": lengths, |
| | }) |
| |
|
| | return pooled_ds, full_ds |
| |
|
| | train_pooled, train_full = make_split_datasets(TRAIN_CSV, "train") |
| | val_pooled, val_full = make_split_datasets(VAL_CSV, "val") |
| |
|
| | ds_pooled = DatasetDict({"train": train_pooled, "val": val_pooled}) |
| | ds_full = DatasetDict({"train": train_full, "val": val_full}) |
| |
|
| | ds_pooled.save_to_disk(OUT_PATH) |
| | ds_full.save_to_disk(OUT_PATH + "_unpooled") |
| |
|