"""Seamless tiling utilities for SDXL Model Merger.""" import torch def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool): """Create patched forward for seamless tiling on Conv2d layers.""" original_forward = module._conv_forward def patched_conv_forward(input, weight, bias): if tile_x and tile_y: # Circular padding on both axes input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular") elif tile_x: # Circular padding only on left/right edges, constant (zero) on top/bottom # Asymmetric padding for 360° panorama tiling input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular") input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0) elif tile_y: # Circular padding only on top/bottom edges, constant (zero) on left/right input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular") input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0) else: return original_forward(input, weight, bias) return torch.nn.functional.conv2d( input, weight, bias, module.stride, (0, 0), module.dilation, module.groups ) return patched_conv_forward def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False): """ Enable seamless tiling on a model's Conv2d layers. Args: model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder) tile_x: Enable tiling along x-axis tile_y: Enable tiling along y-axis """ for module in model.modules(): if isinstance(module, torch.nn.Conv2d): pad_h = module.padding[0] pad_w = module.padding[1] if pad_h == 0 and pad_w == 0: continue current = getattr(module, "_tiling_config", None) if current == (tile_x, tile_y): continue # already patched with same config module._tiling_config = (tile_x, tile_y) module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y)