CodetteVision / app.py
Raiff1982's picture
Update app.py
6e1a42f verified
import gradio as gr
import tempfile
import imageio
import torch
import time
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from diffusers import DiffusionPipeline
# -------------------------------------------------
# Device Setup
# -------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
text_model_cache = {}
chat_memory = {}
# -------------------------------------------------
# Available Text Models
# -------------------------------------------------
AVAILABLE_MODELS = {
"Codette LoRA (Llama-3.1)": "codette_lora",
"Mistral-7B Instruct": "mistralai/Mistral-7B-Instruct-v0.2",
"Phi-3 Mini": "microsoft/phi-3-mini-4k-instruct",
"GPT-2 (lightweight)": "gpt2"
}
# -------------------------------------------------
# Load Codette LoRA Adapter
# -------------------------------------------------
def load_codette_lora():
base_model = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto"
)
model = PeftModel.from_pretrained(
model,
"Raiff1982/codette-lora-adapters"
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
return pipe
# -------------------------------------------------
# Image Generator (SDXL Turbo)
# -------------------------------------------------
try:
image_generator = DiffusionPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
variant="fp16" if device == "cuda" else None
)
image_generator.to(device)
image_enabled = True
except Exception as e:
print(f"[Image Model Load Error]: {e}")
image_generator = None
image_enabled = False
# -------------------------------------------------
# Video Generator (Zeroscope)
# -------------------------------------------------
try:
video_pipeline = DiffusionPipeline.from_pretrained(
"cerspense/zeroscope_v2_576w",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
video_pipeline.to(device)
video_enabled = True
except Exception as e:
print(f"[Video Model Load Error]: {e}")
video_pipeline = None
video_enabled = False
# -------------------------------------------------
# Load Text Models
# -------------------------------------------------
def get_text_model(model_name):
if model_name not in text_model_cache:
if AVAILABLE_MODELS[model_name] == "codette_lora":
text_model_cache[model_name] = load_codette_lora()
else:
text_model_cache[model_name] = pipeline(
"text-generation",
model=AVAILABLE_MODELS[model_name],
device=0 if device == "cuda" else -1
)
return text_model_cache[model_name]
# -------------------------------------------------
# Codette Terminal Logic
# -------------------------------------------------
def codette_terminal(prompt, model_name, generate_image, generate_video,
session_id, batch_size, video_steps, fps):
if session_id not in chat_memory:
chat_memory[session_id] = []
if prompt.lower() in ["exit", "quit"]:
chat_memory[session_id] = []
yield "🧠 Codette session reset.", None, None
return
try:
model = get_text_model(model_name)
result = model(
prompt,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True
)
output = result[0]["generated_text"]
except Exception as e:
yield f"[Text generation error]: {e}", None, None
return
# -------------------------------------------------
# Stream Text
# -------------------------------------------------
response_so_far = ""
for char in output:
response_so_far += char
temp_log = chat_memory[session_id][:]
temp_log.append(f"🖋️ You > {prompt}")
temp_log.append(f"🧠 Codette > {response_so_far}")
yield "\n".join(temp_log[-10:]), None, None
time.sleep(0.01)
chat_memory[session_id].append(f"🖋️ You > {prompt}")
chat_memory[session_id].append(f"🧠 Codette > {output}")
imgs = None
vid = None
# -------------------------------------------------
# Image Generation
# -------------------------------------------------
if generate_image and image_enabled:
try:
result = image_generator(
prompt,
num_images_per_prompt=batch_size,
num_inference_steps=2
)
imgs = result.images
except Exception as e:
print(f"[Image error]: {e}")
# -------------------------------------------------
# Video Generation
# -------------------------------------------------
if generate_video and video_enabled:
try:
result = video_pipeline(
prompt,
num_inference_steps=video_steps
)
frames = result.frames
temp_video_path = tempfile.NamedTemporaryFile(
suffix=".mp4",
delete=False
).name
imageio.mimsave(temp_video_path, frames, fps=fps)
vid = temp_video_path
except Exception as e:
print(f"[Video error]: {e}")
yield "\n".join(chat_memory[session_id][-10:]), imgs, vid
# -------------------------------------------------
# Gradio Interface
# -------------------------------------------------
with gr.Blocks(title="🧬 Codette Terminal") as demo:
gr.Markdown("## 🧬 Codette Terminal")
gr.Markdown("Chat with Codette, generate images, and create short videos.")
session_id = gr.Textbox(value="default_session", visible=False)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="Codette LoRA (Llama-3.1)",
label="Language Model"
)
with gr.Row():
generate_image_toggle = gr.Checkbox(
label="Generate Image(s)",
value=False,
interactive=image_enabled
)
generate_video_toggle = gr.Checkbox(
label="Generate Video",
value=False,
interactive=video_enabled
)
with gr.Row():
batch_size_slider = gr.Slider(
label="Number of Images",
minimum=1,
maximum=4,
step=1,
value=1
)
video_steps_slider = gr.Slider(
label="Video Inference Steps",
minimum=10,
maximum=50,
step=10,
value=20
)
fps_slider = gr.Slider(
label="Video FPS",
minimum=4,
maximum=24,
step=2,
value=8
)
user_input = gr.Textbox(
label="Your Prompt",
placeholder="A robot dreaming on Mars...",
lines=1
)
output_text = gr.Textbox(
label="Codette Output",
lines=15,
interactive=False
)
with gr.Row():
output_image = gr.Gallery(
label="Generated Images",
columns=2
)
output_video = gr.Video(
label="Generated Video"
)
user_input.submit(
codette_terminal,
inputs=[
user_input,
model_dropdown,
generate_image_toggle,
generate_video_toggle,
session_id,
batch_size_slider,
video_steps_slider,
fps_slider
],
outputs=[
output_text,
output_image,
output_video
]
)
# -------------------------------------------------
# Launch
# -------------------------------------------------
if __name__ == "__main__":
demo.launch()