| |
| |
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| import platform |
| import tempfile |
| import time |
| from typing import List, Dict |
| import zipfile |
|
|
| from cacheout import Cache |
| import gradio as gr |
| import huggingface_hub |
| import numpy as np |
| import torch |
|
|
| from project_settings import project_path, environment |
| from toolbox.torch.utils.data.tokenizers.pretrained_bert_tokenizer import PretrainedBertTokenizer |
| from toolbox.torch.utils.data.vocabulary import Vocabulary |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--waba_intent_examples_file", |
| default=(project_path / "waba_intent_examples.json").as_posix(), |
| type=str |
| ) |
| parser.add_argument( |
| "--waba_intent_md_file", |
| default=(project_path / "waba_intent.md").as_posix(), |
| type=str |
| ) |
| args = parser.parse_args() |
| return args |
|
|
|
|
| model_cache = Cache(maxsize=256, ttl=1 * 60, timer=time.time) |
|
|
|
|
| def load_waba_intent_model(repo_id: str): |
| model_local_dir = project_path / "trained_models/{}".format(repo_id) |
| model_local_dir.mkdir(parents=True, exist_ok=True) |
| hf_token = environment.get("hf_token") |
| huggingface_hub.login(token=hf_token) |
| huggingface_hub.snapshot_download( |
| repo_id=repo_id, |
| local_dir=model_local_dir |
| ) |
|
|
| model = torch.jit.load((model_local_dir / "final.zip").as_posix()) |
| vocabulary = Vocabulary.from_files((model_local_dir / "vocabulary").as_posix()) |
| tokenizer = PretrainedBertTokenizer(model_local_dir.as_posix()) |
|
|
| result = { |
| "model": model, |
| "vocabulary": vocabulary, |
| "tokenizer": tokenizer, |
| } |
| return result |
|
|
|
|
| def click_waba_intent_button(repo_id: str, text: str): |
| model_group = model_cache.get(repo_id) |
| if model_group is None: |
| model_group = load_waba_intent_model(repo_id) |
| model_cache.set(key=repo_id, value=model_group) |
|
|
| model = model_group["model"] |
| vocabulary = model_group["vocabulary"] |
| tokenizer = model_group["tokenizer"] |
|
|
| tokens: List[str] = tokenizer.tokenize(text) |
| tokens: List[int] = [vocabulary.get_token_index(token, namespace="tokens") for token in tokens] |
|
|
| if len(tokens) < 5: |
| tokens = vocabulary.pad_or_truncate_ids_by_max_length(tokens, max_length=5) |
| batch_tokens = [tokens] |
| batch_tokens = torch.from_numpy(np.array(batch_tokens)) |
|
|
| outputs = model.forward(batch_tokens) |
|
|
| probs = outputs["probs"] |
| argmax = torch.argmax(probs, dim=-1) |
| probs = probs.tolist()[0] |
| argmax = argmax.tolist()[0] |
|
|
| label_str = vocabulary.get_token_from_index(argmax, namespace="labels") |
| prob = probs[argmax] |
| prob = round(prob, 4) |
|
|
| return label_str, prob |
|
|
|
|
| def main(): |
| args = get_args() |
|
|
| brief_description = """ |
| ## Text Classification |
| """ |
|
|
| |
| with open(args.waba_intent_examples_file, "r", encoding="utf-8") as f: |
| waba_intent_examples = json.load(f) |
| with open(args.waba_intent_md_file, "r", encoding="utf-8") as f: |
| waba_intent_md = f.read() |
|
|
| with gr.Blocks() as blocks: |
| gr.Markdown(value=brief_description) |
|
|
| with gr.Row(): |
| with gr.Column(scale=5): |
| with gr.Tabs(): |
| with gr.TabItem("waba_intent"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| waba_intent_repo_id = gr.Dropdown( |
| choices=["nxcloud/waba_intent_en"], |
| value="nxcloud/waba_intent_en", |
| label="repo_id" |
| ) |
| waba_intent_text = gr.Textbox(label="text", max_lines=5) |
| waba_intent_button = gr.Button("predict", variant="primary") |
|
|
| with gr.Column(scale=1): |
| waba_intent_label = gr.Textbox(label="label") |
| waba_intent_prob = gr.Textbox(label="prob") |
|
|
| |
| gr.Examples( |
| examples=waba_intent_examples, |
| inputs=[ |
| waba_intent_repo_id, |
| waba_intent_text, |
| ], |
| outputs=[ |
| waba_intent_label, |
| waba_intent_prob |
| ], |
| fn=click_waba_intent_button |
| ) |
|
|
| |
| gr.Markdown(value=waba_intent_md) |
|
|
| |
| waba_intent_button.click( |
| fn=click_waba_intent_button, |
| inputs=[ |
| waba_intent_repo_id, |
| waba_intent_text, |
| ], |
| outputs=[ |
| waba_intent_label, |
| waba_intent_prob |
| ], |
| ) |
|
|
| blocks.queue().launch( |
| share=False if platform.system() == "Windows" else False, |
| server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
| server_port=7860 |
| ) |
| return |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|