Spaces:
Running on Zero
Running on Zero
Kyle Pearson
Refactor entry point, consolidate app creation, update main block, simplify Gradio init
d723e62 | #!/usr/bin/env python3 | |
| """ | |
| SDXL Model Merger - Modernized with modular architecture and improved UI/UX. | |
| This application allows you to: | |
| - Load SDXL checkpoints with optional VAE and multiple LoRAs | |
| - Generate images with seamless tiling support | |
| - Export merged models with quantization options | |
| Author: Qwen Code Assistant | |
| """ | |
| try: | |
| import spaces # noqa: F401 — must be imported before torch/CUDA packages | |
| except ImportError: | |
| pass | |
| import gradio as gr | |
| def create_app(): | |
| """Create and configure the Gradio app.""" | |
| header_css = """ | |
| .header-gradient { | |
| background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| .feature-card { | |
| border-radius: 12px; | |
| padding: 20px; | |
| margin-bottom: 16px; | |
| box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); | |
| transition: transform 0.2s ease; | |
| } | |
| .feature-card:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); | |
| } | |
| .gradio-container .label { | |
| font-weight: 600; | |
| color: #374151; | |
| margin-bottom: 8px; | |
| } | |
| .status-success { color: #059669 !important; font-weight: 600; } | |
| .status-error { color: #dc2626 !important; font-weight: 600; } | |
| .status-warning { color: #d97706 !important; font-weight: 600; } | |
| .gradio-container .btn { | |
| border-radius: 8px; | |
| padding: 12px 24px; | |
| font-weight: 600; | |
| } | |
| .gradio-container textarea, | |
| .gradio-container input[type="number"], | |
| .gradio-container input[type="text"] { | |
| border-radius: 8px; | |
| border-color: #d1d5db; | |
| } | |
| .gradio-container textarea:focus, | |
| .gradio-container input:focus { | |
| outline: none; | |
| border-color: #6366f1; | |
| box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1); | |
| } | |
| .gradio-container .tabitem { | |
| background: transparent; | |
| border-radius: 12px; | |
| } | |
| .progress-text { | |
| font-weight: 500; | |
| color: #6b7280 !important; | |
| } | |
| """ | |
| from src.pipeline import load_pipeline | |
| from src.generator import generate_image | |
| from src.exporter import export_merged_model | |
| from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras | |
| with gr.Blocks(title="SDXL Model Merger") as demo: | |
| # Header section | |
| with gr.Column(elem_classes=["feature-card"]): | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 24px;"> | |
| <h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;"> | |
| <span class="header-gradient">SDXL Model Merger</span> | |
| </h1> | |
| <p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;"> | |
| Merge checkpoints, LoRAs, and VAEs - then bake LoRAs into a single exportable | |
| checkpoint with optional quantization. | |
| </p> | |
| </div> | |
| """) | |
| # Feature highlights | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 16px;"> | |
| <div style="font-size: 2.5em; margin-bottom: 8px;">🚀</div> | |
| <strong>Fast Loading</strong> | |
| <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p> | |
| </div> | |
| """) | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 16px;"> | |
| <div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div> | |
| <strong>Panorama Gen</strong> | |
| <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p> | |
| </div> | |
| """) | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 16px;"> | |
| <div style="font-size: 2.5em; margin-bottom: 8px;">📦</div> | |
| <strong>Export Ready</strong> | |
| <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p> | |
| </div> | |
| """) | |
| gr.Markdown("---") | |
| with gr.Tab("Load Pipeline"): | |
| gr.Markdown("### Load SDXL Pipeline with Checkpoint, VAE, and LoRAs") | |
| # Progress indicator for pipeline loading | |
| load_progress = gr.Textbox( | |
| label="Loading Progress", | |
| placeholder="Ready to start...", | |
| show_label=True, | |
| info="Real-time status of model downloads and pipeline setup" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Checkpoint URL with cached models dropdown | |
| checkpoint_url = gr.Textbox( | |
| label="Base Model (.safetensors) URL", | |
| value="https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16", | |
| placeholder="e.g., https://civitai.com/api/download/models/...", | |
| info="Download link for the base SDXL checkpoint" | |
| ) | |
| # Dropdown of cached checkpoints | |
| cached_checkpoints = gr.Dropdown( | |
| choices=["(None found)"] + get_cached_checkpoints(), | |
| label="Cached Checkpoints", | |
| value="(None found)" if not get_cached_checkpoints() else None, | |
| info="Models already downloaded to .cache/" | |
| ) | |
| # VAE URL | |
| vae_url = gr.Textbox( | |
| label="VAE (.safetensors) URL", | |
| value="https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true", | |
| placeholder="Leave blank to use model's built-in VAE", | |
| info="Optional custom VAE for improved quality" | |
| ) | |
| # Dropdown of cached VAEs | |
| cached_vaes = gr.Dropdown( | |
| choices=["(None found)"] + get_cached_vaes(), | |
| label="Cached VAEs", | |
| value="(None found)" if not get_cached_vaes() else None, | |
| info="Select a VAE to load" | |
| ) | |
| with gr.Column(scale=1): | |
| # LoRA URLs input | |
| lora_urls = gr.Textbox( | |
| label="LoRA URLs (one per line)", | |
| lines=5, | |
| value="https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor", | |
| placeholder="https://civit.ai/...\nhttps://huggingface.co/...", | |
| info="Multiple LoRAs can be loaded and fused together" | |
| ) | |
| # Dropdown of cached LoRAs | |
| cached_loras = gr.Dropdown( | |
| choices=["(None found)"] + get_cached_loras(), | |
| label="Cached LoRAs", | |
| value="(None found)" if not get_cached_loras() else None, | |
| info="Select a LoRA to add to the list below" | |
| ) | |
| lora_strengths = gr.Textbox( | |
| label="LoRA Strengths", | |
| value="1.0", | |
| placeholder="e.g., 0.8,1.0,0.5", | |
| info="Comma-separated strength values for each LoRA" | |
| ) | |
| with gr.Row(): | |
| load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg") | |
| # Detailed status display | |
| load_status = gr.HTML( | |
| label="Status", | |
| value='<div class="status-success">✅ Ready to load pipeline</div>', | |
| ) | |
| with gr.Tab("Generate Image"): | |
| gr.Markdown("### Generate Panorama Images with Seamless Tiling") | |
| # Progress indicator for image generation | |
| gen_progress = gr.Textbox( | |
| label="Generation Progress", | |
| placeholder="Ready to generate...", | |
| show_label=True, | |
| info="Real-time status of image generation" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Positive Prompt", | |
| value="Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic", | |
| lines=4, | |
| placeholder="Describe the image you want to generate..." | |
| ) | |
| cfg = gr.Slider( | |
| minimum=1.0, maximum=20.0, value=3.0, step=0.5, | |
| label="CFG Scale", | |
| info="Higher values make outputs match prompt more strictly" | |
| ) | |
| height = gr.Number( | |
| value=1024, precision=0, | |
| label="Height (pixels)", | |
| info="Output image height" | |
| ) | |
| with gr.Column(scale=1): | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="boring, text, signature, watermark, low quality, bad quality", | |
| lines=4, | |
| placeholder="Elements to avoid in generation..." | |
| ) | |
| steps = gr.Slider( | |
| minimum=1, maximum=100, value=8, step=1, | |
| label="Inference Steps", | |
| info="More steps = better quality but slower" | |
| ) | |
| width = gr.Number( | |
| value=2048, precision=0, | |
| label="Width (pixels)", | |
| info="Output image width" | |
| ) | |
| with gr.Row(): | |
| tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling") | |
| tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling") | |
| seed = gr.Number( | |
| value=80484030936239, | |
| precision=0, | |
| label="Seed", | |
| info="Random seed for reproducible generation" | |
| ) | |
| with gr.Row(): | |
| gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg") | |
| with gr.Row(): | |
| image_output = gr.Image( | |
| label="Result", | |
| height=400, | |
| show_label=True | |
| ) | |
| with gr.Column(): | |
| gen_status = gr.HTML( | |
| label="Generation Status", | |
| value='<div class="status-success">✅ Ready to generate</div>', | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 16px; padding: 12px; background-color: #e5e7eb !important; border-radius: 8px;"> | |
| <strong style="color: #1f2937 !important;">💡 Tips:</strong> | |
| <ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em; color: #1f2937 !important;"> | |
| <li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li> | |
| <li>Enable seamless tiling for texture-like outputs</li> | |
| <li>Lower CFG (3-5) for more creative results</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Tab("Export Model"): | |
| gr.Markdown("### Export Merged Checkpoint with Quantization Options") | |
| # Progress indicator for export | |
| export_progress = gr.Textbox( | |
| label="Export Progress", | |
| placeholder="Ready to export...", | |
| show_label=True, | |
| info="Real-time status of model export and quantization" | |
| ) | |
| with gr.Row(): | |
| include_lora = gr.Checkbox( | |
| True, | |
| label="Include Fused LoRAs", | |
| info="Bake the loaded LoRAs into the exported model" | |
| ) | |
| quantize_toggle = gr.Checkbox( | |
| False, | |
| label="Apply Quantization", | |
| info="Reduce model size with quantization" | |
| ) | |
| qtype_row = gr.Row(visible=True) | |
| with qtype_row: | |
| qtype_dropdown = gr.Dropdown( | |
| choices=["none", "int8", "int4", "float8"], | |
| value="int8", | |
| label="Quantization Method", | |
| info="Trade quality for smaller file size" | |
| ) | |
| with gr.Row(): | |
| format_dropdown = gr.Dropdown( | |
| choices=["safetensors", "bin"], | |
| value="safetensors", | |
| label="Export Format", | |
| info="safetensors is recommended for safety" | |
| ) | |
| with gr.Row(): | |
| export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg") | |
| with gr.Row(): | |
| download_link = gr.File( | |
| label="Download Merged File", | |
| show_label=True, | |
| ) | |
| with gr.Column(): | |
| export_status = gr.HTML( | |
| label="Export Status", | |
| value='<div class="status-success">✅ Ready to export</div>', | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;"> | |
| <strong>ℹ️ About Quantization:</strong> | |
| <p style="font-size: 0.9em; margin: 8px 0;"> | |
| Reduces model size by lowering precision. Int8 is typically | |
| lossless for inference while cutting size in half. | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers - all inside Blocks context | |
| def on_load_pipeline_start(): | |
| """Called when pipeline loading starts.""" | |
| return ( | |
| '<div class="status-warning">⏳ Loading started...</div>', | |
| "Starting download...", | |
| gr.update(interactive=False) | |
| ) | |
| def on_load_pipeline_complete(status_msg, progress_text): | |
| """Called when pipeline loading completes.""" | |
| if "✅" in status_msg: | |
| return ( | |
| '<div class="status-success">✅ Pipeline loaded successfully!</div>', | |
| progress_text, | |
| gr.update(interactive=True) | |
| ) | |
| elif "⚠️" in status_msg or "cancelled" in status_msg.lower(): | |
| return ( | |
| '<div class="status-warning">⚠️ Download cancelled</div>', | |
| progress_text, | |
| gr.update(interactive=True) | |
| ) | |
| else: | |
| return ( | |
| f'<div class="status-error">{status_msg}</div>', | |
| progress_text, | |
| gr.update(interactive=True) | |
| ) | |
| load_btn.click( | |
| fn=on_load_pipeline_start, | |
| inputs=[], | |
| outputs=[load_status, load_progress, load_btn], | |
| ).then( | |
| fn=load_pipeline, | |
| inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths], | |
| outputs=[load_status, load_progress], | |
| show_progress="full", | |
| ).then( | |
| fn=on_load_pipeline_complete, | |
| inputs=[load_status, load_progress], | |
| outputs=[load_status, load_progress, load_btn], | |
| ).then( | |
| fn=lambda: ( | |
| gr.update(choices=["(None found)"] + get_cached_checkpoints()), | |
| gr.update(choices=["(None found)"] + get_cached_vaes()), | |
| gr.update(choices=["(None found)"] + get_cached_loras()), | |
| ), | |
| inputs=[], | |
| outputs=[cached_checkpoints, cached_vaes, cached_loras], | |
| ) | |
| def on_cached_checkpoint_change(cached_path): | |
| """Update URL when a cached checkpoint is selected.""" | |
| if cached_path and cached_path != "(None found)": | |
| return gr.update(value=f"file://{cached_path}") | |
| return gr.update() | |
| cached_checkpoints.change( | |
| fn=lambda x: gr.update(value=f"file://{x}" if x and x != "(None found)" else ""), | |
| inputs=cached_checkpoints, | |
| outputs=checkpoint_url, | |
| ) | |
| def on_cached_vae_change(cached_path): | |
| """Update VAE URL when a cached VAE is selected.""" | |
| if cached_path and cached_path != "(None found)": | |
| return gr.update(value=f"file://{cached_path}") | |
| return gr.update() | |
| cached_vaes.change( | |
| fn=on_cached_vae_change, | |
| inputs=cached_vaes, | |
| outputs=vae_url, | |
| ) | |
| def on_cached_lora_change(cached_path, current_urls): | |
| """Add cached LoRA to the list.""" | |
| if cached_path and cached_path != "(None found)": | |
| urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()] | |
| file_url = f"file://{cached_path}" | |
| if file_url not in urls_list: | |
| urls_list.append(file_url) | |
| return gr.update(value="\n".join(urls_list)) | |
| return gr.update() | |
| cached_loras.change( | |
| fn=on_cached_lora_change, | |
| inputs=[cached_loras, lora_urls], | |
| outputs=lora_urls, | |
| ) | |
| def on_generate_start(): | |
| """Called when image generation starts.""" | |
| return ( | |
| '<div class="status-warning">⏳ Generating image...</div>', | |
| "Starting generation...", | |
| gr.update(interactive=False) | |
| ) | |
| def on_generate_complete(status_msg, progress_text, image): | |
| """Called when image generation completes.""" | |
| if image is None: | |
| return ( | |
| f'<div class="status-error">{status_msg}</div>', | |
| "", | |
| gr.update(interactive=True), | |
| gr.update() | |
| ) | |
| else: | |
| return ( | |
| '<div class="status-success">✅ Generation complete!</div>', | |
| "Done", | |
| gr.update(interactive=True), | |
| gr.update(value=image) | |
| ) | |
| gen_btn.click( | |
| fn=on_generate_start, | |
| inputs=[], | |
| outputs=[gen_status, gen_progress, gen_btn], | |
| ).then( | |
| fn=generate_image, | |
| inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y, seed], | |
| outputs=[image_output, gen_progress], | |
| ).then( | |
| fn=lambda img, msg: on_generate_complete(msg, "Done", img), | |
| inputs=[image_output, gen_progress], | |
| outputs=[gen_status, gen_progress, gen_btn, image_output], | |
| ) | |
| def on_export_start(): | |
| """Called when export starts.""" | |
| return ( | |
| '<div class="status-warning">⏳ Export started...</div>', | |
| "Starting export...", | |
| gr.update(interactive=False) | |
| ) | |
| def on_export_complete(status_msg, progress_text, file_path): | |
| """Called when export completes.""" | |
| if file_path is None: | |
| return ( | |
| f'<div class="status-error">{status_msg}</div>', | |
| "", | |
| gr.update(interactive=True), | |
| gr.update(value=None) | |
| ) | |
| else: | |
| return ( | |
| '<div class="status-success">✅ Export complete!</div>', | |
| "Exported successfully", | |
| gr.update(interactive=True), | |
| gr.update(value=file_path) | |
| ) | |
| export_btn.click( | |
| fn=on_export_start, | |
| inputs=[], | |
| outputs=[export_status, export_progress, export_btn], | |
| ).then( | |
| fn=lambda inc, q, qt, fmt: export_merged_model( | |
| include_lora=inc, | |
| quantize=q and (qt != "none"), | |
| qtype=qt, # always pass the string value; exporter handles "none" correctly | |
| save_format=fmt, | |
| ), | |
| inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown], | |
| outputs=[download_link, export_progress], | |
| ).then( | |
| fn=lambda path, msg: on_export_complete(msg, "Exported", path), | |
| inputs=[download_link, export_progress], | |
| outputs=[export_status, export_progress, export_btn, download_link], | |
| ) | |
| quantize_toggle.change( | |
| fn=lambda checked: gr.update(visible=checked), | |
| inputs=[quantize_toggle], | |
| outputs=qtype_row, | |
| ) | |
| return demo | |
| demo = create_app() | |
| if __name__ == "__main__": | |
| demo.launch() | |