Spaces:
Running on Zero
Running on Zero
Kyle Pearson
Add zero-gpu support, enhance model export with quantization/gpu acceleration helpers, optimize inference pipeline with vae fixes, modernize pipeline loading with unified decorators, implement gpu decorator infrastructure.
3631a8e | """Image generation functions for SDXL Model Merger.""" | |
| import torch | |
| from . import config | |
| from .config import device, dtype | |
| from .gpu_decorator import GPU | |
| from .tiling import enable_seamless_tiling | |
| def _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator): | |
| """GPU-decorated helper that runs the actual inference.""" | |
| return pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=int(width), | |
| height=int(height), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(cfg), | |
| generator=generator, | |
| ) | |
| def generate_image( | |
| prompt: str, | |
| negative_prompt: str, | |
| cfg: float, | |
| steps: int, | |
| height: int, | |
| width: int, | |
| tile_x: bool = True, | |
| tile_y: bool = False, | |
| seed: int | None = None, | |
| ): | |
| """ | |
| Generate an image using the loaded SDXL pipeline. | |
| Args: | |
| prompt: Positive prompt for image generation | |
| negative_prompt: Negative prompt to avoid certain elements | |
| cfg: Classifier-Free Guidance scale (1.0-20.0) | |
| steps: Number of inference steps (1-50) | |
| height: Output image height in pixels | |
| width: Output image width in pixels | |
| tile_x: Enable seamless tiling on x-axis | |
| tile_y: Enable seamless tiling on y-axis | |
| seed: Random seed for reproducibility (default: uses timestamp-based seed) | |
| Yields: | |
| Tuple of (intermediate_image_or_none, status_message) | |
| - First yield: (None, progress_text) for initial progress update | |
| - Final yield: (PIL Image, final_status) with the generated image, | |
| or (None, error_message) if generation failed. | |
| """ | |
| # Fetch the pipeline at call time — avoids the stale import-by-value problem. | |
| pipe = config.get_pipe() | |
| if not pipe: | |
| yield None, "⚠️ Please load a pipeline first." | |
| return | |
| # Ensure VAE stays in float32 to prevent colorful static output | |
| pipe.vae.to(dtype=torch.float32) | |
| # Enable seamless tiling on UNet & VAE decoder | |
| enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y) | |
| enable_seamless_tiling(pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y) | |
| yield None, "🎨 Generating image..." | |
| try: | |
| actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item()) | |
| generator = torch.Generator(device=device).manual_seed(actual_seed) | |
| result = _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator) | |
| image = result.images[0] | |
| yield image, f"✅ Complete! ({int(width)}x{int(height)})" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Generation failed: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| yield None, error_msg | |