| import os |
|
|
| import dotenv |
| import gradio as gr |
| from mistralai.client import MistralClient |
| from mistralai.models.chat_completion import ChatMessage |
|
|
| dotenv.load_dotenv() |
|
|
|
|
| MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY") |
|
|
|
|
| TITLE = """<h1 align="center">MistralAI Playground 💬</h1>""" |
| DUPLICATE = """ |
| <div style="text-align: center; display: flex; justify-content: center; align-items: center;"> |
| <a href="https://huggingface.co/spaces/douglarek/MistralAI?duplicate=true"> |
| <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space" style="margin-right: 10px;"> |
| </a> |
| <span>Duplicate the Space and run securely with your |
| <a href="https://console.mistral.ai/user/api-keys"> Mistral API KEY</a>. |
| </span> |
| </div> |
| """ |
|
|
| AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png") |
|
|
|
|
| chatbot_component = gr.Chatbot( |
| label="MistralAI", bubble_full_width=False, avatar_images=AVATAR_IMAGES, scale=2, height=400 |
| ) |
| text_prompt_component = gr.Textbox(placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8) |
| run_button_component = gr.Button(value="Run", variant="primary", scale=1) |
| clear_button_component = gr.ClearButton(value="Clear", variant="secondary", scale=1) |
| mistral_key_component = gr.Textbox( |
| label="MISTRAL API KEY", |
| value="", |
| type="password", |
| placeholder="...", |
| info="You have to provide your own MISTRAL_API_KEY for this app to function properly", |
| visible=MISTRAL_API_KEY is None, |
| ) |
| model_component = gr.Dropdown( |
| choices=["mistral-tiny", "mistral-small", "mistral-medium"], |
| label="Model", |
| value="mistral-small", |
| scale=1, |
| type="value", |
| ) |
| temperature_component = gr.Slider( |
| minimum=0, |
| maximum=1.0, |
| value=0.7, |
| step=0.05, |
| label="Temperature", |
| info=( |
| "What sampling temperature to use, between 0.0 and 1.0. " |
| "Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." |
| "We generally recommend altering this or top_p but not both." |
| ), |
| ) |
|
|
| user_inputs = [ |
| text_prompt_component, |
| chatbot_component, |
| ] |
| bot_inputs = [ |
| mistral_key_component, |
| model_component, |
| temperature_component, |
| chatbot_component, |
| ] |
|
|
|
|
| client: MistralClient = None |
|
|
|
|
| def preprocess_chat_history(history): |
| chat_history = [] |
| for human, assistant in history: |
| if human: |
| chat_history.append(ChatMessage(role="user", content=human)) |
| if assistant: |
| chat_history.append(ChatMessage(role="assistant", content=assistant)) |
| return chat_history |
|
|
|
|
| def bot( |
| mistral_key: str | None, |
| model: str, |
| temperature: float, |
| history, |
| ): |
| if not history: |
| return history |
|
|
| mistral_key = mistral_key or MISTRAL_API_KEY |
| if not mistral_key: |
| raise ValueError("MISTRAL_API_KEY is not set. Please follow the instructions in the README to set it up.") |
| global client |
| if client is None: |
| client = MistralClient(api_key=mistral_key) |
|
|
| chat_history = preprocess_chat_history(history) |
| history[-1][1] = "" |
| for chunk in client.chat_stream(model=model, messages=chat_history, temperature=temperature): |
| print("chunk", chunk) |
| if chunk.choices and chunk.choices[0].delta.content: |
| history[-1][1] += chunk.choices[0].delta.content |
| yield history |
|
|
|
|
| def user(text_prompt: str, history): |
| if text_prompt: |
| history.append((text_prompt, None)) |
| return "", history |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.HTML(TITLE) |
| gr.HTML(DUPLICATE) |
| with gr.Column(): |
| mistral_key_component.render() |
| chatbot_component.render() |
| with gr.Row(): |
| text_prompt_component.render() |
| run_button_component.render() |
| clear_button_component.render() |
| with gr.Accordion("Parameters", open=False): |
| model_component.render() |
| temperature_component.render() |
| run_button_component.click( |
| fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False |
| ).then( |
| fn=bot, |
| inputs=bot_inputs, |
| outputs=[chatbot_component], |
| ) |
| clear_button_component.click(lambda: (None, None), outputs=[text_prompt_component, chatbot_component], queue=False) |
| text_prompt_component.submit( |
| fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False |
| ).then( |
| fn=bot, |
| inputs=bot_inputs, |
| outputs=[chatbot_component], |
| ) |
|
|
|
|
| demo.queue(max_size=99).launch(debug=False, show_error=True) |
|
|