| import streamlit as st |
| import pandas as pd |
| import torch |
| from transformers import BertTokenizer, AutoModelForSequenceClassification |
| from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support |
| import numpy as np |
| import plotly.graph_objects as go |
| import plotly.express as px |
| from tqdm import tqdm |
|
|
| def load_model_and_tokenizer(): |
| try: |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
| model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| model.eval() |
| return model, tokenizer, device |
| except Exception as e: |
| st.error(f"Error loading model or tokenizer: {str(e)}") |
| return None, None, None |
|
|
| def preprocess_data(df): |
| try: |
| processed_data = [] |
| for _, row in df.iterrows(): |
| outlet = row["outlet"].strip().upper() |
| if outlet == "FOX NEWS": |
| outlet = "FOXNEWS" |
| elif outlet == "NBC NEWS": |
| outlet = "NBC" |
| |
| processed_data.append({ |
| "title": row["title"], |
| "outlet": outlet |
| }) |
| return processed_data |
| except Exception as e: |
| st.error(f"Error preprocessing data: {str(e)}") |
| return None |
|
|
| def evaluate_model(model, tokenizer, device, test_dataset): |
| label2id = {"FOXNEWS": 0, "NBC": 1} |
| all_logits = [] |
| references = [] |
| |
| batch_size = 16 |
| progress_bar = st.progress(0) |
| |
| for i in range(0, len(test_dataset), batch_size): |
| |
| progress = min(i / len(test_dataset), 1.0) |
| progress_bar.progress(progress) |
| |
| batch = test_dataset[i:i + batch_size] |
| texts = [item['title'] for item in batch] |
|
|
| encoded = tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=128, |
| return_tensors="pt" |
| ) |
|
|
| inputs = {k: v.to(device) for k, v in encoded.items()} |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits.cpu().numpy() |
|
|
| true_labels = [label2id[item['outlet']] for item in batch] |
| all_logits.extend(logits) |
| references.extend(true_labels) |
| progress_bar.progress(1.0) |
| probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy() |
| return references, probabilities |
|
|
| def plot_roc_curve(references, probabilities): |
| fpr, tpr, _ = roc_curve(references, probabilities[:, 1]) |
| auc_score = roc_auc_score(references, probabilities[:, 1]) |
| fig = go.Figure() |
| fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})')) |
| fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash'))) |
| fig.update_layout( |
| title='ROC Curve', |
| xaxis_title='False Positive Rate', |
| yaxis_title='True Positive Rate', |
| showlegend=True |
| ) |
| return fig, auc_score |
|
|
| def plot_metrics_by_threshold(references, probabilities): |
| thresholds = np.arange(0.0, 1.0, 0.01) |
| metrics = { |
| 'threshold': thresholds, |
| 'f1': [], |
| 'precision': [], |
| 'recall': [] |
| } |
| best_f1 = 0 |
| best_threshold = 0 |
| best_metrics = {} |
| for threshold in thresholds: |
| preds = (probabilities[:, 1] > threshold).astype(int) |
| f1 = f1_score(references, preds) |
| precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary') |
| metrics['f1'].append(f1) |
| metrics['precision'].append(precision) |
| metrics['recall'].append(recall) |
| if f1 > best_f1: |
| best_f1 = f1 |
| best_threshold = threshold |
| cm = confusion_matrix(references, preds) |
| report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4) |
| best_metrics = { |
| 'threshold': threshold, |
| 'f1_score': f1, |
| 'confusion_matrix': cm, |
| 'classification_report': report |
| } |
| fig = go.Figure() |
| fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score')) |
| fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision')) |
| fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall')) |
| fig.update_layout( |
| title='Metrics by Threshold', |
| xaxis_title='Threshold', |
| yaxis_title='Score', |
| showlegend=True |
| ) |
| return fig, best_metrics |
|
|
| def plot_confusion_matrix(cm): |
| labels = ['FOXNEWS', 'NBC'] |
| annotations = [] |
| for i in range(len(labels)): |
| for j in range(len(labels)): |
| annotations.append( |
| dict( |
| text=str(cm[i, j]), |
| x=labels[j], |
| y=labels[i], |
| showarrow=False, |
| font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black') |
| ) |
| ) |
| fig = go.Figure(data=go.Heatmap( |
| z=cm, |
| x=labels, |
| y=labels, |
| colorscale='Blues', |
| showscale=True |
| )) |
| fig.update_layout( |
| title='Confusion Matrix', |
| xaxis_title='Predicted Label', |
| yaxis_title='True Label', |
| annotations=annotations |
| ) |
| return fig |
|
|
| def main(): |
| st.title("News Classifier Model Evaluation") |
| uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv']) |
| if uploaded_file is not None: |
| df = pd.read_csv(uploaded_file) |
| st.write("Preview of uploaded data:") |
| st.dataframe(df.head()) |
| model, tokenizer, device = load_model_and_tokenizer() |
| if model and tokenizer: |
| test_dataset = preprocess_data(df) |
| if test_dataset: |
| st.write(f"Total examples: {len(test_dataset)}") |
| with st.spinner('Evaluating model...'): |
| references, probabilities = evaluate_model(model, tokenizer, device, test_dataset) |
| roc_fig, auc_score = plot_roc_curve(references, probabilities) |
| st.plotly_chart(roc_fig) |
| st.metric("AUC-ROC Score", f"{auc_score:.4f}") |
| metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities) |
| st.plotly_chart(metrics_fig) |
| st.subheader("Best Threshold Evaluation") |
| col1, col2 = st.columns(2) |
| with col1: |
| st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}") |
| with col2: |
| st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}") |
| st.subheader("Confusion Matrix") |
| cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix']) |
| st.plotly_chart(cm_fig) |
| st.subheader("Classification Report") |
| st.text(best_metrics['classification_report']) |
| if __name__ == "__main__": |
| main() |