SDXL-Model-Merger / src /generator.py
Kyle Pearson
Add zero-gpu support, enhance model export with quantization/gpu acceleration helpers, optimize inference pipeline with vae fixes, modernize pipeline loading with unified decorators, implement gpu decorator infrastructure.
3631a8e
"""Image generation functions for SDXL Model Merger."""
import torch
from . import config
from .config import device, dtype
from .gpu_decorator import GPU
from .tiling import enable_seamless_tiling
@GPU(duration=120)
def _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator):
"""GPU-decorated helper that runs the actual inference."""
return pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=int(width),
height=int(height),
num_inference_steps=int(steps),
guidance_scale=float(cfg),
generator=generator,
)
def generate_image(
prompt: str,
negative_prompt: str,
cfg: float,
steps: int,
height: int,
width: int,
tile_x: bool = True,
tile_y: bool = False,
seed: int | None = None,
):
"""
Generate an image using the loaded SDXL pipeline.
Args:
prompt: Positive prompt for image generation
negative_prompt: Negative prompt to avoid certain elements
cfg: Classifier-Free Guidance scale (1.0-20.0)
steps: Number of inference steps (1-50)
height: Output image height in pixels
width: Output image width in pixels
tile_x: Enable seamless tiling on x-axis
tile_y: Enable seamless tiling on y-axis
seed: Random seed for reproducibility (default: uses timestamp-based seed)
Yields:
Tuple of (intermediate_image_or_none, status_message)
- First yield: (None, progress_text) for initial progress update
- Final yield: (PIL Image, final_status) with the generated image,
or (None, error_message) if generation failed.
"""
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
pipe = config.get_pipe()
if not pipe:
yield None, "⚠️ Please load a pipeline first."
return
# Ensure VAE stays in float32 to prevent colorful static output
pipe.vae.to(dtype=torch.float32)
# Enable seamless tiling on UNet & VAE decoder
enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
enable_seamless_tiling(pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)
yield None, "🎨 Generating image..."
try:
actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
generator = torch.Generator(device=device).manual_seed(actual_seed)
result = _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator)
image = result.images[0]
yield image, f"✅ Complete! ({int(width)}x{int(height)})"
except Exception as e:
import traceback
error_msg = f"❌ Generation failed: {str(e)}"
print(error_msg)
print(traceback.format_exc())
yield None, error_msg