| | |
| | import matplotlib.pyplot as plt |
| | import matplotlib as mpl |
| | import numpy as np |
| | import os |
| | import pandas as pd |
| | from lightning.pytorch import seed_everything |
| | import torch |
| | from tqdm import tqdm |
| | from datasets import Dataset, DatasetDict, Features, Value, Sequence |
| | from transformers import AutoModelForMaskedLM |
| | import sys |
| | from transformers import AutoTokenizer, EsmModel |
| | from datasets import Dataset, DatasetDict |
| | import tqdm |
| |
|
| | seed_everything(1986) |
| | |
| | |
| | |
| | m1 = [ |
| | '[PAD]','A','R','N','D','C','Q','E','G','H', |
| | 'I','L','K','M','F','P','S','T','W','Y','V' |
| | ] |
| | m2 = dict(zip( |
| | ['[PAD]','[UNK]','[CLS]','[SEP]','[MASK]','L', |
| | 'A','G','V','E','S','I','K','R','D','T','P','N', |
| | 'Q','F','Y','M','H','C','W','X','U','B','Z','O'], |
| | range(30) |
| | )) |
| | |
| | reverse_m2 = {v: k for k, v in m2.items()} |
| | sequences = [] |
| | labels = [] |
| |
|
| | |
| | print("Processing positive sequences...") |
| | with np.load('nf-positive.npz') as pos: |
| | pos_data = pos['arr_0'] |
| | for seq in pos_data: |
| | sequence = ''.join(reverse_m2[token] for token in seq if token != 0) |
| | sequences.append(sequence) |
| | labels.append(1) |
| |
|
| | |
| | print("Processing negative sequences...") |
| | with np.load('nf-negative.npz') as neg: |
| | neg_data = neg['arr_0'] |
| | for seq in neg_data: |
| | sequence = ''.join(reverse_m2[token] for token in seq if token != 0) |
| | sequences.append(sequence) |
| | labels.append(0) |
| | |
| | |
| | ids = [f"seq_{i:06d}" for i in range(len(sequences))] |
| | df = pd.DataFrame({ |
| | "id": ids, |
| | "sequence": sequences, |
| | "label": labels, |
| | }) |
| | print("Before dedup:", len(df)) |
| |
|
| | df = ( |
| | df |
| | .drop_duplicates(subset=["sequence"]) |
| | .reset_index(drop=True) |
| | ) |
| |
|
| | print("After dedup:", len(df)) |
| | |
| | df.to_csv("nf_all.csv", index=False) |
| | print("Saved nf_all.csv") |
| |
|
| | |
| | with open("nf_all.fasta", "w") as f: |
| | for seq_id, seq in zip(df["id"], df["sequence"]): |
| | f.write(f">{seq_id}\n{seq}\n") |
| |
|
| | print("Saved nf_all.fasta") |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | """ |
| | mkdir -p mmseqs_tmp |
| | |
| | mmseqs createdb nf_all.fasta nfDB |
| | |
| | mmseqs cluster nfDB nfDB_clu mmseqs_tmp \ |
| | --min-seq-id 0.3 -c 0.8 --cov-mode 0 |
| | |
| | mmseqs createtsv nfDB nfDB nfDB_clu clusters-nf.tsv |
| | """ |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | train_fraction = 0.8 |
| | csv_path = "nf_all.csv" |
| | clusters_tsv = "clusters-nf.tsv" |
| | rng = np.random.default_rng() |
| |
|
| | df = pd.read_csv(csv_path) |
| |
|
| | |
| | id_to_index = {sid: i for i, sid in enumerate(df["id"])} |
| |
|
| | |
| | cluster_map = {} |
| | with open(clusters_tsv) as f: |
| | for line in f: |
| | if not line.strip(): |
| | continue |
| | rep_id, member_id = line.strip().split('\t') |
| | cluster_map[member_id] = rep_id |
| |
|
| | |
| | for sid in df["id"]: |
| | if sid not in cluster_map: |
| | cluster_map[sid] = sid |
| |
|
| | |
| | cluster_to_indices = {} |
| | for sid, cid in cluster_map.items(): |
| | idx = id_to_index[sid] |
| | cluster_to_indices.setdefault(cid, []).append(idx) |
| |
|
| | |
| | cluster_ids = list(cluster_to_indices.keys()) |
| | rng.shuffle(cluster_ids) |
| |
|
| | |
| | total_n = len(df) |
| | train_target = int(train_fraction * total_n) |
| |
|
| | train_indices = [] |
| | val_indices = [] |
| | current_train = 0 |
| |
|
| | for cid in cluster_ids: |
| | indices = cluster_to_indices[cid] |
| | if current_train + len(indices) <= train_target: |
| | train_indices.extend(indices) |
| | current_train += len(indices) |
| | else: |
| | val_indices.extend(indices) |
| |
|
| | |
| | split = np.full(total_n, "val", dtype=object) |
| | split[train_indices] = "train" |
| |
|
| | |
| | df_with_split = df.copy() |
| | df_with_split["split"] = split |
| | df_with_split.to_csv("nf_meta_with_split.csv", index=False) |
| |
|
| | |
| | df_train = df_with_split[df_with_split["split"] == "train"].reset_index(drop=True) |
| | df_val = df_with_split[df_with_split["split"] == "val"].reset_index(drop=True) |
| |
|
| | df_train.to_csv("nf_train.csv", index=False) |
| | df_val.to_csv("nf_val.csv", index=False) |
| |
|
| | |
| | print("Split counts:") |
| | print(df_with_split["split"].value_counts()) |
| | print() |
| | print(f"Train size: {len(df_train)}") |
| | print(f"Val size: {len(df_val)}") |
| | print("Wrote:") |
| | print(" - sol_meta_with_split.csv") |
| | print(" - sol_train.csv") |
| | print(" - sol_val.csv") |
| |
|
| |
|
| | device = torch.device("cuda:0") |
| | print(f"Using device: {device}") |
| |
|
| | meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" |
| | save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings" |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | model = model.to(device) |
| | model.eval() |
| |
|
| |
|
| | def compute_embeddings(sequences, batch_size=32): |
| | """Return numpy array of shape (N, hidden_dim).""" |
| | embeddings = [] |
| | for i in tqdm.trange(0, len(sequences), batch_size): |
| | batch_sequences = sequences[i:i + batch_size] |
| |
|
| | inputs = tokenizer( |
| | batch_sequences, |
| | padding=True, |
| | max_length=1022, |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | last_hidden_states = outputs.last_hidden_state |
| |
|
| | attention_mask = inputs["attention_mask"].unsqueeze(-1) |
| | masked_hidden_states = last_hidden_states * attention_mask |
| | sum_hidden_states = masked_hidden_states.sum(dim=1) |
| | seq_lengths = attention_mask.sum(dim=1) |
| | batch_embeddings = sum_hidden_states / seq_lengths |
| |
|
| | embeddings.append(batch_embeddings.cpu()) |
| |
|
| | return torch.cat(embeddings, dim=0).numpy() |
| |
|
| |
|
| | def create_and_save_datasets(): |
| | |
| | meta = pd.read_csv(meta_path) |
| | sequences = meta["sequence"].tolist() |
| | labels = meta["label"].tolist() |
| | splits = meta["split"].tolist() |
| |
|
| | print(f"Total sequences: {len(sequences)}") |
| | print("Split counts:", pd.Series(splits).value_counts().to_dict()) |
| |
|
| | print("Computing ESM embeddings...") |
| | embeddings = compute_embeddings(sequences) |
| |
|
| | full_ds = Dataset.from_dict({ |
| | "sequence": sequences, |
| | "embedding": embeddings, |
| | "label": labels, |
| | "split": splits, |
| | }) |
| |
|
| | |
| | train_ds = full_ds.filter(lambda x: x["split"] == "train") |
| | val_ds = full_ds.filter(lambda x: x["split"] == "val") |
| |
|
| | train_ds = train_ds.remove_columns("split") |
| | val_ds = val_ds.remove_columns("split") |
| |
|
| | ds_dict = DatasetDict({ |
| | "train": train_ds, |
| | "val": val_ds, |
| | }) |
| |
|
| | ds_dict.save_to_disk(save_path) |
| | print(f"Saved DatasetDict with train/val to: {save_path}") |
| | print("Train size:", len(ds_dict["train"])) |
| | print("Val size:", len(ds_dict["val"])) |
| |
|
| | return ds_dict |
| |
|
| |
|
| | ds = create_and_save_datasets() |
| |
|
| | ex = ds["train"][0] |
| | print("\nExample from train:") |
| | print("Sequence:", ex["sequence"]) |
| | print("Embedding shape:", np.array(ex["embedding"]).shape) |
| | print("Label:", ex["label"]) |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" |
| | save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings_unpooled" |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D", add_pooling_layer=False).to(device).eval() |
| |
|
| | cls_id = tokenizer.cls_token_id |
| | eos_id = tokenizer.eos_token_id |
| |
|
| | @torch.no_grad() |
| | def embed_one(seq, max_length=1022): |
| | inputs = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | out = model(**inputs) |
| | h = out.last_hidden_state[0] |
| | attn = inputs["attention_mask"][0].bool() |
| | ids = inputs["input_ids"][0] |
| |
|
| | keep = attn.clone() |
| | if cls_id is not None: |
| | keep &= (ids != cls_id) |
| | if eos_id is not None: |
| | keep &= (ids != eos_id) |
| |
|
| | hb = h[keep].detach().cpu().to(torch.float16).numpy() |
| | return hb |
| |
|
| | H = 1280 |
| | features = Features({ |
| | "sequence": Value("string"), |
| | "label": Value("int64"), |
| | "embedding": Sequence(Sequence(Value("float16"), length=H)), |
| | "attention_mask": Sequence(Value("int8")), |
| | "length": Value("int64"), |
| | }) |
| |
|
| | def make_generator(df): |
| | for seq, lab in tqdm.tqdm(zip(df["sequence"].tolist(), df["label"].astype(int).tolist()), total=len(df)): |
| | emb = embed_one(seq) |
| | emb_list = emb.tolist() |
| | li = len(emb_list) |
| | yield { |
| | "sequence": seq, |
| | "label": int(lab), |
| | "embedding": emb_list, |
| | "attention_mask": [1] * li, |
| | "length": li, |
| | } |
| |
|
| | def build_and_save_split(df, out_dir): |
| | ds = Dataset.from_generator(make_generator, gen_kwargs={"df": df}, features=features) |
| | |
| | ds.save_to_disk(out_dir, max_shard_size="1GB") |
| | return ds |
| |
|
| | meta = pd.read_csv(meta_path) |
| | train_df = meta[meta["split"] == "train"].reset_index(drop=True) |
| | val_df = meta[meta["split"] == "val"].reset_index(drop=True) |
| |
|
| | train_dir = os.path.join(save_path, "train") |
| | val_dir = os.path.join(save_path, "val") |
| | os.makedirs(save_path, exist_ok=True) |
| |
|
| | train_ds = build_and_save_split(train_df, train_dir) |
| | val_ds = build_and_save_split(val_df, val_dir) |
| |
|
| | ds_dict = DatasetDict({"train": train_ds, "val": val_ds}) |
| | ds_dict.save_to_disk(save_path) |
| | print(ds_dict) |
| |
|