| import torch |
| import gradio as gr |
| from uuid import uuid4 |
| from threading import Thread |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
| from theme_dropdown import create_theme_dropdown |
|
|
|
|
| model_name = "RootYuan/RedLing-7B-v0.1" |
| max_new_tokens = 2048 |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
| DEFAULT_SYSTEM_MESSAGE = """ |
| A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. |
| """ |
|
|
| VISION_TOKEN = '<img>' |
| VISION_TOKENS = '\n' + VISION_TOKEN * 32 + '\n' |
| EOT_TOKEN = "<EOT>" |
|
|
| PROMPT_TEMPLATE = "USER:{user}<EOT>ASSISTANT:{assistant}{eos_token}" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) |
|
|
| dropdown, js = create_theme_dropdown() |
|
|
| def get_uuid(): |
| return str(uuid4()) |
|
|
|
|
| def add_text(message, history): |
| |
| return "", history + [[message, ""]] |
|
|
|
|
| def add_media(media, history): |
| media_name = media.name |
| media_format = media_name.split(".")[-1] |
| if media_format in ["jpg", "jpeg", "png"]: |
| media_type = "image" |
| history = history + [[(media_name, media_type), ""]] |
| return history |
|
|
|
|
| def convert_history_to_text(history): |
| conversations = [] |
| add_vision_tokens = False |
| for item in history[:-1]: |
| if isinstance(item[0], tuple): |
| add_vision_tokens = True |
| else: |
| if add_vision_tokens: |
| conversation = PROMPT_TEMPLATE.format( |
| media=VISION_TOKENS, |
| user=item[0], |
| assistant=item[1], |
| eos_token=EOT_TOKEN, |
| ) |
| add_vision_tokens = False |
| else: |
| conversation = PROMPT_TEMPLATE.format( |
| media='', |
| user=item[0], |
| assistant=item[1], |
| eos_token=EOT_TOKEN, |
| ) |
| conversations.append(conversation) |
| |
| text = "".join(conversations) |
| last = PROMPT_TEMPLATE.format( |
| media='', |
| user=history[-1][0], |
| assistant=history[-1][1], |
| eos_token='', |
| ) |
| text += last |
| |
| return text |
|
|
|
|
| def bot(history, temperature, top_k, sys_msg): |
| print(f"history: {history}") |
| |
| |
| messages = sys_msg + convert_history_to_text(history) |
| input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(device) |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
| generation_kwargs = dict( |
| input_ids=input_ids, |
| temperature=temperature, |
| max_new_tokens=max_new_tokens, |
| top_k=top_k, |
| streamer=streamer, |
| ) |
| |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| |
| generated_text = "" |
| for new_text in streamer: |
| generated_text += new_text |
| history[-1][1] = generated_text |
| yield history |
|
|
|
|
| with gr.Blocks(theme='sudeepshouche/minimalist') as demo: |
| with gr.Row().style(equal_height=True): |
| with gr.Column(scale=12): |
| gr.Markdown( |
| """ |
| # Visual Assistant Lab |
| """ |
| ) |
| with gr.Column(scale=2): |
| with gr.Box(): |
| dropdown.render() |
| toggle_dark = gr.Button(value="Toggle Dark").style(full_width=True) |
| dropdown.change(None, dropdown, None, _js=js) |
| toggle_dark.click(lambda: None, None, None, _js="() => {document.body.classList.toggle('dark')}") |
| |
| with gr.Row(): |
| with gr.Accordion("System Message", open=False): |
| sys_msg = gr.Textbox( |
| value=DEFAULT_SYSTEM_MESSAGE, |
| label="System Message", |
| info="Instruct the AI Assistant to set its beaviour", |
| show_label=False, |
| ) |
| with gr.Row(): |
| chatbot = gr.Chatbot(label="Assistant").style(height=500) |
| with gr.Row(): |
| with gr.Accordion("Advanced Settings:", open=False): |
| with gr.Row().style(equal_height=True): |
| with gr.Column(): |
| temperature = gr.Slider( |
| label="Temperature", |
| value=0.1, |
| minimum=0.0, |
| maximum=1.0, |
| step=0.1, |
| interactive=True, |
| info="Higher values produce more diverse outputs", |
| ) |
| with gr.Column(): |
| top_k = gr.Slider( |
| label="Top-k", |
| value=0, |
| minimum=0.0, |
| maximum=200, |
| step=1, |
| interactive=True, |
| info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", |
| ) |
| with gr.Row().style(equal_height=True): |
| with gr.Column(scale=12): |
| msg = gr.Textbox( |
| label="Chat Message Box", |
| placeholder="Hi! Type here, Press [Enter] to send...", |
| show_label=False, |
| ).style(container=False) |
| with gr.Column(scale=2): |
| send = gr.Button("Send") |
| with gr.Row().style(equal_height=True): |
| media = gr.UploadButton("Upload files", file_types=["image", "video", "audio"]) |
| stop = gr.Button("Stop") |
| clear = gr.Button("Clear") |
| |
| send_event = msg.submit( |
| fn=add_text, |
| inputs=[msg, chatbot], |
| outputs=[msg, chatbot], |
| queue=False, |
| ).then( |
| fn=bot, |
| inputs=[chatbot, temperature, top_k, sys_msg], |
| outputs=chatbot, |
| queue=True, |
| ) |
| |
| media.upload( |
| fn=add_media, |
| inputs=[media, chatbot], |
| outputs=[chatbot], |
| ) |
| |
| send_click_event = send.click( |
| fn=add_text, |
| inputs=[msg, chatbot], |
| outputs=[msg, chatbot], |
| queue=False, |
| ).then( |
| fn=bot, |
| inputs=[chatbot, temperature, top_k, sys_msg], |
| outputs=chatbot, |
| queue=True, |
| ) |
| |
| stop.click( |
| fn=None, |
| inputs=None, |
| outputs=None, |
| cancels=[send_event, send_click_event], |
| queue=False, |
| ) |
| clear.click(lambda: None, None, chatbot, queue=False) |
| |
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=128, concurrency_count=2) |
| demo.launch() |