| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import pickle |
| import pandas as pd |
| from transformers import RobertaTokenizerFast, RobertaModel |
|
|
| |
| |
| with open("label_mappings.pkl", "rb") as f: |
| label_mappings = pickle.load(f) |
|
|
| label_to_team = label_mappings.get("label_to_team", {}) |
| label_to_email = label_mappings.get("label_to_email", {}) |
|
|
| |
| tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") |
|
|
| |
| class RoBertaClassifier(nn.Module): |
| def __init__(self, num_teams, num_emails): |
| super(RoBertaClassifier, self).__init__() |
| self.roberta = RobertaModel.from_pretrained("roberta-base") |
| self.team_classifier = nn.Linear(self.roberta.config.hidden_size, num_teams) |
| self.email_classifier = nn.Linear(self.roberta.config.hidden_size, num_emails) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
| cls_output = outputs.last_hidden_state[:, 0, :] |
|
|
| team_logits = self.team_classifier(cls_output) |
| email_logits = self.email_classifier(cls_output) |
|
|
| return team_logits, email_logits |
|
|
| |
| num_teams = len(label_to_team) |
| num_emails = len(label_to_email) |
| model = RoBertaClassifier(num_teams, num_emails) |
| checkpoint = torch.load("ticket_classification_model.pth", map_location=torch.device("cpu")) |
| filtered_checkpoint = {k: v for k, v in checkpoint.items() if k in model.state_dict()} |
| model.load_state_dict(filtered_checkpoint, strict=False) |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| model.to(device) |
| model.eval() |
|
|
| |
| def predict_tickets(ticket_descriptions): |
| predictions = [] |
| csv_data = [] |
| for idx, description in enumerate(ticket_descriptions, start=1): |
| inputs = tokenizer(description, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device) |
| with torch.no_grad(): |
| team_logits, email_logits = model(inputs.input_ids, inputs.attention_mask) |
| predicted_team_index = team_logits.argmax(dim=-1).cpu().item() |
| predicted_email_index = email_logits.argmax(dim=-1).cpu().item() |
| predicted_team = label_to_team.get(predicted_team_index, "Unknown Team") |
| predicted_email = label_to_email.get(predicted_email_index, "Unknown Email") |
| predictions.append(f"**{idx}. {description}**\n - **Assigned Team:** {predicted_team}\n - **Team Email:** {predicted_email}\n") |
| csv_data.append([idx, description, predicted_team, predicted_email]) |
|
|
| df = pd.DataFrame(csv_data, columns=["Index", "Description", "Assigned Team", "Team Email"]) |
| csv_file = "ticket-predictions.csv" |
| df.to_csv(csv_file, index=False) |
| return "\n".join(predictions), csv_file |
|
|
| |
| def gradio_predict(option, text_input, file_input): |
| if option == "Enter Text": |
| descriptions = text_input.split("\n") |
| descriptions = [desc.strip() for desc in descriptions if desc.strip()] |
| elif option == "Upload CSV" and file_input is not None: |
| df = pd.read_csv(file_input) |
| if "Description" not in df.columns: |
| return "⚠️ Error: CSV must contain a 'Description' column.", None |
| descriptions = df["Description"].tolist() |
| else: |
| return "⚠️ Please provide input.", None |
|
|
| results, csv_file = predict_tickets(descriptions) |
| return results, csv_file |
|
|
| def clear_inputs(): |
| return "Enter Text", "", None, "", None |
|
|
| |
| custom_css = """ |
| .gradio-container { |
| max-width: 1000px !important; |
| margin: auto !important; |
| } |
| #title { |
| text-align: center; |
| font-size: 26px !important; |
| font-weight: bold; |
| } |
| #predict-button, #clear-button, #download-button { |
| width: 100% !important; |
| height: 55px !important; |
| font-size: 18px !important; |
| } |
| #results-box { |
| height: 350px !important; |
| overflow-y: auto !important; |
| background: #f9f9f9; |
| padding: 15px; |
| border-radius: 10px; |
| font-size: 16px; |
| } |
| /* Reduce vertical padding for the radio component */ |
| #choose_input_method { |
| padding-top: 5px !important; |
| padding-bottom: 5px !important; |
| } |
| /* Force both input components to have the same min-height */ |
| #text_input, #file_input { |
| min-height: 200px !important; |
| /* Optionally add a consistent border and padding to match styling */ |
| border: 1px solid #ccc; |
| padding: 10px; |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=custom_css) as app: |
| gr.Markdown( |
| """ |
| # AI Solution for Defect Ticket Classification |
| |
| **Supports:** Multi-line text input & CSV upload. |
| **Output:** Text results & downloadable CSV file. |
| **Model:** Fine-tuned **RoBERTa** for classification. |
| |
| Enter ticket Description/Comment/Summary or upload a **CSV file** to predict Assigned Team & Team Email. |
| """, |
| elem_id="title" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| option = gr.Radio( |
| ["Enter Text", "Upload CSV"], |
| label="📝 Choose Input Method", |
| value="Enter Text", |
| elem_id="choose_input_method" |
| ) |
|
|
| |
| text_input = gr.Textbox( |
| label="Enter Ticket Description/Comment/Summary (One per line)", |
| visible=True, |
| lines=6, |
| placeholder="Example:\n - Database performance issue\n - Login fails for admin users...", |
| elem_id="text_input" |
| ) |
|
|
| file_input = gr.File( |
| label="📂 Upload CSV (Optional)", |
| type="filepath", |
| visible=False, |
| elem_id="file_input" |
| ) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("## Prediction Results") |
| results_output = gr.Markdown(elem_id="results-box", visible=True) |
| download_csv = gr.File(label="📥 Download Predictions CSV", interactive=False) |
|
|
| with gr.Row(): |
| predict_btn = gr.Button("PREDICT", variant="primary") |
| clear_btn = gr.Button("CLEAR", variant="secondary") |
|
|
| |
| def toggle_input(selected_option): |
| if selected_option == "Enter Text": |
| return gr.update(visible=True), gr.update(visible=False) |
| else: |
| return gr.update(visible=False), gr.update(visible=True) |
|
|
| option.change(fn=toggle_input, inputs=[option], outputs=[text_input, file_input]) |
| predict_btn.click(fn=gradio_predict, inputs=[option, text_input, file_input], outputs=[results_output, download_csv]) |
| clear_btn.click(fn=clear_inputs, inputs=[], outputs=[option, text_input, file_input, results_output, download_csv]) |
|
|
| |
| gr.Markdown("---") |
| gr.HTML( |
| """ |
| <div style="text-align: center; color: gray; padding-top: 10px;"> |
| <p>Developed by NYP student @ Min Thein Win: Student ID: 3907578Y</p> |
| </div> |
| """ |
| ) |
|
|
| |
| app.launch(share=True) |
|
|