SDXL-Model-Merger / src /tiling.py
Kyle Pearson
cleaned up code
570384a
"""Seamless tiling utilities for SDXL Model Merger."""
import torch
def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
"""Create patched forward for seamless tiling on Conv2d layers."""
original_forward = module._conv_forward
def patched_conv_forward(input, weight, bias):
if tile_x and tile_y:
# Circular padding on both axes
input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular")
elif tile_x:
# Circular padding only on left/right edges, constant (zero) on top/bottom
# Asymmetric padding for 360° panorama tiling
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular")
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0)
elif tile_y:
# Circular padding only on top/bottom edges, constant (zero) on left/right
input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular")
input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0)
else:
return original_forward(input, weight, bias)
return torch.nn.functional.conv2d(
input, weight, bias, module.stride, (0, 0), module.dilation, module.groups
)
return patched_conv_forward
def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False):
"""
Enable seamless tiling on a model's Conv2d layers.
Args:
model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder)
tile_x: Enable tiling along x-axis
tile_y: Enable tiling along y-axis
"""
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
pad_h = module.padding[0]
pad_w = module.padding[1]
if pad_h == 0 and pad_w == 0:
continue
current = getattr(module, "_tiling_config", None)
if current == (tile_x, tile_y):
continue # already patched with same config
module._tiling_config = (tile_x, tile_y)
module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)