| |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import gradio as gr |
|
|
| |
| checkpoint = "microsoft/DialoGPT-medium" |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
|
| |
| chat_history_ids = None |
|
|
| def respond(user_input, history=[]): |
| global chat_history_ids |
| |
| |
| new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') |
| |
| if chat_history_ids is not None: |
| bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) |
| else: |
| bot_input_ids = new_input_ids |
|
|
| chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) |
|
|
| output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
| history.append((user_input, output)) |
| return history, history |
|
|
| |
| with gr.Blocks() as demo: |
| chatbot = gr.Chatbot() |
| msg = gr.Textbox(label="Type your message here") |
| clear = gr.Button("Clear Chat") |
|
|
| state = gr.State([]) |
|
|
| msg.submit(respond, [msg, state], [chatbot, state]) |
| clear.click(lambda: ([], []), None, [chatbot, state]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |