| import gradio as gr |
| import torch |
| import logging |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class MagicSupportClassifier: |
| """ |
| Encapsulates the customer support intent classification model. |
| Engineered for dynamic label resolution and rapid inference. |
| """ |
| def __init__(self, model_id: str = "learn-abc/magicSupport-intent-classifier"): |
| self.model_id = model_id |
| self.max_length = 128 |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self._load_model() |
|
|
| def _load_model(self): |
| logger.info(f"Initializing model {self.model_id} on {self.device}...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) |
| self.model.to(self.device) |
| self.model.eval() |
| |
| |
| self.num_classes = len(self.model.config.id2label) |
| logger.info(f"Model loaded successfully with {self.num_classes} intent classes.") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
|
|
| def _get_iconography(self, label: str) -> str: |
| """ |
| Dynamically assigns UI icons based on intent keywords. |
| Future-proofs the application against retrained label sets. |
| """ |
| label_lower = label.lower() |
| if "order" in label_lower or "delivery" in label_lower or "track" in label_lower: |
| return "π¦" |
| if "refund" in label_lower or "payment" in label_lower or "invoice" in label_lower or "fee" in label_lower: |
| return "π³" |
| if "account" in label_lower or "password" in label_lower or "register" in label_lower or "profile" in label_lower: |
| return "π€" |
| if "cancel" in label_lower or "delete" in label_lower or "problem" in label_lower or "issue" in label_lower: |
| return "β οΈ" |
| if "contact" in label_lower or "service" in label_lower or "support" in label_lower: |
| return "π§" |
| return "πΉ" |
|
|
| def _format_label(self, label: str) -> str: |
| """Cleans up raw dataset labels for professional UI presentation.""" |
| return label.replace("_", " ").title() |
|
|
| @torch.inference_mode() |
| def predict(self, text: str, top_k: int = 5): |
| if not text or not text.strip(): |
| return "<div style='color: #ef4444; padding: 10px;'>β οΈ <b>Input Required:</b> Please enter a customer query.</div>", None |
|
|
| try: |
| inputs = self.tokenizer( |
| text.strip(), |
| return_tensors="pt", |
| truncation=True, |
| max_length=self.max_length, |
| padding=True |
| ).to(self.device) |
|
|
| logits = self.model(**inputs).logits |
| probs = F.softmax(logits, dim=-1).squeeze() |
| |
| if probs.dim() == 0: |
| probs = probs.unsqueeze(0) |
|
|
| |
| actual_top_k = min(top_k, self.num_classes) |
| |
| top_indices = torch.topk(probs, k=actual_top_k).indices.tolist() |
| top_probs = torch.topk(probs, k=actual_top_k).values.tolist() |
| id2label = self.model.config.id2label |
|
|
| |
| top_intent_raw = id2label[top_indices[0]] |
| emoji = self._get_iconography(top_intent_raw) |
| clean_label = self._format_label(top_intent_raw) |
| confidence = top_probs[0] * 100 |
|
|
| result_html = f""" |
| <h2 style='margin-bottom: 5px; display: flex; align-items: center; gap: 8px;'>{emoji} {clean_label}</h2> |
| <p style='margin-top: 0; font-size: 16px;'><b>Confidence:</b> {confidence:.1f}%</p> |
| <hr style='border-top: 1px solid var(--border-color-primary); margin: 20px 0;'/> |
| <h3 style='margin-bottom: 15px;'>π Top {actual_top_k} Predictions</h3> |
| """ |
| |
| |
| for idx, prob in zip(top_indices, top_probs): |
| intent_raw = id2label[idx] |
| e = self._get_iconography(intent_raw) |
| l = self._format_label(intent_raw) |
| pct = prob * 100 |
| |
| bar_html = f""" |
| <div style="margin-bottom: 16px;"> |
| <div style="display: flex; justify-content: space-between; margin-bottom: 4px;"> |
| <strong>{e} {l}</strong> |
| <span style="font-family:monospace;">{pct:.1f}%</span> |
| </div> |
| <div style="background-color: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); border-radius: 6px; width: 100%; height: 10px;"> |
| <div style="background-color: #8b5cf6; width: {pct}%; height: 100%; border-radius: 5px;"></div> |
| </div> |
| </div> |
| """ |
| result_html += bar_html |
|
|
| |
| chart_data = { |
| self._format_label(id2label[i]): float(probs[i].item()) |
| for i in range(len(probs)) |
| } |
| |
| return result_html, chart_data |
|
|
| except Exception as e: |
| logger.error(f"Inference error: {e}") |
| return f"<div style='color: #ef4444;'>β <b>System Error:</b> Inference failed. Check application logs.</div>", None |
|
|
| |
| app_backend = MagicSupportClassifier() |
|
|
| |
| EXAMPLES = [ |
| ["I need to cancel my order immediately, it was placed by mistake.", 5], |
| ["Where can I find the invoice for my last purchase?", 3], |
| ["The item arrived damaged and I want a full refund.", 5], |
| ["How do I change the shipping address on my account?", 3], |
| ["I forgot my password and cannot log in.", 3], |
| ["Are there any hidden fees if I cancel my subscription now?", 5], |
| ] |
|
|
| |
| with gr.Blocks( |
| theme=gr.themes.Soft(primary_hue="violet", secondary_hue="slate"), |
| title="MagicSupport Intent Classifier R&D Dashboard", |
| css=""" |
| .header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;} |
| .header-box h1 { color: var(--body-text-color); margin-bottom: 5px; } |
| .header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; } |
| .badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; } |
| .domain-badge { background: #ede9fe; color: #5b21b6; border: 1px solid #ddd6fe;} |
| .metric-badge { background: #f1f5f9; color: #334155; border: 1px solid #cbd5e1;} |
| footer { display: none !important; } |
| """ |
| ) as demo: |
|
|
| gr.HTML(""" |
| <div class="header-box"> |
| <h1>π§ MagicSupport Intent Classifier</h1> |
| <p> |
| High-precision semantic routing for automated customer support pipelines. |
| </p> |
| <div style="margin-top:12px;"> |
| <span class="badge domain-badge">E-commerce & Retail</span> |
| <span class="badge domain-badge">Account Management</span> |
| <span class="badge domain-badge">Billing & Refunds</span> |
| <span class="badge metric-badge">Based on Bitext Taxonomy</span> |
| </div> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=5): |
| text_input = gr.Textbox( |
| label="Input Customer Query", |
| placeholder="Type a customer message here (e.g., 'Where is my package?')...", |
| lines=3, |
| ) |
| |
| with gr.Row(): |
| top_k_slider = gr.Slider( |
| minimum=1, maximum=15, value=5, step=1, |
| label="Display Top-K Predictions" |
| ) |
| |
| with gr.Row(): |
| predict_btn = gr.Button("π Execute Prediction", variant="primary") |
| clear_btn = gr.Button("ποΈ Clear Interface", variant="secondary") |
| |
| gr.Examples( |
| examples=EXAMPLES, |
| inputs=[text_input, top_k_slider], |
| label="Actionable Test Scenarios", |
| examples_per_page=6, |
| ) |
|
|
| with gr.Column(scale=5): |
| result_output = gr.HTML(label="Inference Results") |
|
|
| with gr.Row(): |
| chart_output = gr.Label( |
| label="Full Semantic Distribution Map", |
| num_top_classes=app_backend.num_classes |
| ) |
|
|
| with gr.Accordion("βοΈ Technical Architecture & Model Details", open=False): |
| gr.Markdown(""" |
| ### Core Specifications |
| * **Target Model:** `learn-abc/magicSupport-intent-classifier` |
| * **Objective:** Multi-class text sequence classification for customer support routing. |
| * **Dataset Lineage:** Trained on the comprehensive `bitext/Bitext-customer-support-llm-chatbot-training-dataset`. |
| |
| ### Pipeline Features |
| * **Dynamic Label Resolution:** The UI heuristic engine automatically maps raw dataset labels (e.g., `change_shipping_address`) into clean, professional UI elements (e.g., Change Shipping Address) and assigns contextual iconography. |
| * **Optimized Inference:** Utilizes PyTorch `inference_mode` for reduced memory footprint and accelerated compute during forward passes. |
| """) |
|
|
| |
| predict_btn.click( |
| fn=app_backend.predict, |
| inputs=[text_input, top_k_slider], |
| outputs=[result_output, chart_output], |
| ) |
| text_input.submit( |
| fn=app_backend.predict, |
| inputs=[text_input, top_k_slider], |
| outputs=[result_output, chart_output], |
| ) |
| clear_btn.click( |
| fn=lambda: ("", 5, "", None), |
| outputs=[text_input, top_k_slider, result_output, chart_output], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |