"""Image generation functions for SDXL Model Merger.""" import torch from . import config from .config import device, dtype from .gpu_decorator import GPU from .tiling import enable_seamless_tiling @GPU(duration=120) def _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator): """GPU-decorated helper that runs the actual inference.""" return pipe( prompt=prompt, negative_prompt=negative_prompt, width=int(width), height=int(height), num_inference_steps=int(steps), guidance_scale=float(cfg), generator=generator, ) def generate_image( prompt: str, negative_prompt: str, cfg: float, steps: int, height: int, width: int, tile_x: bool = True, tile_y: bool = False, seed: int | None = None, ): """ Generate an image using the loaded SDXL pipeline. Args: prompt: Positive prompt for image generation negative_prompt: Negative prompt to avoid certain elements cfg: Classifier-Free Guidance scale (1.0-20.0) steps: Number of inference steps (1-50) height: Output image height in pixels width: Output image width in pixels tile_x: Enable seamless tiling on x-axis tile_y: Enable seamless tiling on y-axis seed: Random seed for reproducibility (default: uses timestamp-based seed) Yields: Tuple of (intermediate_image_or_none, status_message) - First yield: (None, progress_text) for initial progress update - Final yield: (PIL Image, final_status) with the generated image, or (None, error_message) if generation failed. """ # Fetch the pipeline at call time — avoids the stale import-by-value problem. pipe = config.get_pipe() if not pipe: yield None, "⚠️ Please load a pipeline first." return # Ensure VAE stays in float32 to prevent colorful static output pipe.vae.to(dtype=torch.float32) # Enable seamless tiling on UNet & VAE decoder enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y) enable_seamless_tiling(pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y) yield None, "🎨 Generating image..." try: actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item()) generator = torch.Generator(device=device).manual_seed(actual_seed) result = _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator) image = result.images[0] yield image, f"✅ Complete! ({int(width)}x{int(height)})" except Exception as e: import traceback error_msg = f"❌ Generation failed: {str(e)}" print(error_msg) print(traceback.format_exc()) yield None, error_msg