Spaces:
Running on Zero
Running on Zero
| """Seamless tiling utilities for SDXL Model Merger.""" | |
| import torch | |
| def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool): | |
| """Create patched forward for seamless tiling on Conv2d layers.""" | |
| original_forward = module._conv_forward | |
| def patched_conv_forward(input, weight, bias): | |
| if tile_x and tile_y: | |
| # Circular padding on both axes | |
| input = torch.nn.functional.pad(input, (pad_w, pad_w, pad_h, pad_h), mode="circular") | |
| elif tile_x: | |
| # Circular padding only on left/right edges, constant (zero) on top/bottom | |
| # Asymmetric padding for 360° panorama tiling | |
| input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="circular") | |
| input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="constant", value=0) | |
| elif tile_y: | |
| # Circular padding only on top/bottom edges, constant (zero) on left/right | |
| input = torch.nn.functional.pad(input, (0, 0, pad_h, pad_h), mode="circular") | |
| input = torch.nn.functional.pad(input, (pad_w, pad_w, 0, 0), mode="constant", value=0) | |
| else: | |
| return original_forward(input, weight, bias) | |
| return torch.nn.functional.conv2d( | |
| input, weight, bias, module.stride, (0, 0), module.dilation, module.groups | |
| ) | |
| return patched_conv_forward | |
| def enable_seamless_tiling(model, tile_x: bool = True, tile_y: bool = False): | |
| """ | |
| Enable seamless tiling on a model's Conv2d layers. | |
| Args: | |
| model: PyTorch model with Conv2d layers (e.g., pipe.unet, pipe.vae.decoder) | |
| tile_x: Enable tiling along x-axis | |
| tile_y: Enable tiling along y-axis | |
| """ | |
| for module in model.modules(): | |
| if isinstance(module, torch.nn.Conv2d): | |
| pad_h = module.padding[0] | |
| pad_w = module.padding[1] | |
| if pad_h == 0 and pad_w == 0: | |
| continue | |
| current = getattr(module, "_tiling_config", None) | |
| if current == (tile_x, tile_y): | |
| continue # already patched with same config | |
| module._tiling_config = (tile_x, tile_y) | |
| module._conv_forward = _make_asymmetric_forward(module, pad_h, pad_w, tile_x, tile_y) | |