Spaces:
Build error
Build error
| import cv2 | |
| import numpy as np | |
| import os | |
| import urllib.request | |
| # βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "weights") | |
| MODEL_FILENAME = "realesrgan_x4plus.onnx" | |
| MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILENAME) | |
| MODEL_URL = ( | |
| "https://huggingface.co/Qualcomm/Real-ESRGAN-x4plus/resolve/main/" | |
| "Real-ESRGAN-x4plus.onnx" | |
| ) | |
| SCALE_FACTOR = 4 | |
| TILE_SIZE = 256 # Process in tiles to limit memory usage | |
| TILE_OVERLAP = 16 # Overlap between tiles for seamless stitching | |
| # Lazy-loaded ONNX session | |
| _session = None | |
| def _ensure_model(): | |
| """Download the Real-ESRGAN ONNX model if it doesn't exist locally.""" | |
| if os.path.exists(MODEL_PATH): | |
| return | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| print(f"Downloading Real-ESRGAN x4plus model to {MODEL_PATH} ...") | |
| print("(This is a one-time download, ~67 MB)") | |
| urllib.request.urlretrieve(MODEL_URL, MODEL_PATH) | |
| print("Download complete.") | |
| def _get_session(): | |
| """Lazily initialize the ONNX Runtime inference session.""" | |
| global _session | |
| if _session is None: | |
| import onnxruntime as ort | |
| ort.set_default_logger_severity(3) # Suppress verbose logs | |
| _ensure_model() | |
| opts = ort.SessionOptions() | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| _session = ort.InferenceSession( | |
| MODEL_PATH, | |
| sess_options=opts, | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| return _session | |
| def _run_esrgan_tile(session, tile_bgr: np.ndarray) -> np.ndarray: | |
| """ | |
| Run a single BGR tile through the Real-ESRGAN ONNX model. | |
| Input: uint8 BGR HWC β Output: uint8 BGR HWC (4Γ larger) | |
| """ | |
| # BGR β RGB, HWC β CHW, normalise to [0,1] | |
| rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| tensor = np.expand_dims(rgb.transpose(2, 0, 1), axis=0) # 1Γ3ΓHΓW | |
| input_name = session.get_inputs()[0].name | |
| result = session.run(None, {input_name: tensor})[0][0] # 3Γ(4H)Γ(4W) | |
| # CHW β HWC, clip, convert back to BGR uint8 | |
| out_rgb = (result.transpose(1, 2, 0) * 255.0).clip(0, 255).astype(np.uint8) | |
| return cv2.cvtColor(out_rgb, cv2.COLOR_RGB2BGR) | |
| def _upscale_tiled(session, img_bgr: np.ndarray) -> np.ndarray: | |
| """ | |
| Upscale a full BGR image using tiled inference with overlap blending. | |
| This prevents OOM on large images while avoiding visible seams. | |
| """ | |
| h, w = img_bgr.shape[:2] | |
| sf = SCALE_FACTOR | |
| # Pad image so dimensions are divisible by tile_size | |
| pad_h = (TILE_SIZE - h % TILE_SIZE) % TILE_SIZE | |
| pad_w = (TILE_SIZE - w % TILE_SIZE) % TILE_SIZE | |
| padded = cv2.copyMakeBorder(img_bgr, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) | |
| ph, pw = padded.shape[:2] | |
| # Output canvas | |
| out_h, out_w = ph * sf, pw * sf | |
| output = np.zeros((out_h, out_w, 3), dtype=np.float64) | |
| weight = np.zeros((out_h, out_w, 1), dtype=np.float64) | |
| # Iterate over tiles with overlap | |
| step = TILE_SIZE - TILE_OVERLAP | |
| for y in range(0, ph, step): | |
| for x in range(0, pw, step): | |
| # Clamp tile boundaries | |
| ty = min(y, ph - TILE_SIZE) | |
| tx = min(x, pw - TILE_SIZE) | |
| tile = padded[ty : ty + TILE_SIZE, tx : tx + TILE_SIZE] | |
| # Run inference | |
| upscaled_tile = _run_esrgan_tile(session, tile) | |
| # Output coordinates | |
| oy, ox = ty * sf, tx * sf | |
| th, tw = upscaled_tile.shape[:2] | |
| # Accumulate with simple averaging (overlap regions get averaged) | |
| output[oy : oy + th, ox : ox + tw] += upscaled_tile.astype(np.float64) | |
| weight[oy : oy + th, ox : ox + tw] += 1.0 | |
| # Average overlapping regions | |
| weight = np.maximum(weight, 1.0) | |
| output = (output / weight).clip(0, 255).astype(np.uint8) | |
| # Remove padding from output | |
| return output[: h * sf, : w * sf] | |
| def upscale_image(img: np.ndarray) -> np.ndarray: | |
| """ | |
| Upscale an image 4Γ using Real-ESRGAN via ONNX Runtime. | |
| Handles both BGR and BGRA (transparent) images. | |
| Falls back to local Lanczos upscaling if ONNX inference fails. | |
| """ | |
| has_alpha = len(img.shape) == 3 and img.shape[2] == 4 | |
| if has_alpha: | |
| bgr = img[:, :, :3] | |
| alpha = img[:, :, 3] | |
| else: | |
| bgr = img | |
| alpha = None | |
| try: | |
| session = _get_session() | |
| upscaled_bgr = _upscale_tiled(session, bgr) | |
| if alpha is not None: | |
| uh, uw = upscaled_bgr.shape[:2] | |
| upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4) | |
| _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY) | |
| return cv2.merge(( | |
| upscaled_bgr[:, :, 0], | |
| upscaled_bgr[:, :, 1], | |
| upscaled_bgr[:, :, 2], | |
| upscaled_alpha, | |
| )) | |
| return upscaled_bgr | |
| except Exception as e: | |
| print(f"Real-ESRGAN upscale failed: {e}") | |
| print("Falling back to local Lanczos upscaling...") | |
| return _local_fallback_upscale(img) | |
| def _local_fallback_upscale(img: np.ndarray) -> np.ndarray: | |
| """ | |
| Fallback: local multi-pass Lanczos + sharpening if ONNX is unavailable. | |
| """ | |
| has_alpha = len(img.shape) == 3 and img.shape[2] == 4 | |
| if has_alpha: | |
| bgr = img[:, :, :3] | |
| alpha = img[:, :, 3] | |
| else: | |
| bgr = img | |
| alpha = None | |
| h, w = bgr.shape[:2] | |
| upscaled = cv2.resize(bgr, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
| upscaled = cv2.bilateralFilter(upscaled, d=5, sigmaColor=40, sigmaSpace=40) | |
| # Unsharp mask | |
| blurred = cv2.GaussianBlur(upscaled, (0, 0), 2.0) | |
| upscaled = cv2.addWeighted(upscaled, 2.0, blurred, -1.0, 0) | |
| if alpha is not None: | |
| uh, uw = upscaled.shape[:2] | |
| upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4) | |
| _, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY) | |
| return cv2.merge((upscaled[:, :, 0], upscaled[:, :, 1], upscaled[:, :, 2], upscaled_alpha)) | |
| return upscaled | |