Spaces:
Running on Zero
Running on Zero
| import os | |
| import time | |
| import nltk | |
| import numpy as np | |
| import random | |
| from flask import Flask, request, render_template, jsonify, Response | |
| from flasgger import Swagger, swag_from | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from huggingface_hub import login | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from nltk import word_tokenize, pos_tag | |
| from diffusers import StableDiffusionPipeline, SD3Transformer2DModel, FlowMatchEulerDiscreteScheduler, EulerDiscreteScheduler | |
| from io import BytesIO | |
| import base64 | |
| # ✅ Safe import of GPU decorator | |
| try: | |
| from spaces import GPU | |
| except ImportError: | |
| def GPU(func): return func | |
| import torch | |
| # Download required NLTK data (ensure it runs once) | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| nltk.download("averaged_perceptron_tagger") | |
| nltk.download('averaged_perceptron_tagger_eng') | |
| # Flask + Swagger setup | |
| app = Flask(__name__, static_folder="static", template_folder="templates") | |
| swagger = Swagger(app, template={ | |
| "swagger": "2.0", | |
| "info": { | |
| "title": "ChatMate Real-Time API", | |
| "description": "LangChain + DuckDuckGo enabled AI chatbot", | |
| "version": "1.0" | |
| } | |
| }, config={ | |
| "headers": [], | |
| "specs": [{"endpoint": 'apispec', "route": '/apispec.json', "rule_filter": lambda rule: True}], | |
| "static_url_path": "/flasgger_static", | |
| "swagger_ui": True, | |
| "specs_route": "/apidocs/" | |
| }) | |
| # ✅ Hugging Face login (if token provided) | |
| login(token=os.environ.get("CHAT_MATE")) | |
| def load_image_pipeline(): | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| safety_checker=None # disable safety checker to avoid blocking | |
| ) | |
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(device) | |
| pipe.enable_attention_slicing() | |
| return pipe | |
| image_pipe = load_image_pipeline() | |
| def generate_image_base64(prompt): | |
| if not prompt.strip(): | |
| prompt = "A scenic view of mountains during sunrise" | |
| try: | |
| print(f"🎨 Generating image for prompt: {prompt}") | |
| generator = torch.Generator(device=image_pipe.device).manual_seed(random.randint(0, 2**31 - 1)) | |
| result = image_pipe( | |
| prompt, | |
| height=512, | |
| width=512, | |
| guidance_scale=7.5, | |
| num_inference_steps=40, | |
| generator=generator | |
| ) | |
| image = result.images[0] | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| print("✅ Image generated successfully.", img_b64) | |
| return img_b64 | |
| except Exception as e: | |
| print(f"⚠️ Image generation failed: {e}") | |
| return "" | |
| # ✅ Prompt checker | |
| IMAGE_PROMPT_KEYWORDS = {"generate image", "create image", "make image", "draw", "illustrate", "paint", "show me", "visualize"} | |
| def is_image_prompt(text): | |
| return any(kw in text.lower() for kw in IMAGE_PROMPT_KEYWORDS) | |
| # ✅ Load LLaMA 3.1 Instruct model | |
| #model_id = "meta-llama/Llama-3.1-8B-Instruct" | |
| model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",load_in_4bit=True, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True) | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512) | |
| # ✅ Simple keyword-based check for real-time info | |
| REAL_TIME_KEYWORDS = {"latest", "current", "news", "today", "price", "time", "live", "trending", "update", "happening"} | |
| def should_search(message): | |
| message = message.lower() | |
| return any(kw in message for kw in REAL_TIME_KEYWORDS) | |
| def should_search_intent_nltk(message): | |
| question_starters = {"what", "when", "where", "who", "how", "which", "why"} | |
| tokens = word_tokenize(message.lower()) | |
| tags = pos_tag(tokens) | |
| starts_with_question = tokens[0] in question_starters | |
| has_nouns_or_verbs = any(tag.startswith("NN") or tag.startswith("VB") for _, tag in tags) | |
| return starts_with_question and has_nouns_or_verbs | |
| # ✅ Search tool | |
| search_tool = DuckDuckGoSearchRun() | |
| # ✅ Chat using model with chat template and history | |
| def generate_full_reply(message, history): | |
| system_prompt = "You are a helpful AI assistant." | |
| messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": message}] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| output = pipe(prompt, do_sample=True, temperature=0.7)[0]["generated_text"] | |
| #print(output) | |
| return output.split(prompt)[-1].strip() | |
| # ✅ Flask route | |
| def home(): | |
| return render_template("index.html") | |
| def chat(): | |
| data = request.get_json() | |
| message = data.get("message") | |
| history = data.get("history", []) | |
| # Check if real-time search is needed | |
| if should_search(message) or should_search_intent_nltk(message): | |
| result = f"(Live info) {search_tool.run(message)}" | |
| else: | |
| result = generate_full_reply(message, history) | |
| return jsonify({"reply": result}) | |
| def chat_stream(): | |
| data = request.get_json() | |
| message = data.get("message") | |
| history = data.get("history", []) | |
| def generate(): | |
| if is_image_prompt(message): | |
| yield "[🖼️ Generating image...]\n" | |
| image_b64 = generate_image_base64(message) | |
| yield f"[BASE64_IMAGE_START]{image_b64}[BASE64_IMAGE_END]\n" | |
| elif should_search(message) or should_search_intent_nltk(message): | |
| reply = f"(Live info) {search_tool.run(message)}" | |
| for token in reply.splitlines(keepends=True): | |
| yield token | |
| time.sleep(0.05) | |
| else: | |
| reply = generate_full_reply(message, history) | |
| for token in reply.splitlines(keepends=True): | |
| yield token | |
| time.sleep(0.05) | |
| return Response(generate(), mimetype='text/plain') | |
| # ✅ Warm-up on startup | |
| if __name__ == "__main__": | |
| print("🔧 Warming up...") | |
| _ = generate_full_reply("Hello", []) | |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |