import gradio as gr import tempfile import imageio import torch import time import os from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from diffusers import DiffusionPipeline # ------------------------------------------------- # Device Setup # ------------------------------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" text_model_cache = {} chat_memory = {} # ------------------------------------------------- # Available Text Models # ------------------------------------------------- AVAILABLE_MODELS = { "Codette LoRA (Llama-3.1)": "codette_lora", "Mistral-7B Instruct": "mistralai/Mistral-7B-Instruct-v0.2", "Phi-3 Mini": "microsoft/phi-3-mini-4k-instruct", "GPT-2 (lightweight)": "gpt2" } # ------------------------------------------------- # Load Codette LoRA Adapter # ------------------------------------------------- def load_codette_lora(): base_model = "meta-llama/Meta-Llama-3.1-8B" tokenizer = AutoTokenizer.from_pretrained(base_model) model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" ) model = PeftModel.from_pretrained( model, "Raiff1982/codette-lora-adapters" ) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="auto" ) return pipe # ------------------------------------------------- # Image Generator (SDXL Turbo) # ------------------------------------------------- try: image_generator = DiffusionPipeline.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch.float16 if device == "cuda" else torch.float32, variant="fp16" if device == "cuda" else None ) image_generator.to(device) image_enabled = True except Exception as e: print(f"[Image Model Load Error]: {e}") image_generator = None image_enabled = False # ------------------------------------------------- # Video Generator (Zeroscope) # ------------------------------------------------- try: video_pipeline = DiffusionPipeline.from_pretrained( "cerspense/zeroscope_v2_576w", torch_dtype=torch.float16 if device == "cuda" else torch.float32 ) video_pipeline.to(device) video_enabled = True except Exception as e: print(f"[Video Model Load Error]: {e}") video_pipeline = None video_enabled = False # ------------------------------------------------- # Load Text Models # ------------------------------------------------- def get_text_model(model_name): if model_name not in text_model_cache: if AVAILABLE_MODELS[model_name] == "codette_lora": text_model_cache[model_name] = load_codette_lora() else: text_model_cache[model_name] = pipeline( "text-generation", model=AVAILABLE_MODELS[model_name], device=0 if device == "cuda" else -1 ) return text_model_cache[model_name] # ------------------------------------------------- # Codette Terminal Logic # ------------------------------------------------- def codette_terminal(prompt, model_name, generate_image, generate_video, session_id, batch_size, video_steps, fps): if session_id not in chat_memory: chat_memory[session_id] = [] if prompt.lower() in ["exit", "quit"]: chat_memory[session_id] = [] yield "🧠 Codette session reset.", None, None return try: model = get_text_model(model_name) result = model( prompt, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True ) output = result[0]["generated_text"] except Exception as e: yield f"[Text generation error]: {e}", None, None return # ------------------------------------------------- # Stream Text # ------------------------------------------------- response_so_far = "" for char in output: response_so_far += char temp_log = chat_memory[session_id][:] temp_log.append(f"🖋️ You > {prompt}") temp_log.append(f"🧠 Codette > {response_so_far}") yield "\n".join(temp_log[-10:]), None, None time.sleep(0.01) chat_memory[session_id].append(f"🖋️ You > {prompt}") chat_memory[session_id].append(f"🧠 Codette > {output}") imgs = None vid = None # ------------------------------------------------- # Image Generation # ------------------------------------------------- if generate_image and image_enabled: try: result = image_generator( prompt, num_images_per_prompt=batch_size, num_inference_steps=2 ) imgs = result.images except Exception as e: print(f"[Image error]: {e}") # ------------------------------------------------- # Video Generation # ------------------------------------------------- if generate_video and video_enabled: try: result = video_pipeline( prompt, num_inference_steps=video_steps ) frames = result.frames temp_video_path = tempfile.NamedTemporaryFile( suffix=".mp4", delete=False ).name imageio.mimsave(temp_video_path, frames, fps=fps) vid = temp_video_path except Exception as e: print(f"[Video error]: {e}") yield "\n".join(chat_memory[session_id][-10:]), imgs, vid # ------------------------------------------------- # Gradio Interface # ------------------------------------------------- with gr.Blocks(title="🧬 Codette Terminal") as demo: gr.Markdown("## 🧬 Codette Terminal") gr.Markdown("Chat with Codette, generate images, and create short videos.") session_id = gr.Textbox(value="default_session", visible=False) with gr.Row(): model_dropdown = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="Codette LoRA (Llama-3.1)", label="Language Model" ) with gr.Row(): generate_image_toggle = gr.Checkbox( label="Generate Image(s)", value=False, interactive=image_enabled ) generate_video_toggle = gr.Checkbox( label="Generate Video", value=False, interactive=video_enabled ) with gr.Row(): batch_size_slider = gr.Slider( label="Number of Images", minimum=1, maximum=4, step=1, value=1 ) video_steps_slider = gr.Slider( label="Video Inference Steps", minimum=10, maximum=50, step=10, value=20 ) fps_slider = gr.Slider( label="Video FPS", minimum=4, maximum=24, step=2, value=8 ) user_input = gr.Textbox( label="Your Prompt", placeholder="A robot dreaming on Mars...", lines=1 ) output_text = gr.Textbox( label="Codette Output", lines=15, interactive=False ) with gr.Row(): output_image = gr.Gallery( label="Generated Images", columns=2 ) output_video = gr.Video( label="Generated Video" ) user_input.submit( codette_terminal, inputs=[ user_input, model_dropdown, generate_image_toggle, generate_video_toggle, session_id, batch_size_slider, video_steps_slider, fps_slider ], outputs=[ output_text, output_image, output_video ] ) # ------------------------------------------------- # Launch # ------------------------------------------------- if __name__ == "__main__": demo.launch()