| import streamlit as st |
| from transformers import AutoTokenizer |
| import torch |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| def validate_sequence(sequence): |
| valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") |
| return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200 |
|
|
| def load_model(model_name): |
| |
| model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu')) |
| model.eval() |
| return model |
|
|
|
|
| def predict(model, sequence): |
| tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') |
| tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True) |
| output = model(**tokenized_input) |
| probabilities = F.softmax(output.logits, dim=-1) |
| predicted_label = torch.argmax(probabilities, dim=-1) |
| confidence = probabilities.max().item() * 0.85 |
| return predicted_label.item(), confidence |
|
|
| def plot_prediction_graphs(data,model_keys): |
| |
| unique_names = sorted(data.keys()) |
| palette = sns.color_palette("hsv", len(unique_names)) |
| color_dict = {name: color for name, color in zip(unique_names, palette)} |
|
|
| for model_name in model_keys: |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True) |
| for prediction_val in [0, 1]: |
| ax = ax1 if prediction_val == 0 else ax2 |
| filtered_data = {name: values[model_name] for name, values in data.items() if values[model_name][0] == prediction_val} |
| |
| sorted_names = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True) |
| names = [x[0] for x in sorted_names] |
| conf_values = [x[1][1] for x in sorted_names] |
| colors = [color_dict[name] for name in names] |
| sns.barplot(x=names, y=conf_values, palette=colors, ax=ax) |
| ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})') |
| ax.set_xlabel('Names') |
| ax.set_ylabel('Confidence') |
| ax.tick_params(axis='x', rotation=45) |
|
|
| st.pyplot(fig) |