File size: 2,819 Bytes
6a07ce1
 
 
 
570384a
3631a8e
 
570384a
6a07ce1
 
3631a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
570384a
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
570384a
 
 
 
 
6a07ce1
570384a
 
 
 
 
 
6a07ce1
3631a8e
 
459ac47
6a07ce1
570384a
 
6a07ce1
570384a
6a07ce1
570384a
 
3631a8e
 
6a07ce1
570384a
 
6a07ce1
570384a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Image generation functions for SDXL Model Merger."""

import torch

from . import config
from .config import device, dtype
from .gpu_decorator import GPU
from .tiling import enable_seamless_tiling


@GPU(duration=120)
def _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator):
    """GPU-decorated helper that runs the actual inference."""
    return pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=int(width),
        height=int(height),
        num_inference_steps=int(steps),
        guidance_scale=float(cfg),
        generator=generator,
    )


def generate_image(
    prompt: str,
    negative_prompt: str,
    cfg: float,
    steps: int,
    height: int,
    width: int,
    tile_x: bool = True,
    tile_y: bool = False,
    seed: int | None = None,
):
    """
    Generate an image using the loaded SDXL pipeline.

    Args:
        prompt: Positive prompt for image generation
        negative_prompt: Negative prompt to avoid certain elements
        cfg: Classifier-Free Guidance scale (1.0-20.0)
        steps: Number of inference steps (1-50)
        height: Output image height in pixels
        width: Output image width in pixels
        tile_x: Enable seamless tiling on x-axis
        tile_y: Enable seamless tiling on y-axis
        seed: Random seed for reproducibility (default: uses timestamp-based seed)

    Yields:
        Tuple of (intermediate_image_or_none, status_message)
        - First yield: (None, progress_text) for initial progress update
        - Final yield: (PIL Image, final_status) with the generated image,
          or (None, error_message) if generation failed.
    """
    # Fetch the pipeline at call time — avoids the stale import-by-value problem.
    pipe = config.get_pipe()

    if not pipe:
        yield None, "⚠️ Please load a pipeline first."
        return

    # Ensure VAE stays in float32 to prevent colorful static output
    pipe.vae.to(dtype=torch.float32)

    # Enable seamless tiling on UNet & VAE decoder
    enable_seamless_tiling(pipe.unet, tile_x=tile_x, tile_y=tile_y)
    enable_seamless_tiling(pipe.vae.decoder, tile_x=tile_x, tile_y=tile_y)

    yield None, "🎨 Generating image..."

    try:
        actual_seed = seed if seed is not None else int(torch.randint(0, 2**63, (1,)).item())
        generator = torch.Generator(device=device).manual_seed(actual_seed)
        result = _run_inference(pipe, prompt, negative_prompt, width, height, steps, cfg, generator)

        image = result.images[0]
        yield image, f"✅ Complete! ({int(width)}x{int(height)})"

    except Exception as e:
        import traceback
        error_msg = f"❌ Generation failed: {str(e)}"
        print(error_msg)
        print(traceback.format_exc())
        yield None, error_msg