| import torch |
| import pandas as pd |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| from umap import UMAP |
| from sklearn.manifold import TSNE |
| from sklearn.decomposition import PCA |
| from transformers import AutoModel, AutoTokenizer |
|
|
| path = "/workspace/sg666/MDpLM/benchmarks/Generation" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| esm_model_path = "facebook/esm2_t33_650M_UR50D" |
|
|
| |
| def load_esm2_model(model_name): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModel.from_pretrained(model_name).to(device) |
| return tokenizer, model |
|
|
| def get_latents(model, tokenizer, sequence): |
| inputs = tokenizer(sequence, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| embeddings = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().tolist() |
| return embeddings |
|
|
| |
| def parse_fasta_file(file_path): |
| with open(file_path, 'r') as file: |
| lines = file.readlines() |
|
|
| sequences = [] |
| current_seq = [] |
| current_type = "UniProt" |
| |
| for line in lines: |
| line = line.strip() |
| if line.startswith('>'): |
| if current_seq: |
| sequences.append(("".join(current_seq), current_type)) |
| current_seq = [] |
| else: |
| current_seq.append(line) |
| if current_seq: |
| sequences.append(("".join(current_seq), current_type)) |
| |
| return pd.DataFrame(sequences, columns=["Sequence", "Sequence Source"]).sample(100).reset_index(drop=True) |
|
|
|
|
| |
| protgpt2_sequences = pd.read_csv(path + "/ProtGPT2/protgpt2_generated_sequences.csv") |
| protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('<|ENDOFTEXT|>', '', regex=False) |
| protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('""', '', regex=False) |
| protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('\n', '', regex=False) |
| protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('X', 'G', regex=False) |
| protgpt2_sequences.drop(columns=['Perplexity'], inplace=True) |
| protgpt2_sequences['Sequence Source'] = "ProtGPT2" |
| bad_sequences = [] |
| for seq in protgpt2_sequences['Sequence']: |
| for residue in seq: |
| if residue in ['B', 'U', 'Z', 'O']: |
| bad_sequences.append(seq) |
| protgpt2_sequences = protgpt2_sequences[~protgpt2_sequences['Sequence'].isin(bad_sequences)] |
|
|
|
|
| |
| memdlm_sequences = pd.read_csv(path + "/mdlm_de-novo_generation_results.csv") |
| memdlm_sequences.rename(columns={"Generated Sequence": "Sequence"}, inplace=True) |
| memdlm_sequences.drop(columns=['Perplexity'], inplace=True) |
| memdlm_sequences['Sequence Source'] = "MeMDLM" |
| memdlm_sequences.reset_index(drop=True, inplace=True) |
|
|
| |
| |
| |
|
|
| |
| other_sequences = pd.read_csv("/workspace/sg666/MDpLM/data/membrane/test.csv") |
| other_sequences['Sequence Source'] = "Test Set" |
| other_sequences = other_sequences.sample(100) |
|
|
| |
| data = pd.concat([memdlm_sequences, protgpt2_sequences, other_sequences]) |
|
|
|
|
| |
| tokenizer, model = load_esm2_model(esm_model_path) |
| model = model.to(device) |
|
|
|
|
| |
| data['Embeddings'] = data['Sequence'].apply(lambda sequence: get_latents(model, tokenizer, sequence)) |
| data = data.reset_index(drop=True) |
| umap_df = pd.DataFrame(data['Embeddings'].tolist()) |
| umap_df.index = data['Sequence Source'] |
|
|
|
|
| |
| umap = UMAP(n_components=2) |
| umap_features = umap.fit_transform(umap_df) |
| umap_df['UMAP1'] = umap_features[:, 0] |
| umap_df['UMAP2'] = umap_features[:, 1] |
|
|
| |
| plt.figure(figsize=(8, 5),dpi=300) |
| sns.scatterplot(x='UMAP1', y='UMAP2', hue='Sequence Source', data=umap_df, palette=['#297272', '#ff7477', "#9A77D0"], s=100) |
| plt.xlabel('UMAP1') |
| plt.ylabel('UMAP2') |
| plt.title(f'ESM-650M Embeddings of Membrane Protein Sequences') |
| plt.savefig('esm_umap.png') |
| plt.show() |