File size: 2,274 Bytes
570384a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Seamless tiling utilities for SDXL Model Merger."""

import torch


def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
    """Create patched forward for seamless tiling on Conv2d layers."""
    original_forward = module._conv_forward

    def patched_conv_forward(input, weight, bias):
        if tile_x and tile_y:
            # Circular padding on both axes
            input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
        elif tile_x:
            # Circular padding only on left/right edges, constant (zero) on top/bottom
            # Asymmetric padding for 360° panorama tiling
            input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
            input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
        elif tile_y:
            # Circular padding only on top/bottom edges, constant (zero) on left/right
            input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
            input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
        else:
            return original_forward(input, weight, bias)

        return torch.nn.functional.conv2d(
            input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
        )

    return patched_conv_forward


def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
    """
    Enable seamless tiling on a model's Conv2d layers.

    Args:
        model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
        tile_x: Enable tiling along x-axis
        tile_y: Enable tiling along y-axis
    """
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d):
            pad_h = module.padding[0]
            pad_w = module.padding[1]
            if pad_h == 0 and pad_w == 0:
                continue
            current = getattr(module, "_tiling_config", None)
            if current == (tile_x, tile_y):
                continue  # already patched with same config
            module._tiling_config = (tile_x, tile_y)
            module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)