Spaces:
Running
Running
| import gradio as gr | |
| import tempfile | |
| import imageio | |
| import torch | |
| import time | |
| import os | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from diffusers import DiffusionPipeline | |
| # ------------------------------------------------- | |
| # Device Setup | |
| # ------------------------------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| text_model_cache = {} | |
| chat_memory = {} | |
| # ------------------------------------------------- | |
| # Available Text Models | |
| # ------------------------------------------------- | |
| AVAILABLE_MODELS = { | |
| "Codette LoRA (Llama-3.1)": "codette_lora", | |
| "Mistral-7B Instruct": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Phi-3 Mini": "microsoft/phi-3-mini-4k-instruct", | |
| "GPT-2 (lightweight)": "gpt2" | |
| } | |
| # ------------------------------------------------- | |
| # Load Codette LoRA Adapter | |
| # ------------------------------------------------- | |
| def load_codette_lora(): | |
| base_model = "meta-llama/Meta-Llama-3.1-8B" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| "Raiff1982/codette-lora-adapters" | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device_map="auto" | |
| ) | |
| return pipe | |
| # ------------------------------------------------- | |
| # Image Generator (SDXL Turbo) | |
| # ------------------------------------------------- | |
| try: | |
| image_generator = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| variant="fp16" if device == "cuda" else None | |
| ) | |
| image_generator.to(device) | |
| image_enabled = True | |
| except Exception as e: | |
| print(f"[Image Model Load Error]: {e}") | |
| image_generator = None | |
| image_enabled = False | |
| # ------------------------------------------------- | |
| # Video Generator (Zeroscope) | |
| # ------------------------------------------------- | |
| try: | |
| video_pipeline = DiffusionPipeline.from_pretrained( | |
| "cerspense/zeroscope_v2_576w", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ) | |
| video_pipeline.to(device) | |
| video_enabled = True | |
| except Exception as e: | |
| print(f"[Video Model Load Error]: {e}") | |
| video_pipeline = None | |
| video_enabled = False | |
| # ------------------------------------------------- | |
| # Load Text Models | |
| # ------------------------------------------------- | |
| def get_text_model(model_name): | |
| if model_name not in text_model_cache: | |
| if AVAILABLE_MODELS[model_name] == "codette_lora": | |
| text_model_cache[model_name] = load_codette_lora() | |
| else: | |
| text_model_cache[model_name] = pipeline( | |
| "text-generation", | |
| model=AVAILABLE_MODELS[model_name], | |
| device=0 if device == "cuda" else -1 | |
| ) | |
| return text_model_cache[model_name] | |
| # ------------------------------------------------- | |
| # Codette Terminal Logic | |
| # ------------------------------------------------- | |
| def codette_terminal(prompt, model_name, generate_image, generate_video, | |
| session_id, batch_size, video_steps, fps): | |
| if session_id not in chat_memory: | |
| chat_memory[session_id] = [] | |
| if prompt.lower() in ["exit", "quit"]: | |
| chat_memory[session_id] = [] | |
| yield "🧠 Codette session reset.", None, None | |
| return | |
| try: | |
| model = get_text_model(model_name) | |
| result = model( | |
| prompt, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| output = result[0]["generated_text"] | |
| except Exception as e: | |
| yield f"[Text generation error]: {e}", None, None | |
| return | |
| # ------------------------------------------------- | |
| # Stream Text | |
| # ------------------------------------------------- | |
| response_so_far = "" | |
| for char in output: | |
| response_so_far += char | |
| temp_log = chat_memory[session_id][:] | |
| temp_log.append(f"🖋️ You > {prompt}") | |
| temp_log.append(f"🧠 Codette > {response_so_far}") | |
| yield "\n".join(temp_log[-10:]), None, None | |
| time.sleep(0.01) | |
| chat_memory[session_id].append(f"🖋️ You > {prompt}") | |
| chat_memory[session_id].append(f"🧠 Codette > {output}") | |
| imgs = None | |
| vid = None | |
| # ------------------------------------------------- | |
| # Image Generation | |
| # ------------------------------------------------- | |
| if generate_image and image_enabled: | |
| try: | |
| result = image_generator( | |
| prompt, | |
| num_images_per_prompt=batch_size, | |
| num_inference_steps=2 | |
| ) | |
| imgs = result.images | |
| except Exception as e: | |
| print(f"[Image error]: {e}") | |
| # ------------------------------------------------- | |
| # Video Generation | |
| # ------------------------------------------------- | |
| if generate_video and video_enabled: | |
| try: | |
| result = video_pipeline( | |
| prompt, | |
| num_inference_steps=video_steps | |
| ) | |
| frames = result.frames | |
| temp_video_path = tempfile.NamedTemporaryFile( | |
| suffix=".mp4", | |
| delete=False | |
| ).name | |
| imageio.mimsave(temp_video_path, frames, fps=fps) | |
| vid = temp_video_path | |
| except Exception as e: | |
| print(f"[Video error]: {e}") | |
| yield "\n".join(chat_memory[session_id][-10:]), imgs, vid | |
| # ------------------------------------------------- | |
| # Gradio Interface | |
| # ------------------------------------------------- | |
| with gr.Blocks(title="🧬 Codette Terminal") as demo: | |
| gr.Markdown("## 🧬 Codette Terminal") | |
| gr.Markdown("Chat with Codette, generate images, and create short videos.") | |
| session_id = gr.Textbox(value="default_session", visible=False) | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value="Codette LoRA (Llama-3.1)", | |
| label="Language Model" | |
| ) | |
| with gr.Row(): | |
| generate_image_toggle = gr.Checkbox( | |
| label="Generate Image(s)", | |
| value=False, | |
| interactive=image_enabled | |
| ) | |
| generate_video_toggle = gr.Checkbox( | |
| label="Generate Video", | |
| value=False, | |
| interactive=video_enabled | |
| ) | |
| with gr.Row(): | |
| batch_size_slider = gr.Slider( | |
| label="Number of Images", | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=1 | |
| ) | |
| video_steps_slider = gr.Slider( | |
| label="Video Inference Steps", | |
| minimum=10, | |
| maximum=50, | |
| step=10, | |
| value=20 | |
| ) | |
| fps_slider = gr.Slider( | |
| label="Video FPS", | |
| minimum=4, | |
| maximum=24, | |
| step=2, | |
| value=8 | |
| ) | |
| user_input = gr.Textbox( | |
| label="Your Prompt", | |
| placeholder="A robot dreaming on Mars...", | |
| lines=1 | |
| ) | |
| output_text = gr.Textbox( | |
| label="Codette Output", | |
| lines=15, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| output_image = gr.Gallery( | |
| label="Generated Images", | |
| columns=2 | |
| ) | |
| output_video = gr.Video( | |
| label="Generated Video" | |
| ) | |
| user_input.submit( | |
| codette_terminal, | |
| inputs=[ | |
| user_input, | |
| model_dropdown, | |
| generate_image_toggle, | |
| generate_video_toggle, | |
| session_id, | |
| batch_size_slider, | |
| video_steps_slider, | |
| fps_slider | |
| ], | |
| outputs=[ | |
| output_text, | |
| output_image, | |
| output_video | |
| ] | |
| ) | |
| # ------------------------------------------------- | |
| # Launch | |
| # ------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() |