import gradio as gr
import torch
import logging
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Configure professional logging
logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
class MagicSupportClassifier:
"""
Encapsulates the customer support intent classification model.
Engineered for dynamic label resolution and rapid inference.
"""
def __init__(self, model_id: str = "learn-abc/magicSupport-intent-classifier"):
self.model_id = model_id
self.max_length = 128
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_model()
def _load_model(self):
logger.info(f"Initializing model {self.model_id} on {self.device}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.model.to(self.device)
self.model.eval()
# Extract number of classes dynamically
self.num_classes = len(self.model.config.id2label)
logger.info(f"Model loaded successfully with {self.num_classes} intent classes.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def _get_iconography(self, label: str) -> str:
"""
Dynamically assigns UI icons based on intent keywords.
Future-proofs the application against retrained label sets.
"""
label_lower = label.lower()
if "order" in label_lower or "delivery" in label_lower or "track" in label_lower:
return "📦"
if "refund" in label_lower or "payment" in label_lower or "invoice" in label_lower or "fee" in label_lower:
return "💳"
if "account" in label_lower or "password" in label_lower or "register" in label_lower or "profile" in label_lower:
return "👤"
if "cancel" in label_lower or "delete" in label_lower or "problem" in label_lower or "issue" in label_lower:
return "⚠️"
if "contact" in label_lower or "service" in label_lower or "support" in label_lower:
return "🎧"
return "🔹"
def _format_label(self, label: str) -> str:
"""Cleans up raw dataset labels for professional UI presentation."""
return label.replace("_", " ").title()
@torch.inference_mode()
def predict(self, text: str, top_k: int = 5):
if not text or not text.strip():
return "
⚠️ Input Required: Please enter a customer query.
", None
try:
inputs = self.tokenizer(
text.strip(),
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True
).to(self.device)
logits = self.model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze()
if probs.dim() == 0:
probs = probs.unsqueeze(0)
# Cap top_k to the maximum number of available classes
actual_top_k = min(top_k, self.num_classes)
top_indices = torch.topk(probs, k=actual_top_k).indices.tolist()
top_probs = torch.topk(probs, k=actual_top_k).values.tolist()
id2label = self.model.config.id2label
# Primary Prediction Formatting
top_intent_raw = id2label[top_indices[0]]
emoji = self._get_iconography(top_intent_raw)
clean_label = self._format_label(top_intent_raw)
confidence = top_probs[0] * 100
result_html = f"""
{emoji} {clean_label}
Confidence: {confidence:.1f}%
📊 Top {actual_top_k} Predictions
"""
# HTML Progress Bars
for idx, prob in zip(top_indices, top_probs):
intent_raw = id2label[idx]
e = self._get_iconography(intent_raw)
l = self._format_label(intent_raw)
pct = prob * 100
bar_html = f"""
"""
result_html += bar_html
# Format data for the full distribution chart
chart_data = {
self._format_label(id2label[i]): float(probs[i].item())
for i in range(len(probs))
}
return result_html, chart_data
except Exception as e:
logger.error(f"Inference error: {e}")
return f"❌ System Error: Inference failed. Check application logs.
", None
# Initialize application backend
app_backend = MagicSupportClassifier()
# High-value test scenarios based on Bitext taxonomy
EXAMPLES = [
["I need to cancel my order immediately, it was placed by mistake.", 5],
["Where can I find the invoice for my last purchase?", 3],
["The item arrived damaged and I want a full refund.", 5],
["How do I change the shipping address on my account?", 3],
["I forgot my password and cannot log in.", 3],
["Are there any hidden fees if I cancel my subscription now?", 5],
]
# Build Gradio Interface
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="slate"),
title="MagicSupport Intent Classifier R&D Dashboard",
css="""
.header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;}
.header-box h1 { color: var(--body-text-color); margin-bottom: 5px; }
.header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; }
.badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; }
.domain-badge { background: #ede9fe; color: #5b21b6; border: 1px solid #ddd6fe;}
.metric-badge { background: #f1f5f9; color: #334155; border: 1px solid #cbd5e1;}
footer { display: none !important; }
"""
) as demo:
gr.HTML("""
""")
with gr.Row():
with gr.Column(scale=5):
text_input = gr.Textbox(
label="Input Customer Query",
placeholder="Type a customer message here (e.g., 'Where is my package?')...",
lines=3,
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1, maximum=15, value=5, step=1,
label="Display Top-K Predictions"
)
with gr.Row():
predict_btn = gr.Button("🔍 Execute Prediction", variant="primary")
clear_btn = gr.Button("🗑️ Clear Interface", variant="secondary")
gr.Examples(
examples=EXAMPLES,
inputs=[text_input, top_k_slider],
label="Actionable Test Scenarios",
examples_per_page=6,
)
with gr.Column(scale=5):
result_output = gr.HTML(label="Inference Results")
with gr.Row():
chart_output = gr.Label(
label="Full Semantic Distribution Map",
num_top_classes=app_backend.num_classes # Dynamically set based on model config
)
with gr.Accordion("⚙️ Technical Architecture & Model Details", open=False):
gr.Markdown("""
### Core Specifications
* **Target Model:** `learn-abc/magicSupport-intent-classifier`
* **Objective:** Multi-class text sequence classification for customer support routing.
* **Dataset Lineage:** Trained on the comprehensive `bitext/Bitext-customer-support-llm-chatbot-training-dataset`.
### Pipeline Features
* **Dynamic Label Resolution:** The UI heuristic engine automatically maps raw dataset labels (e.g., `change_shipping_address`) into clean, professional UI elements (e.g., Change Shipping Address) and assigns contextual iconography.
* **Optimized Inference:** Utilizes PyTorch `inference_mode` for reduced memory footprint and accelerated compute during forward passes.
""")
# Event Wiring
predict_btn.click(
fn=app_backend.predict,
inputs=[text_input, top_k_slider],
outputs=[result_output, chart_output],
)
text_input.submit(
fn=app_backend.predict,
inputs=[text_input, top_k_slider],
outputs=[result_output, chart_output],
)
clear_btn.click(
fn=lambda: ("", 5, "", None),
outputs=[text_input, top_k_slider, result_output, chart_output],
)
if __name__ == "__main__":
demo.launch()