Kyle Pearson
Refactor entry point, consolidate app creation, update main block, simplify Gradio init
d723e62
#!/usr/bin/env python3
"""
SDXL Model Merger - Modernized with modular architecture and improved UI/UX.
This application allows you to:
- Load SDXL checkpoints with optional VAE and multiple LoRAs
- Generate images with seamless tiling support
- Export merged models with quantization options
Author: Qwen Code Assistant
"""
try:
import spaces # noqa: F401 — must be imported before torch/CUDA packages
except ImportError:
pass
import gradio as gr
def create_app():
"""Create and configure the Gradio app."""
header_css = """
.header-gradient {
background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.feature-card {
border-radius: 12px;
padding: 20px;
margin-bottom: 16px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
transition: transform 0.2s ease;
}
.feature-card:hover {
transform: translateY(-2px);
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
}
.gradio-container .label {
font-weight: 600;
color: #374151;
margin-bottom: 8px;
}
.status-success { color: #059669 !important; font-weight: 600; }
.status-error { color: #dc2626 !important; font-weight: 600; }
.status-warning { color: #d97706 !important; font-weight: 600; }
.gradio-container .btn {
border-radius: 8px;
padding: 12px 24px;
font-weight: 600;
}
.gradio-container textarea,
.gradio-container input[type="number"],
.gradio-container input[type="text"] {
border-radius: 8px;
border-color: #d1d5db;
}
.gradio-container textarea:focus,
.gradio-container input:focus {
outline: none;
border-color: #6366f1;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1);
}
.gradio-container .tabitem {
background: transparent;
border-radius: 12px;
}
.progress-text {
font-weight: 500;
color: #6b7280 !important;
}
"""
from src.pipeline import load_pipeline
from src.generator import generate_image
from src.exporter import export_merged_model
from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
with gr.Blocks(title="SDXL Model Merger") as demo:
# Header section
with gr.Column(elem_classes=["feature-card"]):
gr.HTML("""
<div style="text-align: center; margin-bottom: 24px;">
<h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;">
<span class="header-gradient">SDXL Model Merger</span>
</h1>
<p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;">
Merge checkpoints, LoRAs, and VAEs - then bake LoRAs into a single exportable
checkpoint with optional quantization.
</p>
</div>
""")
# Feature highlights
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""
<div style="text-align: center; padding: 16px;">
<div style="font-size: 2.5em; margin-bottom: 8px;">🚀</div>
<strong>Fast Loading</strong>
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p>
</div>
""")
with gr.Column(scale=1):
gr.HTML("""
<div style="text-align: center; padding: 16px;">
<div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div>
<strong>Panorama Gen</strong>
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p>
</div>
""")
with gr.Column(scale=1):
gr.HTML("""
<div style="text-align: center; padding: 16px;">
<div style="font-size: 2.5em; margin-bottom: 8px;">📦</div>
<strong>Export Ready</strong>
<p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p>
</div>
""")
gr.Markdown("---")
with gr.Tab("Load Pipeline"):
gr.Markdown("### Load SDXL Pipeline with Checkpoint, VAE, and LoRAs")
# Progress indicator for pipeline loading
load_progress = gr.Textbox(
label="Loading Progress",
placeholder="Ready to start...",
show_label=True,
info="Real-time status of model downloads and pipeline setup"
)
with gr.Row():
with gr.Column(scale=2):
# Checkpoint URL with cached models dropdown
checkpoint_url = gr.Textbox(
label="Base Model (.safetensors) URL",
value="https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
placeholder="e.g., https://civitai.com/api/download/models/...",
info="Download link for the base SDXL checkpoint"
)
# Dropdown of cached checkpoints
cached_checkpoints = gr.Dropdown(
choices=["(None found)"] + get_cached_checkpoints(),
label="Cached Checkpoints",
value="(None found)" if not get_cached_checkpoints() else None,
info="Models already downloaded to .cache/"
)
# VAE URL
vae_url = gr.Textbox(
label="VAE (.safetensors) URL",
value="https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
placeholder="Leave blank to use model's built-in VAE",
info="Optional custom VAE for improved quality"
)
# Dropdown of cached VAEs
cached_vaes = gr.Dropdown(
choices=["(None found)"] + get_cached_vaes(),
label="Cached VAEs",
value="(None found)" if not get_cached_vaes() else None,
info="Select a VAE to load"
)
with gr.Column(scale=1):
# LoRA URLs input
lora_urls = gr.Textbox(
label="LoRA URLs (one per line)",
lines=5,
value="https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor",
placeholder="https://civit.ai/...\nhttps://huggingface.co/...",
info="Multiple LoRAs can be loaded and fused together"
)
# Dropdown of cached LoRAs
cached_loras = gr.Dropdown(
choices=["(None found)"] + get_cached_loras(),
label="Cached LoRAs",
value="(None found)" if not get_cached_loras() else None,
info="Select a LoRA to add to the list below"
)
lora_strengths = gr.Textbox(
label="LoRA Strengths",
value="1.0",
placeholder="e.g., 0.8,1.0,0.5",
info="Comma-separated strength values for each LoRA"
)
with gr.Row():
load_btn = gr.Button("🚀 Load Pipeline", variant="primary", size="lg")
# Detailed status display
load_status = gr.HTML(
label="Status",
value='<div class="status-success">✅ Ready to load pipeline</div>',
)
with gr.Tab("Generate Image"):
gr.Markdown("### Generate Panorama Images with Seamless Tiling")
# Progress indicator for image generation
gen_progress = gr.Textbox(
label="Generation Progress",
placeholder="Ready to generate...",
show_label=True,
info="Real-time status of image generation"
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Positive Prompt",
value="Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic",
lines=4,
placeholder="Describe the image you want to generate..."
)
cfg = gr.Slider(
minimum=1.0, maximum=20.0, value=3.0, step=0.5,
label="CFG Scale",
info="Higher values make outputs match prompt more strictly"
)
height = gr.Number(
value=1024, precision=0,
label="Height (pixels)",
info="Output image height"
)
with gr.Column(scale=1):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="boring, text, signature, watermark, low quality, bad quality",
lines=4,
placeholder="Elements to avoid in generation..."
)
steps = gr.Slider(
minimum=1, maximum=100, value=8, step=1,
label="Inference Steps",
info="More steps = better quality but slower"
)
width = gr.Number(
value=2048, precision=0,
label="Width (pixels)",
info="Output image width"
)
with gr.Row():
tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")
seed = gr.Number(
value=80484030936239,
precision=0,
label="Seed",
info="Random seed for reproducible generation"
)
with gr.Row():
gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")
with gr.Row():
image_output = gr.Image(
label="Result",
height=400,
show_label=True
)
with gr.Column():
gen_status = gr.HTML(
label="Generation Status",
value='<div class="status-success">✅ Ready to generate</div>',
)
gr.HTML("""
<div style="margin-top: 16px; padding: 12px; background-color: #e5e7eb !important; border-radius: 8px;">
<strong style="color: #1f2937 !important;">💡 Tips:</strong>
<ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em; color: #1f2937 !important;">
<li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li>
<li>Enable seamless tiling for texture-like outputs</li>
<li>Lower CFG (3-5) for more creative results</li>
</ul>
</div>
""")
with gr.Tab("Export Model"):
gr.Markdown("### Export Merged Checkpoint with Quantization Options")
# Progress indicator for export
export_progress = gr.Textbox(
label="Export Progress",
placeholder="Ready to export...",
show_label=True,
info="Real-time status of model export and quantization"
)
with gr.Row():
include_lora = gr.Checkbox(
True,
label="Include Fused LoRAs",
info="Bake the loaded LoRAs into the exported model"
)
quantize_toggle = gr.Checkbox(
False,
label="Apply Quantization",
info="Reduce model size with quantization"
)
qtype_row = gr.Row(visible=True)
with qtype_row:
qtype_dropdown = gr.Dropdown(
choices=["none", "int8", "int4", "float8"],
value="int8",
label="Quantization Method",
info="Trade quality for smaller file size"
)
with gr.Row():
format_dropdown = gr.Dropdown(
choices=["safetensors", "bin"],
value="safetensors",
label="Export Format",
info="safetensors is recommended for safety"
)
with gr.Row():
export_btn = gr.Button("💾 Save Merged Checkpoint", variant="primary", size="lg")
with gr.Row():
download_link = gr.File(
label="Download Merged File",
show_label=True,
)
with gr.Column():
export_status = gr.HTML(
label="Export Status",
value='<div class="status-success">✅ Ready to export</div>',
)
gr.HTML("""
<div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
<strong>ℹ️ About Quantization:</strong>
<p style="font-size: 0.9em; margin: 8px 0;">
Reduces model size by lowering precision. Int8 is typically
lossless for inference while cutting size in half.
</p>
</div>
""")
# Event handlers - all inside Blocks context
def on_load_pipeline_start():
"""Called when pipeline loading starts."""
return (
'<div class="status-warning">⏳ Loading started...</div>',
"Starting download...",
gr.update(interactive=False)
)
def on_load_pipeline_complete(status_msg, progress_text):
"""Called when pipeline loading completes."""
if "✅" in status_msg:
return (
'<div class="status-success">✅ Pipeline loaded successfully!</div>',
progress_text,
gr.update(interactive=True)
)
elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
return (
'<div class="status-warning">⚠️ Download cancelled</div>',
progress_text,
gr.update(interactive=True)
)
else:
return (
f'<div class="status-error">{status_msg}</div>',
progress_text,
gr.update(interactive=True)
)
load_btn.click(
fn=on_load_pipeline_start,
inputs=[],
outputs=[load_status, load_progress, load_btn],
).then(
fn=load_pipeline,
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
outputs=[load_status, load_progress],
show_progress="full",
).then(
fn=on_load_pipeline_complete,
inputs=[load_status, load_progress],
outputs=[load_status, load_progress, load_btn],
).then(
fn=lambda: (
gr.update(choices=["(None found)"] + get_cached_checkpoints()),
gr.update(choices=["(None found)"] + get_cached_vaes()),
gr.update(choices=["(None found)"] + get_cached_loras()),
),
inputs=[],
outputs=[cached_checkpoints, cached_vaes, cached_loras],
)
def on_cached_checkpoint_change(cached_path):
"""Update URL when a cached checkpoint is selected."""
if cached_path and cached_path != "(None found)":
return gr.update(value=f"file://{cached_path}")
return gr.update()
cached_checkpoints.change(
fn=lambda x: gr.update(value=f"file://{x}" if x and x != "(None found)" else ""),
inputs=cached_checkpoints,
outputs=checkpoint_url,
)
def on_cached_vae_change(cached_path):
"""Update VAE URL when a cached VAE is selected."""
if cached_path and cached_path != "(None found)":
return gr.update(value=f"file://{cached_path}")
return gr.update()
cached_vaes.change(
fn=on_cached_vae_change,
inputs=cached_vaes,
outputs=vae_url,
)
def on_cached_lora_change(cached_path, current_urls):
"""Add cached LoRA to the list."""
if cached_path and cached_path != "(None found)":
urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
file_url = f"file://{cached_path}"
if file_url not in urls_list:
urls_list.append(file_url)
return gr.update(value="\n".join(urls_list))
return gr.update()
cached_loras.change(
fn=on_cached_lora_change,
inputs=[cached_loras, lora_urls],
outputs=lora_urls,
)
def on_generate_start():
"""Called when image generation starts."""
return (
'<div class="status-warning">⏳ Generating image...</div>',
"Starting generation...",
gr.update(interactive=False)
)
def on_generate_complete(status_msg, progress_text, image):
"""Called when image generation completes."""
if image is None:
return (
f'<div class="status-error">{status_msg}</div>',
"",
gr.update(interactive=True),
gr.update()
)
else:
return (
'<div class="status-success">✅ Generation complete!</div>',
"Done",
gr.update(interactive=True),
gr.update(value=image)
)
gen_btn.click(
fn=on_generate_start,
inputs=[],
outputs=[gen_status, gen_progress, gen_btn],
).then(
fn=generate_image,
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y, seed],
outputs=[image_output, gen_progress],
).then(
fn=lambda img, msg: on_generate_complete(msg, "Done", img),
inputs=[image_output, gen_progress],
outputs=[gen_status, gen_progress, gen_btn, image_output],
)
def on_export_start():
"""Called when export starts."""
return (
'<div class="status-warning">⏳ Export started...</div>',
"Starting export...",
gr.update(interactive=False)
)
def on_export_complete(status_msg, progress_text, file_path):
"""Called when export completes."""
if file_path is None:
return (
f'<div class="status-error">{status_msg}</div>',
"",
gr.update(interactive=True),
gr.update(value=None)
)
else:
return (
'<div class="status-success">✅ Export complete!</div>',
"Exported successfully",
gr.update(interactive=True),
gr.update(value=file_path)
)
export_btn.click(
fn=on_export_start,
inputs=[],
outputs=[export_status, export_progress, export_btn],
).then(
fn=lambda inc, q, qt, fmt: export_merged_model(
include_lora=inc,
quantize=q and (qt != "none"),
qtype=qt, # always pass the string value; exporter handles "none" correctly
save_format=fmt,
),
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
outputs=[download_link, export_progress],
).then(
fn=lambda path, msg: on_export_complete(msg, "Exported", path),
inputs=[download_link, export_progress],
outputs=[export_status, export_progress, export_btn, download_link],
)
quantize_toggle.change(
fn=lambda checked: gr.update(visible=checked),
inputs=[quantize_toggle],
outputs=qtype_row,
)
return demo
demo = create_app()
if __name__ == "__main__":
demo.launch()