| |
|
|
| import os |
| import warnings |
| from collections.abc import Iterator |
| from threading import Thread |
| from typing import List, Dict, Optional, Tuple |
| import time |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| try: |
| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| TextIteratorStreamer |
| ) |
| TRANSFORMERS_AVAILABLE = True |
| except ImportError: |
| TRANSFORMERS_AVAILABLE = False |
|
|
| try: |
| import gradio as gr |
| GRADIO_AVAILABLE = True |
| except ImportError: |
| GRADIO_AVAILABLE = False |
|
|
| class CPULLMChat: |
| def __init__(self): |
| self.models = { |
| "microsoft/DialoGPT-medium": "DialoGPT Medium (Recommended for chat)", |
| "microsoft/DialoGPT-small": "DialoGPT Small (Faster)", |
| "distilgpt2": "DistilGPT2 (Very fast)", |
| "gpt2": "GPT2 (Standard)", |
| "facebook/blenderbot-400M-distill": "BlenderBot (Conversational)" |
| } |
| |
| self.current_model = None |
| self.current_tokenizer = None |
| self.current_model_name = None |
| self.model_loaded = False |
| |
| |
| self.max_input_length = 2048 |
| self.device = "cpu" |
| |
| def load_model(self, model_name: str, progress=gr.Progress()) -> str: |
| """Load the selected model""" |
| if not TRANSFORMERS_AVAILABLE: |
| return "β Error: transformers library not installed. Run: pip install torch transformers" |
| |
| if model_name == self.current_model_name and self.model_loaded: |
| return f"β
Model {model_name} is already loaded!" |
| |
| try: |
| progress(0.1, desc="Loading tokenizer...") |
| |
| |
| self.current_tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| padding_side="left" |
| ) |
| if self.current_tokenizer.pad_token is None: |
| self.current_tokenizer.pad_token = self.current_tokenizer.eos_token |
| |
| progress(0.5, desc="Loading model...") |
| |
| |
| self.current_model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float32, |
| device_map={"": self.device}, |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| self.current_model.eval() |
| |
| self.current_model_name = model_name |
| self.model_loaded = True |
| |
| progress(1.0, desc="Model loaded successfully!") |
| |
| return f"β
Successfully loaded: {model_name}" |
| |
| except Exception as e: |
| self.model_loaded = False |
| return f"β Failed to load model {model_name}: {str(e)}" |
| |
| def generate_response( |
| self, |
| message: str, |
| chat_history: List[List[str]], |
| max_new_tokens: int = 256, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: int = 50, |
| repetition_penalty: float = 1.1, |
| ) -> Iterator[str]: |
| """Generate response with streaming""" |
| |
| if not self.model_loaded: |
| yield "β Please load a model first!" |
| return |
| |
| if not message.strip(): |
| yield "Please enter a message." |
| return |
| |
| try: |
| |
| conversation_text = "" |
| |
| |
| recent_history = chat_history[-5:] if len(chat_history) > 5 else chat_history |
| |
| if "DialoGPT" in self.current_model_name: |
| |
| chat_history_ids = None |
| |
| |
| for user_msg, bot_msg in recent_history: |
| if user_msg: |
| user_input_ids = self.current_tokenizer.encode( |
| user_msg + self.current_tokenizer.eos_token, |
| return_tensors='pt' |
| ) |
| if chat_history_ids is not None: |
| chat_history_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1) |
| else: |
| chat_history_ids = user_input_ids |
| |
| if bot_msg: |
| bot_input_ids = self.current_tokenizer.encode( |
| bot_msg + self.current_tokenizer.eos_token, |
| return_tensors='pt' |
| ) |
| if chat_history_ids is not None: |
| chat_history_ids = torch.cat([chat_history_ids, bot_input_ids], dim=-1) |
| else: |
| chat_history_ids = bot_input_ids |
| |
| |
| new_user_input_ids = self.current_tokenizer.encode( |
| message + self.current_tokenizer.eos_token, |
| return_tensors='pt' |
| ) |
| |
| if chat_history_ids is not None: |
| input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
| else: |
| input_ids = new_user_input_ids |
| |
| else: |
| |
| for user_msg, bot_msg in recent_history: |
| if user_msg and bot_msg: |
| conversation_text += f"User: {user_msg}\nAssistant: {bot_msg}\n" |
| |
| conversation_text += f"User: {message}\nAssistant:" |
| input_ids = self.current_tokenizer.encode(conversation_text, return_tensors='pt') |
| |
| |
| if input_ids.shape[1] > self.max_input_length: |
| input_ids = input_ids[:, -self.max_input_length:] |
| |
| |
| streamer = TextIteratorStreamer( |
| self.current_tokenizer, |
| timeout=60.0, |
| skip_prompt=True, |
| skip_special_tokens=True |
| ) |
| |
| generation_kwargs = { |
| 'input_ids': input_ids, |
| 'streamer': streamer, |
| 'max_new_tokens': max_new_tokens, |
| 'temperature': temperature, |
| 'top_p': top_p, |
| 'top_k': top_k, |
| 'repetition_penalty': repetition_penalty, |
| 'do_sample': True, |
| 'pad_token_id': self.current_tokenizer.pad_token_id, |
| 'eos_token_id': self.current_tokenizer.eos_token_id, |
| 'no_repeat_ngram_size': 2, |
| } |
| |
| |
| generation_thread = Thread( |
| target=self.current_model.generate, |
| kwargs=generation_kwargs |
| ) |
| generation_thread.start() |
| |
| |
| partial_response = "" |
| for new_text in streamer: |
| partial_response += new_text |
| yield partial_response |
| |
| except Exception as e: |
| yield f"β Generation error: {str(e)}" |
|
|
| def create_interface(): |
| """Create the Gradio interface""" |
| |
| if not GRADIO_AVAILABLE: |
| print("β Error: gradio library not installed. Run: pip install gradio") |
| return None |
| |
| if not TRANSFORMERS_AVAILABLE: |
| print("β Error: transformers library not installed. Run: pip install torch transformers") |
| return None |
| |
| |
| chat_system = CPULLMChat() |
| |
| |
| css = """ |
| .gradio-container { |
| max-width: 1200px; |
| margin: auto; |
| } |
| .chat-message { |
| padding: 10px; |
| margin: 5px 0; |
| border-radius: 10px; |
| } |
| .user-message { |
| background-color: #e3f2fd; |
| margin-left: 20%; |
| } |
| .bot-message { |
| background-color: #f1f8e9; |
| margin-right: 20%; |
| } |
| """ |
| |
| with gr.Blocks(css=css, title="CPU LLM Chat") as demo: |
| gr.Markdown("# π€ CPU-Optimized LLM Chat") |
| gr.Markdown("*A lightweight chat interface for running language models on CPU*") |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| model_dropdown = gr.Dropdown( |
| choices=list(chat_system.models.keys()), |
| value="microsoft/DialoGPT-medium", |
| label="Select Model", |
| info="Choose a model to load. DialoGPT models work best for chat." |
| ) |
| load_btn = gr.Button("π Load Model", variant="primary") |
| model_status = gr.Textbox( |
| label="Model Status", |
| value="No model loaded", |
| interactive=False |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### π‘ Model Info") |
| gr.Markdown(""" |
| - **DialoGPT Medium**: Best quality, slower |
| - **DialoGPT Small**: Good balance |
| - **DistilGPT2**: Fastest option |
| - **GPT2**: General purpose |
| - **BlenderBot**: Conversational AI |
| """) |
| |
| |
| chatbot = gr.Chatbot( |
| label="Chat History", |
| height=400, |
| show_label=True, |
| container=True |
| ) |
| |
| with gr.Row(): |
| msg = gr.Textbox( |
| label="Your Message", |
| placeholder="Type your message here... (Press Ctrl+Enter to send)", |
| lines=3, |
| max_lines=10, |
| show_label=False |
| ) |
| send_btn = gr.Button("π€ Send", variant="primary") |
| |
| |
| with gr.Accordion("βοΈ Generation Parameters", open=False): |
| with gr.Row(): |
| max_tokens = gr.Slider( |
| minimum=50, |
| maximum=512, |
| value=256, |
| step=10, |
| label="Max New Tokens", |
| info="Maximum number of tokens to generate" |
| ) |
| temperature = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=0.7, |
| step=0.1, |
| label="Temperature", |
| info="Higher values = more creative, lower = more focused" |
| ) |
| |
| with gr.Row(): |
| top_p = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.9, |
| step=0.05, |
| label="Top-p", |
| info="Nucleus sampling parameter" |
| ) |
| top_k = gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=50, |
| step=1, |
| label="Top-k", |
| info="Top-k sampling parameter" |
| ) |
| repetition_penalty = gr.Slider( |
| minimum=1.0, |
| maximum=2.0, |
| value=1.1, |
| step=0.05, |
| label="Repetition Penalty", |
| info="Penalty for repeating tokens" |
| ) |
| |
| |
| with gr.Accordion("π¬ Example Messages", open=False): |
| examples = [ |
| "Hello! How are you today?", |
| "Tell me a short story about a robot.", |
| "What's the difference between AI and machine learning?", |
| "Can you help me write a poem about nature?", |
| "Explain quantum computing in simple terms.", |
| ] |
| |
| example_buttons = [] |
| for example in examples: |
| btn = gr.Button(example, variant="secondary") |
| example_buttons.append(btn) |
| |
| |
| clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") |
| |
| |
| def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty): |
| if not chat_system.model_loaded: |
| history.append([message, "β Please load a model first!"]) |
| return history, "" |
| |
| history.append([message, ""]) |
| |
| for partial_response in chat_system.generate_response( |
| message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty |
| ): |
| history[-1][1] = partial_response |
| yield history, "" |
| |
| def load_model_handler(model_name, progress=gr.Progress()): |
| return chat_system.load_model(model_name, progress) |
| |
| def set_example(example_text): |
| return example_text |
| |
| def clear_chat(): |
| return [], "" |
| |
| |
| load_btn.click(load_model_handler, inputs=[model_dropdown], outputs=[model_status]) |
| |
| msg.submit(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) |
| send_btn.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) |
| |
| clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
| |
| |
| for btn, example in zip(example_buttons, examples): |
| btn.click(set_example, inputs=[gr.State(example)], outputs=[msg]) |
| |
| |
| gr.Markdown(""" |
| --- |
| ### π Instructions: |
| 1. **Select and load a model** using the dropdown and "Load Model" button |
| 2. **Wait for the model to load** (may take 1-2 minutes on first load) |
| 3. **Start chatting** once you see "β
Successfully loaded" message |
| 4. **Adjust parameters** if needed for different response styles |
| |
| ### π» System Requirements: |
| - CPU with at least 4GB RAM available |
| - Python 3.8+ with torch and transformers installed |
| |
| ### β‘ Performance Tips: |
| - Use DialoGPT-small for fastest responses |
| - Keep max tokens under 300 for better speed |
| - Lower temperature (0.3-0.7) for more consistent responses |
| """) |
| |
| return demo |
|
|
| def main(): |
| """Main function to run the application""" |
| |
| print("===== CPU LLM Chat Application =====") |
| print("Checking dependencies...") |
| |
| if not GRADIO_AVAILABLE: |
| print("β Gradio not found. Install with: pip install gradio") |
| return |
| |
| if not TRANSFORMERS_AVAILABLE: |
| print("β Transformers not found. Install with: pip install torch transformers") |
| return |
| |
| print("β
All dependencies found!") |
| print("Starting web interface...") |
| |
| try: |
| demo = create_interface() |
| if demo: |
| |
| demo.queue(max_size=10).launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True, |
| show_tips=True, |
| inbrowser=False |
| ) |
| except KeyboardInterrupt: |
| print("\nπ Application stopped by user") |
| except Exception as e: |
| print(f"β Error starting application: {e}") |
|
|
| if __name__ == "__main__": |
| main() |