| import os
|
| from threading import Thread
|
| from typing import Iterator
|
|
|
| import gradio as gr
|
| import spaces
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
| MAX_MAX_NEW_TOKENS = 2048
|
| DEFAULT_MAX_NEW_TOKENS = 1024
|
|
|
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
|
| DESCRIPTION = """\
|
| # DeepSeek-6.7B-Chat
|
|
|
| This Space demonstrates model [DeepSeek-Coder](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) by DeepSeek, a code model with 6.7B parameters fine-tuned for chat instructions.
|
| """
|
|
|
| if not torch.cuda.is_available():
|
| DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
|
|
|
|
| if torch.cuda.is_available():
|
| model_id = "deepseek-ai/deepseek-coder-6.7b-instruct"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
| tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| tokenizer.use_default_system_prompt = False
|
|
|
|
|
|
|
| @spaces.GPU
|
| def generate(
|
| message: str,
|
| chat_history: list,
|
| system_prompt: str,
|
| max_new_tokens: int = 1024,
|
| temperature: float = 0.6,
|
| top_p: float = 0.9,
|
| top_k: int = 50,
|
| repetition_penalty: float = 1,
|
| ) -> Iterator[str]:
|
| conversation = []
|
| if system_prompt:
|
| conversation.append({"role": "system", "content": system_prompt})
|
| for user, assistant in chat_history:
|
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
| conversation.append({"role": "user", "content": message})
|
|
|
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
|
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
| input_ids = input_ids.to(model.device)
|
|
|
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| generate_kwargs = dict(
|
| {"input_ids": input_ids},
|
| streamer=streamer,
|
| max_new_tokens=max_new_tokens,
|
| do_sample=False,
|
| num_beams=1,
|
| repetition_penalty=repetition_penalty,
|
| eos_token_id=tokenizer.eos_token_id
|
| )
|
| t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| t.start()
|
|
|
| outputs = []
|
| for text in streamer:
|
| outputs.append(text)
|
| yield "".join(outputs).replace("<|EOT|>","")
|
|
|
|
|
| chat_interface = gr.ChatInterface(
|
| fn=generate,
|
| additional_inputs=[
|
| gr.Textbox(label="System prompt", lines=6),
|
| gr.Slider(
|
| label="Max new tokens",
|
| minimum=1,
|
| maximum=MAX_MAX_NEW_TOKENS,
|
| step=1,
|
| value=DEFAULT_MAX_NEW_TOKENS,
|
| ),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| gr.Slider(
|
| label="Top-p (nucleus sampling)",
|
| minimum=0.05,
|
| maximum=1.0,
|
| step=0.05,
|
| value=0.9,
|
| ),
|
| gr.Slider(
|
| label="Top-k",
|
| minimum=1,
|
| maximum=1000,
|
| step=1,
|
| value=50,
|
| ),
|
| gr.Slider(
|
| label="Repetition penalty",
|
| minimum=1.0,
|
| maximum=2.0,
|
| step=0.05,
|
| value=1,
|
| ),
|
| ],
|
| stop_btn=None,
|
| examples=[
|
| ["implement snake game using pygame"],
|
| ["Can you explain briefly to me what is the Python programming language?"],
|
| ["write a program to find the factorial of a number"],
|
| ],
|
| )
|
|
|
| with gr.Blocks(css="style.css") as demo:
|
| gr.Markdown(DESCRIPTION)
|
| chat_interface.render()
|
|
|
| if __name__ == "__main__":
|
| demo.queue().launch(share=True)
|
|
|