| import streamlit as st |
| from utils import validate_sequence, predict, plot_prediction_graphs |
| from model import models |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| def main(): |
| st.set_page_config(layout="wide") |
| st.title("AA Property Inference Demo", anchor=None) |
|
|
| |
| st.markdown(""" |
| <style> |
| .reportview-container { |
| font-family: 'Courier New', monospace; |
| } |
| </style> |
| <p style='font-size:16px;'><span style='font-size:24px;'>←</span> Don't know where to start? Open tab to input a sequence.</p> |
| """, unsafe_allow_html=True) |
|
|
| |
| sequence = st.sidebar.text_input("Enter your amino acid sequence:") |
| uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv") |
| analyze_pressed = st.sidebar.button("Analyze Sequence") |
| show_graphs = st.sidebar.checkbox("Show Prediction Graphs") |
|
|
| sequences = [sequence] if sequence else [] |
| if uploaded_file: |
| df = pd.read_csv(uploaded_file) |
| sequences.extend(df['sequence'].tolist()) |
| names = df['name'].tolist() |
| else: |
| names = [f"Seq {i+1}" for i in range(len(sequences))] |
|
|
| results = [] |
| all_data = {} |
| if analyze_pressed: |
| for name, seq in zip(names, sequences): |
| if validate_sequence(seq): |
| model_results = {} |
| graph_data = {} |
| for model_name, model in models.items(): |
| prediction, confidence = predict(model, seq) |
| model_results[f"{model_name}_prediction"] = prediction |
| model_results[f"{model_name}_confidence"] = round(confidence, 3) |
| graph_data[model_name] = (prediction, confidence) |
| results.append({"Name": name, "Sequence": seq, **model_results}) |
| all_data[name] = graph_data |
| else: |
| st.sidebar.error(f"Invalid sequence for {name}: {seq}") |
|
|
| if results: |
| results_df = pd.DataFrame(results) |
| st.write("### Results") |
| st.dataframe(results_df.style.format(precision=3), width=None, height=None) |
| |
| if show_graphs and all_data: |
| st.write("## Graphs") |
| plot_prediction_graphs(all_data,models.keys()) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|