spectraGAN / export_models.py
ParamDev's picture
Upload export_models.py
54d2540 verified
"""
export_models.py
----------------
Downloads publicly available pretrained weights for SRCNN and EDSR (HResNet-style)
and exports them as ONNX files into the ./model/ directory.
Run once before starting app.py:
pip install torch torchvision huggingface_hub basicsr
python export_models.py
After this script finishes you should have:
model/SRCNN_x4.onnx
model/HResNet_x4.onnx
Then upload both files to Google Drive, copy the file IDs into DRIVE_IDS in app.py,
OR set LOCAL_ONLY = True below to skip Drive entirely and load straight from disk.
"""
import os
import torch
import torch.nn as nn
import torch.onnx
from pathlib import Path
MODEL_DIR = Path("model")
MODEL_DIR.mkdir(exist_ok=True)
# ---------------------------------------------------------------------------
# Set to True to skip Drive and have app.py load the ONNX files from disk
# directly. In app.py, remove the download_from_drive call for these keys
# (or just leave the placeholder Drive ID β€” the script already guards against
# missing files gracefully).
# ---------------------------------------------------------------------------
LOCAL_ONLY = True # flip to False once you have Drive IDs
# ===========================================================================
# 1. SRCNN Γ—4
# Architecture: Dong et al. 2014 β€” 3 conv layers, no upsampling inside
# the network. Input is bicubic-upscaled LR; output is the refined HR.
# We bicubic-upsample inside a wrapper so the ONNX takes a raw LR image.
# ===========================================================================
class SRCNN(nn.Module):
"""Original SRCNN (Dong et al., 2014)."""
def __init__(self, num_channels: int = 3):
super().__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
return self.conv3(x)
class SRCNNx4Wrapper(nn.Module):
"""
Wraps SRCNN so the ONNX input is a LOW-resolution image.
Internally bicubic-upsamples by Γ—4 before feeding SRCNN,
matching the interface expected by app.py's tile_upscale_model.
"""
def __init__(self, srcnn: SRCNN, scale: int = 4):
super().__init__()
self.srcnn = srcnn
self.scale = scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (1, 3, H, W) β€” low-res, float32 in [0, 1]
up = torch.nn.functional.interpolate(
x, scale_factor=self.scale, mode="bicubic", align_corners=False
)
return self.srcnn(up)
def build_srcnn_x4() -> nn.Module:
"""
Loads pretrained SRCNN weights from the basicsr model zoo.
Falls back to random init with a warning if download fails.
"""
srcnn = SRCNN(num_channels=3)
wrapper = SRCNNx4Wrapper(srcnn, scale=4)
# Pretrained weights from the basicsr / mmedit community
# (original Caffe weights re-converted to PyTorch by https://github.com/yjn870/SRCNN-pytorch)
SRCNN_WEIGHTS_URL = (
"https://github.com/yjn870/SRCNN-pytorch/raw/master/models/"
"srcnn_x4.pth"
)
weights_path = MODEL_DIR / "srcnn_x4.pth"
if not weights_path.exists():
print(" Downloading SRCNN Γ—4 weights …")
try:
import urllib.request
urllib.request.urlretrieve(SRCNN_WEIGHTS_URL, weights_path)
print(f" Saved β†’ {weights_path}")
except Exception as e:
print(f" [WARN] Could not download SRCNN weights: {e}")
print(" Continuing with random init (quality will be poor).")
return wrapper
state = torch.load(weights_path, map_location="cpu")
# The yjn870 checkpoint uses keys conv1/conv2/conv3 matching our module
try:
srcnn.load_state_dict(state, strict=True)
print(" SRCNN weights loaded βœ“")
except RuntimeError as e:
print(f" [WARN] Weight mismatch: {e}\n Proceeding with partial load.")
srcnn.load_state_dict(state, strict=False)
return wrapper
# ===========================================================================
# 2. EDSR (HResNet-style) Γ—4
# EDSR-baseline (Lim et al., 2017) is the canonical "deep residual" SR
# network. Pretrained weights from eugenesiow/torch-sr (HuggingFace).
# ===========================================================================
class ResBlock(nn.Module):
def __init__(self, n_feats: int, res_scale: float = 1.0):
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(n_feats, n_feats, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(n_feats, n_feats, 3, padding=1),
)
self.res_scale = res_scale
def forward(self, x):
return x + self.body(x) * self.res_scale
class Upsampler(nn.Sequential):
def __init__(self, scale: int, n_feats: int):
layers = []
if scale in (2, 4):
steps = {2: 1, 4: 2}[scale]
for _ in range(steps):
layers += [
nn.Conv2d(n_feats, 4 * n_feats, 3, padding=1),
nn.PixelShuffle(2),
]
elif scale == 3:
layers += [
nn.Conv2d(n_feats, 9 * n_feats, 3, padding=1),
nn.PixelShuffle(3),
]
super().__init__(*layers)
class EDSR(nn.Module):
"""
EDSR-baseline: 16 residual blocks, 64 feature channels.
Matches the publicly released weights from eugenesiow/torch-sr.
"""
def __init__(self, n_resblocks: int = 16, n_feats: int = 64,
scale: int = 4, num_channels: int = 3):
super().__init__()
self.head = nn.Conv2d(num_channels, n_feats, 3, padding=1)
self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)])
self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1)
self.tail = nn.Sequential(
Upsampler(scale, n_feats),
nn.Conv2d(n_feats, num_channels, 3, padding=1),
)
def forward(self, x):
x = self.head(x)
res = self.body(x)
res = self.body_tail(res)
x = x + res
return self.tail(x)
def build_edsr_x4() -> nn.Module:
"""
Downloads EDSR-baseline Γ—4 weights and loads them.
Source: eugenesiow/torch-sr (Apache-2.0 licensed).
"""
model = EDSR(n_resblocks=16, n_feats=64, scale=4)
# Direct link to the EDSR-baseline Γ—4 checkpoint
EDSR_WEIGHTS_URL = (
"https://huggingface.co/eugenesiow/edsr-base/resolve/main/"
"pytorch_model_4x.pt"
)
weights_path = MODEL_DIR / "edsr_x4.pt"
if not weights_path.exists():
print(" Downloading EDSR Γ—4 weights from HuggingFace …")
try:
import urllib.request
urllib.request.urlretrieve(EDSR_WEIGHTS_URL, weights_path)
print(f" Saved β†’ {weights_path}")
except Exception as e:
print(f" [WARN] Could not download EDSR weights: {e}")
print(" Continuing with random init (quality will be poor).")
return model
state = torch.load(weights_path, map_location="cpu")
# eugenesiow checkpoints may wrap state_dict under a 'model' key
if "model" in state:
state = state["model"]
if "state_dict" in state:
state = state["state_dict"]
# Strip any 'module.' prefix from DataParallel wrapping
state = {k.replace("module.", ""): v for k, v in state.items()}
try:
model.load_state_dict(state, strict=True)
print(" EDSR weights loaded βœ“")
except RuntimeError as e:
print(f" [WARN] Weight mismatch ({e}). Trying strict=False …")
model.load_state_dict(state, strict=False)
print(" EDSR weights loaded (partial) βœ“")
return model
# ===========================================================================
# ONNX export helper
# ===========================================================================
def export_onnx(model: nn.Module, out_path: Path, tile_h: int = 128, tile_w: int = 128):
"""Export *model* to ONNX with dynamic H/W axes."""
model.eval()
dummy = torch.zeros(1, 3, tile_h, tile_w)
torch.onnx.export(
model,
dummy,
str(out_path),
opset_version=17,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch", 2: "H", 3: "W"},
"output": {0: "batch", 2: "H_out", 3: "W_out"},
},
)
size_mb = out_path.stat().st_size / 1_048_576
print(f" Exported β†’ {out_path} ({size_mb:.1f} MB)")
# ===========================================================================
# Main
# ===========================================================================
if __name__ == "__main__":
print("=" * 60)
print("SpectraGAN β€” ONNX model exporter")
print("=" * 60)
# -- SRCNN Γ—4 ------------------------------------------------------------
srcnn_out = MODEL_DIR / "SRCNN_x4.onnx"
if srcnn_out.exists():
print(f"\n[SKIP] {srcnn_out} already exists.")
else:
print("\n[1/2] Building SRCNN Γ—4 …")
srcnn_model = build_srcnn_x4()
print(" Exporting to ONNX …")
export_onnx(srcnn_model, srcnn_out, tile_h=128, tile_w=128)
# -- EDSR (HResNet) Γ—4 ---------------------------------------------------
edsr_out = MODEL_DIR / "HResNet_x4.onnx"
if edsr_out.exists():
print(f"\n[SKIP] {edsr_out} already exists.")
else:
print("\n[2/2] Building EDSR (HResNet) Γ—4 …")
edsr_model = build_edsr_x4()
print(" Exporting to ONNX …")
export_onnx(edsr_model, edsr_out, tile_h=128, tile_w=128)
print("\n" + "=" * 60)
print("Done! Files created:")
for p in [srcnn_out, edsr_out]:
status = "βœ“" if p.exists() else "βœ— MISSING"
print(f" {status} {p}")
print()
if LOCAL_ONLY:
print("LOCAL_ONLY = True:")
print(" app.py will load these files directly from disk.")
print(" No Google Drive upload needed.")
else:
print("Next step:")
print(" Upload the .onnx files to Google Drive and paste")
print(" the file IDs into DRIVE_IDS in app.py.")
print("=" * 60)