| import os |
| from dotenv import load_dotenv |
| import gradio as gr |
| from gradio.components import Textbox, Button, Slider, Checkbox |
| from AinaTheme import theme |
| from huggingface_hub import InferenceClient |
| from urllib.error import HTTPError |
|
|
| load_dotenv() |
|
|
| def generate(prompt, model_parameters): |
|
|
| try: |
| output = client.text_generation(prompt, **model_parameters, return_full_text=True) |
| return output |
| except HTTPError as err: |
| if err.code == 400: |
| gr.Warning("The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET.") |
| except: |
| gr.Warning('Inference endpoint is not available right now. Please try again later.') |
|
|
|
|
| client = InferenceClient( |
| os.environ.get("HF_INFERENCE_ENDPOINT_URL"), |
| token=os.environ.get("HF_INFERENCE_ENDPOINT_TOKEN") |
| ) |
|
|
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100)) |
| MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=100)) |
| SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default=True) == "True" |
|
|
|
|
| def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature): |
| if input_.strip() == "": |
| gr.Warning('Not possible to inference an empty input') |
| return None |
| |
| model_parameters = { |
| "max_new_tokens": max_new_tokens, |
| "repetition_penalty": repetition_penalty, |
| "top_k": top_k, |
| "top_p": top_p, |
| "do_sample": do_sample, |
| "temperature": temperature |
| } |
| |
| output = generate(input_, model_parameters) |
| |
| return output |
| |
| def change_interactive(text): |
| if len(text.strip()) > MAX_INPUT_CHARACTERS: |
| return gr.update(interactive = True), gr.update(interactive = False) |
| if (len(text) == 0): |
| return gr.update(interactive = True), gr.update(interactive = False) |
| return gr.update(interactive = True), gr.update(interactive = True) |
| |
| def clear(): |
| return ( |
| None, |
| None, |
| gr.update(value=MAX_NEW_TOKENS), |
| gr.update(value=1.2), |
| gr.update(value=50), |
| gr.update(value=0.95), |
| gr.update(value=True), |
| gr.update(value=0.5), |
| ) |
| |
| def gradio_app(): |
| with gr.Blocks(theme=theme) as demo: |
| with gr.Row(): |
| with gr.Column(scale=0.1): |
| gr.Image("ginesta_small.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False) |
| with gr.Column(): |
| gr.Markdown( |
| """# AIMestre |
| |
| Basat en el model [Flor](https://huggingface.co/projecte-aina/FLOR-6.3B) del projecte AINA. |
| |
| """ |
| ) |
| with gr.Row(equal_height=True): |
| with gr.Column(variant="panel"): |
| placeholder_max_token = Textbox( |
| visible=False, |
| interactive=False, |
| value= MAX_INPUT_CHARACTERS |
| ) |
| input_ = Textbox( |
| lines=11, |
| label="Posa aquí el teu escrit en català.", |
| placeholder="e.g. El mercat del barri és fantàstic hi pots trobar." |
| ) |
| with gr.Row(variant="panel", equal_height=True): |
| gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""") |
| gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span> / {MAX_INPUT_CHARACTERS}</span>""") |
|
|
| with gr.Row(variant="panel"): |
| with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI): |
| max_new_tokens = Slider( |
| minimum=1, |
| maximum=200, |
| step=1, |
| value=MAX_NEW_TOKENS, |
| label="Max tokens" |
| ) |
| repetition_penalty = Slider( |
| minimum=0.1, |
| maximum=10, |
| step=0.1, |
| value=1.2, |
| label="Repetition penalty" |
| ) |
| top_k = Slider( |
| minimum=1, |
| maximum=100, |
| step=1, |
| value=50, |
| label="Top k" |
| ) |
| top_p = Slider( |
| minimum=0.01, |
| maximum=0.99, |
| value=0.95, |
| label="Top p" |
| ) |
| do_sample = Checkbox( |
| value=True, |
| label="Do sample" |
| ) |
| temperature = Slider( |
| minimum=0, |
| maximum=1, |
| value=0.5, |
| label="Temperature" |
| ) |
| with gr.Column(variant="panel"): |
| output = Textbox( |
| lines=11, |
| label="El mestre diu...", |
| interactive=False, |
| show_copy_button=True |
| ) |
| with gr.Row(variant="panel"): |
| clear_btn = Button( |
| "Clear", |
| ) |
| submit_btn = Button( |
| "Submit", |
| variant="primary", |
| interactive=False |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=0.5): |
| gr.Examples( |
| label="Short prompts:", |
| examples=[ |
| ["""La capital de Suècia"""], |
| ], |
| inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], |
| outputs=output, |
| fn=submit_input, |
| ) |
|
|
| gr.Examples( |
| label="Zero-shot prompts", |
| examples=[ |
| ["Tradueix del Castellà al Català la següent frase: \"Eso es pan comido.\" \nTraducció:"], |
| ], |
| inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], |
| outputs=output, |
| fn=submit_input, |
| ) |
| gr.Examples( |
| label="Few-Shot prompts:", |
| examples=[ |
| ["""Oració: Els sons melòdics produeixen una sensació de calma i benestar en l'individu. \nParàfrasi: La música és molt relaxant i reconfortant.\n----\nOració: L'animal domèstic mostra una gran alegria i satisfacció. \nParàfrasi: El gos és molt feliç. \n----\nOració: El vehicle es va trencar i vaig haver de contactar amb el servei de remolc perquè el transportés. \nParàfrasi: El cotxe es va trencar i vaig haver de trucar la grua. \n----\nOració: El professor va explicar els conceptes de manera clara i concisa. \nParàfrasi:"""], |
| ], |
| inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], |
| outputs=output, |
| fn=submit_input, |
| ) |
|
|
| |
| |
| input_.change(fn=change_interactive, inputs=[input_], outputs=[clear_btn, submit_btn], api_name=False) |
| |
| input_.change(fn=None, inputs=[input_], api_name=False, js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """) |
|
|
| input_.change(fn=None, inputs=[input_, placeholder_max_token], api_name=False, js="""(i, m) => { |
| document.getElementById('inputlenght').textContent = i.length + ' ' |
| document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; |
| }""") |
| |
| clear_btn.click(fn=clear, inputs=[], outputs=[input_, output, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], queue=False, api_name=False) |
| submit_btn.click(fn=submit_input, inputs=[input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature], outputs=[output], api_name="get-results") |
|
|
| demo.launch(show_api=True) |
|
|
| if __name__ == "__main__": |
| gradio_app() |