{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "PwhltETnCLY1" }, "source": [ "#**AMP Classification using ProtBERT Embeddings + Fast MLP**\n", "This notebook extracts ProtBERT embeddings for peptide sequences and trains a simple Multi-Layer Perceptron (MLP) to classify antimicrobial peptides (AMPs) vs non-AMPs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "qv_84qo0CLY6" }, "outputs": [], "source": [ "!pip install torch transformers scikit-learn numpy pandas tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "4wld_6KBCLY7" }, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from transformers import AutoTokenizer, AutoModel\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from torch import nn, optim\n", "from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score\n", "import sys\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print('Device:', device)" ] }, { "cell_type": "markdown", "metadata": { "id": "7n3m1GLLCLY8" }, "source": [ "##Load Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "wAg_vM3JCLY8" }, "outputs": [], "source": [ "IN_COLAB = 'google.colab' in sys.modules\n", "if IN_COLAB:\n", " from google.colab import drive\n", " drive.mount('/content/drive')\n", " file_path = '/content/drive/MyDrive/ampData.csv'\n", "else:\n", " file_path = 'ampData.csv'\n", "\n", "df = pd.read_csv(file_path)\n", "df['sequence'] = df['sequence'].astype(str).str.upper().str.strip()\n", "df = df.dropna(subset=['sequence','label']).reset_index(drop=True)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "8HxUUO6SCLY8" }, "source": [ "## Extract ProtBERT Embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "CltjDxknCLY9" }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert')\n", "model = AutoModel.from_pretrained('Rostlab/prot_bert').to(device)\n", "\n", "def get_embedding(sequence):\n", " seq = ' '.join(list(sequence))\n", " tokens = tokenizer(seq, return_tensors='pt', truncation=True, padding=True).to(device)\n", " with torch.no_grad():\n", " outputs = model(**tokens)\n", " emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()\n", " return emb\n", "\n", "embeddings = []\n", "for seq in tqdm(df['sequence'], desc='Extracting Embeddings'):\n", " embeddings.append(get_embedding(seq))\n", "\n", "X = np.array(embeddings)\n", "y = df['label'].values\n", "\n", "np.save('X_embeddings.npy', X)\n", "np.save('y_labels.npy', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "TZpCHIpTCLY9" }, "source": [ "## Train-Test Split" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HUhsld4YCLY9" }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)\n", "\n", "X_train = torch.tensor(X_train, dtype=torch.float32).to(device)\n", "X_test = torch.tensor(X_test, dtype=torch.float32).to(device)\n", "y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)\n", "y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "aeeNh2s9CLY-" }, "source": [ "## Define MLP Classifier" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "V04ShQ1VCLY-" }, "outputs": [], "source": [ "class MLPClassifier(nn.Module):\n", " def __init__(self, input_dim):\n", " super().__init__()\n", " self.layers = nn.Sequential(\n", " nn.Linear(input_dim, 512),\n", " nn.ReLU(),\n", " nn.Dropout(0.3),\n", " nn.Linear(512, 128),\n", " nn.ReLU(),\n", " nn.Linear(128, 1),\n", " nn.Sigmoid()\n", " )\n", " def forward(self, x):\n", " return self.layers(x)\n", "\n", "model_mlp = MLPClassifier(X_train.shape[1]).to(device)\n", "criterion = nn.BCELoss()\n", "optimizer = optim.Adam(model_mlp.parameters(), lr=1e-4)\n", "\n", "print(model_mlp)" ] }, { "cell_type": "markdown", "metadata": { "id": "XAsOa6l7CLY-" }, "source": [ "## Train MLP" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "7sXSUh3WCLY-" }, "outputs": [], "source": [ "epochs = 20\n", "batch_size = 64\n", "\n", "for epoch in range(epochs):\n", " model_mlp.train()\n", " perm = torch.randperm(X_train.size(0))\n", " total_loss = 0\n", " for i in range(0, X_train.size(0), batch_size):\n", " idx = perm[i:i+batch_size]\n", " x_batch, y_batch = X_train[idx], y_train[idx]\n", " optimizer.zero_grad()\n", " outputs = model_mlp(x_batch)\n", " loss = criterion(outputs, y_batch)\n", " loss.backward()\n", " optimizer.step()\n", " total_loss += loss.item()\n", " print(f\"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "A4XbUrqRCLY-" }, "source": [ "## Evaluate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "YtieKVFhCLY_" }, "outputs": [], "source": [ "model_mlp.eval()\n", "with torch.no_grad():\n", " preds = model_mlp(X_test).cpu().numpy().flatten()\n", "\n", "pred_labels = (preds >= 0.5).astype(int)\n", "print('ROC-AUC:', roc_auc_score(y_test.cpu(), preds))\n", "print('PR-AUC:', average_precision_score(y_test.cpu(), preds))\n", "print('\\nClassification Report:\\n', classification_report(y_test.cpu(), pred_labels))\n", "print('Confusion Matrix:\\n', confusion_matrix(y_test.cpu(), pred_labels))" ] }, { "cell_type": "markdown", "metadata": { "id": "ADjCmp8PCLY_" }, "source": [ "## Save Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "v0j_4vwKCLY_" }, "outputs": [], "source": [ "torch.save(model_mlp.state_dict(), 'fast_mlp_amp.pt')\n", "print('Model saved as fast_mlp_amp.pt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IuJCNyBTXkBH" }, "outputs": [], "source": [ "from google.colab import files\n", "files.download('fast_mlp_amp.pt')" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.x" } }, "nbformat": 4, "nbformat_minor": 0 }