""" export_models.py ---------------- Downloads publicly available pretrained weights for SRCNN and EDSR (HResNet-style) and exports them as ONNX files into the ./model/ directory. Run once before starting app.py: pip install torch torchvision huggingface_hub basicsr python export_models.py After this script finishes you should have: model/SRCNN_x4.onnx model/HResNet_x4.onnx Then upload both files to Google Drive, copy the file IDs into DRIVE_IDS in app.py, OR set LOCAL_ONLY = True below to skip Drive entirely and load straight from disk. """ import os import torch import torch.nn as nn import torch.onnx from pathlib import Path MODEL_DIR = Path("model") MODEL_DIR.mkdir(exist_ok=True) # --------------------------------------------------------------------------- # Set to True to skip Drive and have app.py load the ONNX files from disk # directly. In app.py, remove the download_from_drive call for these keys # (or just leave the placeholder Drive ID — the script already guards against # missing files gracefully). # --------------------------------------------------------------------------- LOCAL_ONLY = True # flip to False once you have Drive IDs # =========================================================================== # 1. SRCNN ×4 # Architecture: Dong et al. 2014 — 3 conv layers, no upsampling inside # the network. Input is bicubic-upscaled LR; output is the refined HR. # We bicubic-upsample inside a wrapper so the ONNX takes a raw LR image. # =========================================================================== class SRCNN(nn.Module): """Original SRCNN (Dong et al., 2014).""" def __init__(self, num_channels: int = 3): super().__init__() self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) self.relu = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) return self.conv3(x) class SRCNNx4Wrapper(nn.Module): """ Wraps SRCNN so the ONNX input is a LOW-resolution image. Internally bicubic-upsamples by ×4 before feeding SRCNN, matching the interface expected by app.py's tile_upscale_model. """ def __init__(self, srcnn: SRCNN, scale: int = 4): super().__init__() self.srcnn = srcnn self.scale = scale def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (1, 3, H, W) — low-res, float32 in [0, 1] up = torch.nn.functional.interpolate( x, scale_factor=self.scale, mode="bicubic", align_corners=False ) return self.srcnn(up) def build_srcnn_x4() -> nn.Module: """ Loads pretrained SRCNN weights from the basicsr model zoo. Falls back to random init with a warning if download fails. """ srcnn = SRCNN(num_channels=3) wrapper = SRCNNx4Wrapper(srcnn, scale=4) # Pretrained weights from the basicsr / mmedit community # (original Caffe weights re-converted to PyTorch by https://github.com/yjn870/SRCNN-pytorch) SRCNN_WEIGHTS_URL = ( "https://github.com/yjn870/SRCNN-pytorch/raw/master/models/" "srcnn_x4.pth" ) weights_path = MODEL_DIR / "srcnn_x4.pth" if not weights_path.exists(): print(" Downloading SRCNN ×4 weights …") try: import urllib.request urllib.request.urlretrieve(SRCNN_WEIGHTS_URL, weights_path) print(f" Saved → {weights_path}") except Exception as e: print(f" [WARN] Could not download SRCNN weights: {e}") print(" Continuing with random init (quality will be poor).") return wrapper state = torch.load(weights_path, map_location="cpu") # The yjn870 checkpoint uses keys conv1/conv2/conv3 matching our module try: srcnn.load_state_dict(state, strict=True) print(" SRCNN weights loaded ✓") except RuntimeError as e: print(f" [WARN] Weight mismatch: {e}\n Proceeding with partial load.") srcnn.load_state_dict(state, strict=False) return wrapper # =========================================================================== # 2. EDSR (HResNet-style) ×4 # EDSR-baseline (Lim et al., 2017) is the canonical "deep residual" SR # network. Pretrained weights from eugenesiow/torch-sr (HuggingFace). # =========================================================================== class ResBlock(nn.Module): def __init__(self, n_feats: int, res_scale: float = 1.0): super().__init__() self.body = nn.Sequential( nn.Conv2d(n_feats, n_feats, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(n_feats, n_feats, 3, padding=1), ) self.res_scale = res_scale def forward(self, x): return x + self.body(x) * self.res_scale class Upsampler(nn.Sequential): def __init__(self, scale: int, n_feats: int): layers = [] if scale in (2, 4): steps = {2: 1, 4: 2}[scale] for _ in range(steps): layers += [ nn.Conv2d(n_feats, 4 * n_feats, 3, padding=1), nn.PixelShuffle(2), ] elif scale == 3: layers += [ nn.Conv2d(n_feats, 9 * n_feats, 3, padding=1), nn.PixelShuffle(3), ] super().__init__(*layers) class EDSR(nn.Module): """ EDSR-baseline: 16 residual blocks, 64 feature channels. Matches the publicly released weights from eugenesiow/torch-sr. """ def __init__(self, n_resblocks: int = 16, n_feats: int = 64, scale: int = 4, num_channels: int = 3): super().__init__() self.head = nn.Conv2d(num_channels, n_feats, 3, padding=1) self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)]) self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1) self.tail = nn.Sequential( Upsampler(scale, n_feats), nn.Conv2d(n_feats, num_channels, 3, padding=1), ) def forward(self, x): x = self.head(x) res = self.body(x) res = self.body_tail(res) x = x + res return self.tail(x) def build_edsr_x4() -> nn.Module: """ Downloads EDSR-baseline ×4 weights and loads them. Source: eugenesiow/torch-sr (Apache-2.0 licensed). """ model = EDSR(n_resblocks=16, n_feats=64, scale=4) # Direct link to the EDSR-baseline ×4 checkpoint EDSR_WEIGHTS_URL = ( "https://huggingface.co/eugenesiow/edsr-base/resolve/main/" "pytorch_model_4x.pt" ) weights_path = MODEL_DIR / "edsr_x4.pt" if not weights_path.exists(): print(" Downloading EDSR ×4 weights from HuggingFace …") try: import urllib.request urllib.request.urlretrieve(EDSR_WEIGHTS_URL, weights_path) print(f" Saved → {weights_path}") except Exception as e: print(f" [WARN] Could not download EDSR weights: {e}") print(" Continuing with random init (quality will be poor).") return model state = torch.load(weights_path, map_location="cpu") # eugenesiow checkpoints may wrap state_dict under a 'model' key if "model" in state: state = state["model"] if "state_dict" in state: state = state["state_dict"] # Strip any 'module.' prefix from DataParallel wrapping state = {k.replace("module.", ""): v for k, v in state.items()} try: model.load_state_dict(state, strict=True) print(" EDSR weights loaded ✓") except RuntimeError as e: print(f" [WARN] Weight mismatch ({e}). Trying strict=False …") model.load_state_dict(state, strict=False) print(" EDSR weights loaded (partial) ✓") return model # =========================================================================== # ONNX export helper # =========================================================================== def export_onnx(model: nn.Module, out_path: Path, tile_h: int = 128, tile_w: int = 128): """Export *model* to ONNX with dynamic H/W axes.""" model.eval() dummy = torch.zeros(1, 3, tile_h, tile_w) torch.onnx.export( model, dummy, str(out_path), opset_version=17, input_names=["input"], output_names=["output"], dynamic_axes={ "input": {0: "batch", 2: "H", 3: "W"}, "output": {0: "batch", 2: "H_out", 3: "W_out"}, }, ) size_mb = out_path.stat().st_size / 1_048_576 print(f" Exported → {out_path} ({size_mb:.1f} MB)") # =========================================================================== # Main # =========================================================================== if __name__ == "__main__": print("=" * 60) print("SpectraGAN — ONNX model exporter") print("=" * 60) # -- SRCNN ×4 ------------------------------------------------------------ srcnn_out = MODEL_DIR / "SRCNN_x4.onnx" if srcnn_out.exists(): print(f"\n[SKIP] {srcnn_out} already exists.") else: print("\n[1/2] Building SRCNN ×4 …") srcnn_model = build_srcnn_x4() print(" Exporting to ONNX …") export_onnx(srcnn_model, srcnn_out, tile_h=128, tile_w=128) # -- EDSR (HResNet) ×4 --------------------------------------------------- edsr_out = MODEL_DIR / "HResNet_x4.onnx" if edsr_out.exists(): print(f"\n[SKIP] {edsr_out} already exists.") else: print("\n[2/2] Building EDSR (HResNet) ×4 …") edsr_model = build_edsr_x4() print(" Exporting to ONNX …") export_onnx(edsr_model, edsr_out, tile_h=128, tile_w=128) print("\n" + "=" * 60) print("Done! Files created:") for p in [srcnn_out, edsr_out]: status = "✓" if p.exists() else "✗ MISSING" print(f" {status} {p}") print() if LOCAL_ONLY: print("LOCAL_ONLY = True:") print(" app.py will load these files directly from disk.") print(" No Google Drive upload needed.") else: print("Next step:") print(" Upload the .onnx files to Google Drive and paste") print(" the file IDs into DRIVE_IDS in app.py.") print("=" * 60)