import streamlit as st import shap import matplotlib.pyplot as plt import numpy as np import joblib import pandas as pd from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc st.set_page_config(page_title="Model Analysis Dashboard", layout="wide") st.title("Model Analysis Dashboard") st.markdown(""" This dashboard allows you to interactively explore the performance and interpretability of the trained logistic regression model. - **Logistic Regression Coefficients:** See which features most strongly influence the model's predictions. - **SHAP Analysis:** Understand both global and local feature importance using SHAP values. - **ROC/PR Curves:** Evaluate the model's discrimination and precision-recall tradeoff. Use the sidebar to select which plots to display and to adjust the number of features or the local sample for SHAP explanations. """) # --- Load model and test data --- @st.cache_data def load_model_and_data(): model = joblib.load("model_1mvp.pkl") df_test = pd.read_csv("test_data.csv") return model, df_test model, df_test = load_model_and_data() target = "y" X_test = df_test.drop(columns=[target]) y_test = df_test[target] preprocessor = model.named_steps["preprocessor"] feature_names = preprocessor.get_feature_names_out() X_test_transformed = preprocessor.transform(X_test) # --- SHAP Explainer (precompute for efficiency) --- explainer = shap.LinearExplainer(model.named_steps["classifier"], X_test_transformed, feature_names=feature_names) shap_values = explainer.shap_values(X_test_transformed) expected_value = explainer.expected_value # --- Sidebar: Plot selection and controls --- with st.sidebar.form("plot_selector"): st.markdown("## Select plots to display") show_coeff = st.checkbox("Logistic Regression Coefficients", value=True) show_shap_global = st.checkbox("SHAP Global (summary plot)", value=True) show_shap_local = st.checkbox("SHAP Local (waterfall plot)", value=False) show_roc = st.checkbox("ROC/PR Curves", value=True) top_n = st.slider("Number of top features for LogReg coeffecients", 5, 30, 15) local_idx = st.number_input("Local SHAP sample index", min_value=0, max_value=len(X_test)-1, value=0) submitted = st.form_submit_button("Update plots") # --- Logistic Regression Coefficient Plot --- if show_coeff and submitted: st.header("Logistic Regression Coefficients") logreg_model = model.named_steps["classifier"] coefficients = logreg_model.coef_[0] importance = pd.DataFrame({ "feature": feature_names, "coefficient": coefficients }).sort_values(by="coefficient", key=abs, ascending=False) fig, ax = plt.subplots(figsize=(8, 6)) importance.head(top_n).set_index("feature")["coefficient"].plot(kind="barh", ax=ax, color="#4C72B0") ax.set_title("Logistic Regression Feature Importance (Coefficients)") ax.set_xlabel("Coefficient Value") ax.set_ylabel("Feature") st.pyplot(fig) st.dataframe(importance.head(top_n).style.format({"coefficient": "{:.3f}"})) # --- SHAP Analysis --- if (show_shap_global or show_shap_local) and submitted: st.header("SHAP Analysis") if show_shap_global: st.subheader("Global Feature Importance (SHAP Summary Plot)") fig, ax = plt.subplots(figsize=(10, 6)) shap.summary_plot(shap_values, X_test_transformed, feature_names=feature_names, show=False) st.pyplot(fig) if show_shap_local: st.subheader("Local Explanation (SHAP Waterfall Plot)") fig2, ax2 = plt.subplots(figsize=(10, 6)) shap.plots.waterfall( shap.Explanation( values=shap_values[local_idx], base_values=expected_value, data=X_test_transformed[local_idx], feature_names=feature_names ), max_display=15, show=False ) st.pyplot(fig2) # --- ROC and PR Curves --- if show_roc and submitted: st.header("Model Performance Metrics (ROC / PR Curves)") y_pred_proba = model.predict_proba(X_test)[:, 1] roc_auc = roc_auc_score(y_test, y_pred_proba) fpr, tpr, _ = roc_curve(y_test, y_pred_proba) precision, recall, _ = precision_recall_curve(y_test, y_pred_proba) pr_auc = auc(recall, precision) col1, col2 = st.columns(2) with col1: st.metric("ROC AUC", f"{roc_auc:.3f}") with col2: st.metric("PR AUC", f"{pr_auc:.3f}") fig1, ax1 = plt.subplots(figsize=(5, 5)) ax1.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})") ax1.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random Guess") ax1.set_xlabel("False Positive Rate") ax1.set_ylabel("True Positive Rate") ax1.set_title("ROC Curve") ax1.legend() st.pyplot(fig1) fig2, ax2 = plt.subplots(figsize=(5, 5)) ax2.plot(recall, precision, color="#C44E52") ax2.set_xlabel("Recall") ax2.set_ylabel("Precision") ax2.set_title("Precision-Recall Curve") st.pyplot(fig2)