| import os |
| from pathlib import Path |
|
|
| import gradio as gr |
| import torch |
| from huggingface_hub import snapshot_download |
| from transformers import pipeline |
|
|
| model_name = "databricks/dolly-v2-12b" |
|
|
| local_dir = f"./models/{model_name}" |
| if not Path(local_dir).exists() or len(os.listdir(local_dir)) == 0: |
| snapshot_download(model_name, local_dir=local_dir) |
|
|
| generate_text = pipeline(model=local_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, |
| device_map="auto") |
|
|
| theme = gr.themes.Monochrome( |
| primary_hue="indigo", |
| secondary_hue="blue", |
| neutral_hue="slate", |
| radius_size=gr.themes.sizes.radius_sm, |
| font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], |
| ) |
|
|
| with gr.Blocks(theme=theme) as demo: |
| chatbot = gr.Chatbot() |
| msg = gr.Textbox() |
| clear = gr.Button("Clear") |
|
|
|
|
| def user(user_message, history): |
| return "", history + [[user_message, None]] |
|
|
|
|
| def bot(history): |
| history[-1][1] = generate_text(history[-1][0]) |
| return history |
|
|
|
|
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot, chatbot, chatbot |
| ) |
| clear.click(lambda: None, None, chatbot, queue=False) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=3000) |
|
|
|
|
| |
| |
|
|
| |
| |