AIChatMateDev / app_backup.py
FrederickSundeep's picture
Create app_backup.py
1c890c3 verified
import os
import time
import nltk
import numpy as np
import random
from flask import Flask, request, render_template, jsonify, Response
from flasgger import Swagger, swag_from
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub import login
from langchain_community.tools import DuckDuckGoSearchRun
from nltk import word_tokenize, pos_tag
from diffusers import StableDiffusionPipeline, SD3Transformer2DModel, FlowMatchEulerDiscreteScheduler, EulerDiscreteScheduler
from io import BytesIO
import base64
# ✅ Safe import of GPU decorator
try:
from spaces import GPU
except ImportError:
def GPU(func): return func
import torch
# Download required NLTK data (ensure it runs once)
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download("averaged_perceptron_tagger")
nltk.download('averaged_perceptron_tagger_eng')
# Flask + Swagger setup
app = Flask(__name__, static_folder="static", template_folder="templates")
swagger = Swagger(app, template={
"swagger": "2.0",
"info": {
"title": "ChatMate Real-Time API",
"description": "LangChain + DuckDuckGo enabled AI chatbot",
"version": "1.0"
}
}, config={
"headers": [],
"specs": [{"endpoint": 'apispec', "route": '/apispec.json', "rule_filter": lambda rule: True}],
"static_url_path": "/flasgger_static",
"swagger_ui": True,
"specs_route": "/apidocs/"
})
# ✅ Hugging Face login (if token provided)
login(token=os.environ.get("CHAT_MATE"))
def load_image_pipeline():
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
safety_checker=None # disable safety checker to avoid blocking
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
return pipe
image_pipe = load_image_pipeline()
def generate_image_base64(prompt):
if not prompt.strip():
prompt = "A scenic view of mountains during sunrise"
try:
print(f"🎨 Generating image for prompt: {prompt}")
generator = torch.Generator(device=image_pipe.device).manual_seed(random.randint(0, 2**31 - 1))
result = image_pipe(
prompt,
height=512,
width=512,
guidance_scale=7.5,
num_inference_steps=40,
generator=generator
)
image = result.images[0]
buffer = BytesIO()
image.save(buffer, format="PNG")
img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
print("✅ Image generated successfully.", img_b64)
return img_b64
except Exception as e:
print(f"⚠️ Image generation failed: {e}")
return ""
# ✅ Prompt checker
IMAGE_PROMPT_KEYWORDS = {"generate image", "create image", "make image", "draw", "illustrate", "paint", "show me", "visualize"}
def is_image_prompt(text):
return any(kw in text.lower() for kw in IMAGE_PROMPT_KEYWORDS)
# ✅ Load LLaMA 3.1 Instruct model
#model_id = "meta-llama/Llama-3.1-8B-Instruct"
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",load_in_4bit=True,
torch_dtype=torch.float16,
trust_remote_code=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
# ✅ Simple keyword-based check for real-time info
REAL_TIME_KEYWORDS = {"latest", "current", "news", "today", "price", "time", "live", "trending", "update", "happening"}
def should_search(message):
message = message.lower()
return any(kw in message for kw in REAL_TIME_KEYWORDS)
def should_search_intent_nltk(message):
question_starters = {"what", "when", "where", "who", "how", "which", "why"}
tokens = word_tokenize(message.lower())
tags = pos_tag(tokens)
starts_with_question = tokens[0] in question_starters
has_nouns_or_verbs = any(tag.startswith("NN") or tag.startswith("VB") for _, tag in tags)
return starts_with_question and has_nouns_or_verbs
# ✅ Search tool
search_tool = DuckDuckGoSearchRun()
# ✅ Chat using model with chat template and history
@GPU
def generate_full_reply(message, history):
system_prompt = "You are a helpful AI assistant."
messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = pipe(prompt, do_sample=True, temperature=0.7)[0]["generated_text"]
#print(output)
return output.split(prompt)[-1].strip()
# ✅ Flask route
@app.route("/")
def home():
return render_template("index.html")
@app.route("/chat", methods=["POST"])
@swag_from({
'tags': ['Chat'],
'consumes': ['application/json'],
'summary': 'Get assistant reply',
'description': 'Send a message and chat history, and receive a full AI-generated response.',
'parameters': [{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'message': {'type': 'string', 'example': 'What is Python?'},
'history': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'role': {'type': 'string', 'example': 'user'},
'content': {'type': 'string', 'example': 'Tell me about Python'}
}
}
}
},
'required': ['message']
}
}],
'responses': {
200: {
'description': 'Assistant reply',
'schema': {
'type': 'object',
'properties': {
'reply': {'type': 'string'}
}
}
}
}
})
def chat():
data = request.get_json()
message = data.get("message")
history = data.get("history", [])
# Check if real-time search is needed
if should_search(message) or should_search_intent_nltk(message):
result = f"(Live info) {search_tool.run(message)}"
else:
result = generate_full_reply(message, history)
return jsonify({"reply": result})
@app.route("/chat-stream", methods=["POST"])
@swag_from({
'tags': ['Chat'],
'consumes': ['application/json'],
'summary': 'Stream assistant reply or generate image',
'description': 'Respond with AI-generated text or image based on prompt intent.',
'parameters': [{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'message': {'type': 'string', 'example': 'Draw a cyberpunk city at night'},
'history': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'role': {'type': 'string', 'example': 'user'},
'content': {'type': 'string', 'example': 'What is entanglement?'}
}
}
}
},
'required': ['message']
}
}],
'responses': {
200: {
'description': 'Streamed text or image',
'content': {'text/plain': {}}
}
}
})
def chat_stream():
data = request.get_json()
message = data.get("message")
history = data.get("history", [])
def generate():
if is_image_prompt(message):
yield "[🖼️ Generating image...]\n"
image_b64 = generate_image_base64(message)
yield f"[BASE64_IMAGE_START]{image_b64}[BASE64_IMAGE_END]\n"
elif should_search(message) or should_search_intent_nltk(message):
reply = f"(Live info) {search_tool.run(message)}"
for token in reply.splitlines(keepends=True):
yield token
time.sleep(0.05)
else:
reply = generate_full_reply(message, history)
for token in reply.splitlines(keepends=True):
yield token
time.sleep(0.05)
return Response(generate(), mimetype='text/plain')
# ✅ Warm-up on startup
if __name__ == "__main__":
print("🔧 Warming up...")
_ = generate_full_reply("Hello", [])
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))