"""CorridorKey Green Screen Matting - HuggingFace Space. Self-contained Gradio app with dual inference paths: - GPU (ZeroGPU H200): PyTorch batched inference via GreenFormer - CPU (fallback): ONNX Runtime sequential inference Usage: python app.py # Launch Gradio UI python app.py --input video.mp4 # CLI mode """ import os import sys import math import shutil import gc import time import tempfile import zipfile import subprocess import logging # Thread tuning for CPU (must be set before numpy/cv2/ort import) os.environ["OMP_NUM_THREADS"] = "2" os.environ["OPENBLAS_NUM_THREADS"] = "2" os.environ["MKL_NUM_THREADS"] = "2" import numpy as np import cv2 import gradio as gr import onnxruntime as ort try: import spaces HAS_SPACES = True except ImportError: HAS_SPACES = False # Workaround: Gradio cache_examples bug with None outputs. _original_read_from_flag = gr.components.Component.read_from_flag def _patched_read_from_flag(self, payload): if payload is None or (isinstance(payload, str) and payload.strip() == ""): return None return _original_read_from_flag(self, payload) gr.components.Component.read_from_flag = _patched_read_from_flag from huggingface_hub import hf_hub_download cv2.setNumThreads(2) logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- BIREFNET_REPO = "onnx-community/BiRefNet_lite-ONNX" BIREFNET_FILE = "onnx/model.onnx" MODELS_DIR = os.path.join(os.path.dirname(__file__), "models") CORRIDORKEY_MODELS = { "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"), "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"), } CORRIDORKEY_PTH_REPO = "nikopueringer/CorridorKey_v1.0" CORRIDORKEY_PTH_FILE = "CorridorKey_v1.0.pth" IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) MAX_DURATION_CPU = 5 MAX_DURATION_GPU = 60 MAX_FRAMES = 1800 HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers() # --------------------------------------------------------------------------- # Preload model files at startup (OUTSIDE GPU function — don't waste GPU time on downloads) # --------------------------------------------------------------------------- logger.info("Preloading model files at startup...") _preloaded_birefnet_path = None _preloaded_pth_path = None try: _preloaded_birefnet_path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE) logger.info("BiRefNet cached: %s", _preloaded_birefnet_path) except Exception as e: logger.warning("BiRefNet preload failed (will retry later): %s", e) try: _preloaded_pth_path = hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE) logger.info("CorridorKey.pth cached: %s", _preloaded_pth_path) except Exception as e: logger.warning("CorridorKey.pth preload failed (will retry later): %s", e) # Batch sizes for GPU inference (conservative for H200 80GB) GPU_BATCH_SIZES = {"1024": 32, "2048": 16} # 2048 uses only 5.7GB/batch=2, so 16 easily fits in 69.8GB # --------------------------------------------------------------------------- # Color utilities (numpy-only) # --------------------------------------------------------------------------- def linear_to_srgb(x): x = np.clip(x, 0.0, None) return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055) def srgb_to_linear(x): x = np.clip(x, 0.0, None) return np.where(x <= 0.04045, x / 12.92, np.power((x + 0.055) / 1.055, 2.4)) def composite_straight(fg, bg, alpha): return fg * alpha + bg * (1.0 - alpha) def despill(image, green_limit_mode="average", strength=1.0): if strength <= 0.0: return image r, g, b = image[..., 0], image[..., 1], image[..., 2] limit = (r + b) / 2.0 if green_limit_mode == "average" else np.maximum(r, b) spill = np.maximum(g - limit, 0.0) despilled = np.stack([r + spill * 0.5, g - spill, b + spill * 0.5], axis=-1) return image * (1.0 - strength) + despilled * strength if strength < 1.0 else despilled def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5): is_3d = alpha_np.ndim == 3 if is_3d: alpha_np = alpha_np[:, :, 0] mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255 num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_8u, connectivity=8) valid = np.zeros(num_labels, dtype=bool) valid[1:] = stats[1:, cv2.CC_STAT_AREA] >= area_threshold cleaned = (valid[labels].astype(np.uint8) * 255) if dilation > 0: k = int(dilation * 2 + 1) cleaned = cv2.dilate(cleaned, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))) if blur_size > 0: b = int(blur_size * 2 + 1) cleaned = cv2.GaussianBlur(cleaned, (b, b), 0) result = alpha_np * (cleaned.astype(np.float32) / 255.0) return result[:, :, np.newaxis] if is_3d else result def create_checkerboard(w, h, checker_size=64, color1=0.15, color2=0.55): xg, yg = np.meshgrid(np.arange(w) // checker_size, np.arange(h) // checker_size) bg = np.where(((xg + yg) % 2) == 0, color1, color2).astype(np.float32) return np.stack([bg, bg, bg], axis=-1) def premultiply(fg, alpha): return fg * alpha # --------------------------------------------------------------------------- # Fast classical green-screen mask # --------------------------------------------------------------------------- def fast_greenscreen_mask(frame_rgb_f32): h, w = frame_rgb_f32.shape[:2] ph, pw = max(int(h * 0.05), 4), max(int(w * 0.05), 4) corners = np.concatenate([ frame_rgb_f32[:ph, :pw].reshape(-1, 3), frame_rgb_f32[:ph, -pw:].reshape(-1, 3), frame_rgb_f32[-ph:, :pw].reshape(-1, 3), frame_rgb_f32[-ph:, -pw:].reshape(-1, 3), ], axis=0) bg_color = np.median(corners, axis=0) if not (bg_color[1] > bg_color[0] + 0.05 and bg_color[1] > bg_color[2] + 0.05): return None, 0.0 frame_u8 = (np.clip(frame_rgb_f32, 0, 1) * 255).astype(np.uint8) hsv = cv2.cvtColor(frame_u8, cv2.COLOR_RGB2HSV) green_mask = cv2.inRange(hsv, (35, 40, 40), (85, 255, 255)) fg_mask = cv2.bitwise_not(green_mask) fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))) fg_mask = cv2.GaussianBlur(fg_mask, (5, 5), 0) mask_f32 = fg_mask.astype(np.float32) / 255.0 confidence = 1.0 - 2.0 * np.mean(np.minimum(mask_f32, 1.0 - mask_f32)) return mask_f32, confidence # --------------------------------------------------------------------------- # ONNX model loading (CPU fallback + BiRefNet) # --------------------------------------------------------------------------- _birefnet_session = None _corridorkey_sessions = {} _sessions_on_gpu = False def _get_providers(): """Get best available providers. Inside @spaces.GPU, CUDA is available.""" providers = ort.get_available_providers() if "CUDAExecutionProvider" in providers: return ["CUDAExecutionProvider", "CPUExecutionProvider"] return ["CPUExecutionProvider"] def _ort_opts(): opts = ort.SessionOptions() if "CUDAExecutionProvider" in ort.get_available_providers(): opts.intra_op_num_threads = 0 opts.inter_op_num_threads = 0 else: opts.intra_op_num_threads = 2 opts.inter_op_num_threads = 1 opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL opts.enable_mem_pattern = True return opts def _ensure_gpu_sessions(): """Reload ONNX sessions on GPU if CUDA just became available (ZeroGPU).""" global _birefnet_session, _corridorkey_sessions, _sessions_on_gpu has_cuda_now = "CUDAExecutionProvider" in ort.get_available_providers() if has_cuda_now and not _sessions_on_gpu: logger.info("CUDA available! Reloading ONNX sessions on GPU...") _birefnet_session = None _corridorkey_sessions = {} _sessions_on_gpu = True def get_birefnet(force_cpu=False): global _birefnet_session if _birefnet_session is None or force_cpu: path = _preloaded_birefnet_path or hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE) providers = ["CPUExecutionProvider"] if force_cpu else _get_providers() logger.info("Loading BiRefNet ONNX: %s (providers: %s)", path, providers) opts = _ort_opts() if force_cpu: opts.intra_op_num_threads = 2 opts.inter_op_num_threads = 1 _birefnet_session = ort.InferenceSession(path, opts, providers=providers) return _birefnet_session def get_corridorkey_onnx(resolution="1024"): global _corridorkey_sessions if resolution not in _corridorkey_sessions: onnx_path = CORRIDORKEY_MODELS.get(resolution) if not onnx_path or not os.path.exists(onnx_path): raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.") providers = _get_providers() logger.info("Loading CorridorKey ONNX (%s): %s (providers: %s)", resolution, onnx_path, providers) _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_opts(), providers=providers) return _corridorkey_sessions[resolution] # --------------------------------------------------------------------------- # PyTorch model loading (GPU path) # --------------------------------------------------------------------------- _pytorch_model = None _pytorch_model_size = None def _load_greenformer(img_size): """Load the GreenFormer PyTorch model for GPU inference.""" import torch import torch.nn.functional as F from CorridorKeyModule.core.model_transformer import GreenFormer checkpoint_path = _preloaded_pth_path or hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE) logger.info("Using checkpoint: %s", checkpoint_path) logger.info("Initializing GreenFormer (img_size=%d)...", img_size) model = GreenFormer( encoder_name="hiera_base_plus_224.mae_in1k_ft_in1k", img_size=img_size, use_refiner=True, ) # Load weights checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) state_dict = checkpoint.get("state_dict", checkpoint) # Fix compiled model prefix & handle PosEmbed mismatch new_state_dict = {} model_state = model.state_dict() for k, v in state_dict.items(): if k.startswith("_orig_mod."): k = k[10:] if "pos_embed" in k and k in model_state: if v.shape != model_state[k].shape: logger.info("Resizing %s from %s to %s", k, v.shape, model_state[k].shape) N_src = v.shape[1] C = v.shape[2] grid_src = int(math.sqrt(N_src)) grid_dst = int(math.sqrt(model_state[k].shape[1])) v_img = v.permute(0, 2, 1).view(1, C, grid_src, grid_src) v_resized = F.interpolate(v_img, size=(grid_dst, grid_dst), mode="bicubic", align_corners=False) v = v_resized.flatten(2).transpose(1, 2) new_state_dict[k] = v missing, unexpected = model.load_state_dict(new_state_dict, strict=False) if missing: logger.warning("Missing keys: %s", missing) if unexpected: logger.warning("Unexpected keys: %s", unexpected) model.eval() model = model.cuda().half() # FP16 for speed on H200 logger.info("Model loaded as FP16") try: import flash_attn logger.info("flash-attn v%s installed (prebuilt wheel)", getattr(flash_attn, '__version__', '?')) except ImportError: logger.info("flash-attn not available (using PyTorch SDPA)") logger.info("SDPA backends: flash=%s, mem_efficient=%s, math=%s", torch.backends.cuda.flash_sdp_enabled(), torch.backends.cuda.mem_efficient_sdp_enabled(), torch.backends.cuda.math_sdp_enabled()) # Skip torch.compile on ZeroGPU — the 37s warmup eats too much of the 120s budget. if not HAS_SPACES and sys.platform in ("linux", "win32"): try: compiled = torch.compile(model) dummy = torch.zeros(1, 4, img_size, img_size, dtype=torch.float16, device="cuda") with torch.inference_mode(): compiled(dummy) model = compiled logger.info("torch.compile() succeeded") except Exception as e: logger.warning("torch.compile() failed, using eager mode: %s", e) torch.cuda.empty_cache() else: logger.info("Skipping torch.compile() (ZeroGPU: saving GPU time for inference)") logger.info("GreenFormer loaded on CUDA (img_size=%d)", img_size) return model def get_pytorch_model(img_size): """Get or load the PyTorch GreenFormer model for the given resolution.""" global _pytorch_model, _pytorch_model_size if _pytorch_model is None or _pytorch_model_size != img_size: # Free old model if switching resolution if _pytorch_model is not None: import torch del _pytorch_model _pytorch_model = None torch.cuda.empty_cache() gc.collect() _pytorch_model = _load_greenformer(img_size) _pytorch_model_size = img_size return _pytorch_model # --------------------------------------------------------------------------- # Per-frame inference: ONNX (CPU fallback) # --------------------------------------------------------------------------- def birefnet_frame(session, image_rgb_uint8): h, w = image_rgb_uint8.shape[:2] inp = session.get_inputs()[0] res = (inp.shape[2], inp.shape[3]) img = cv2.resize(image_rgb_uint8, res).astype(np.float32) / 255.0 img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)[np.newaxis, :].astype(np.float32) pred = 1.0 / (1.0 + np.exp(-session.run(None, {inp.name: img})[-1])) return (cv2.resize(pred[0, 0], (w, h)) > 0.04).astype(np.float32) def corridorkey_frame_onnx(session, image_f32, mask_f32, img_size, despill_strength=0.5, auto_despeckle=True, despeckle_size=400): """ONNX inference for a single frame (CPU path).""" h, w = image_f32.shape[:2] img_r = cv2.resize(image_f32, (img_size, img_size)) mask_r = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis] inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1) inp = inp.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32) alpha_raw, fg_raw = session.run(None, {"input": inp}) alpha = cv2.resize(alpha_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) fg = cv2.resize(fg_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) if alpha.ndim == 2: alpha = alpha[:, :, np.newaxis] if auto_despeckle: alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) fg = despill(fg, green_limit_mode="average", strength=despill_strength) return {"alpha": alpha, "fg": fg} # --------------------------------------------------------------------------- # Batched inference: PyTorch (GPU path) # --------------------------------------------------------------------------- def corridorkey_batch_pytorch(model, images_f32, masks_f32, img_size, despill_strength=0.5, auto_despeckle=True, despeckle_size=400): """PyTorch batched inference for multiple frames on GPU. Args: model: GreenFormer model on CUDA images_f32: list of [H, W, 3] float32 numpy arrays (0-1, sRGB) masks_f32: list of [H, W] float32 numpy arrays (0-1) img_size: model input resolution (1024 or 2048) Returns: list of dicts with 'alpha' [H,W,1] and 'fg' [H,W,3] """ import torch batch_size = len(images_f32) if batch_size == 0: return [] # Store original sizes per frame orig_sizes = [(img.shape[1], img.shape[0]) for img in images_f32] # (w, h) # Preprocess: resize, normalize, concatenate into batch tensor batch_inputs = [] for img, mask in zip(images_f32, masks_f32): img_r = cv2.resize(img, (img_size, img_size)) mask_r = cv2.resize(mask, (img_size, img_size))[:, :, np.newaxis] inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1) batch_inputs.append(inp.transpose(2, 0, 1)) # [4, H, W] batch_np = np.stack(batch_inputs, axis=0).astype(np.float32) # [B, 4, H, W] batch_tensor = torch.from_numpy(batch_np).cuda().half() # FP16 input # Forward pass — model is FP16, input is FP16, no autocast needed with torch.inference_mode(): out = model(batch_tensor) # Extract results alphas_gpu = out["alpha"].float().cpu().numpy() # [B, 1, H, W] fgs_gpu = out["fg"].float().cpu().numpy() # [B, 3, H, W] del batch_tensor # Don't empty cache per batch - too expensive. Let PyTorch manage. # Postprocess each frame results = [] for i in range(batch_size): w, h = orig_sizes[i] alpha = cv2.resize(alphas_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) fg = cv2.resize(fgs_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) if alpha.ndim == 2: alpha = alpha[:, :, np.newaxis] if auto_despeckle: alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) fg = despill(fg, green_limit_mode="average", strength=despill_strength) results.append({"alpha": alpha, "fg": fg}) return results # --------------------------------------------------------------------------- # Video stitching # --------------------------------------------------------------------------- def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420p", codec="libx264", extra_args=None): cmd = ["ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(frame_dir, pattern), "-c:v", codec, "-pix_fmt", pix_fmt] if extra_args: cmd.extend(extra_args) cmd.append(out_path) try: subprocess.run(cmd, capture_output=True, timeout=300, check=True) return True except (FileNotFoundError, subprocess.TimeoutExpired, subprocess.CalledProcessError) as e: logger.warning("ffmpeg failed: %s", e) return False # --------------------------------------------------------------------------- # Output writing helper # --------------------------------------------------------------------------- # Fastest PNG params: compression 1 (instead of default 3) _PNG_FAST = [cv2.IMWRITE_PNG_COMPRESSION, 1] # JPEG for opaque outputs (comp/fg) — 10x faster than PNG at 4K _JPG_QUALITY = [cv2.IMWRITE_JPEG_QUALITY, 95] def _write_frame_fast(i, alpha, fg, w, h, bg_lin, comp_dir, matte_dir, fg_dir): """Fast write: comp (JPEG) + matte (PNG) + fg (JPEG). No heavy PNG/npz.""" if alpha.ndim == 2: alpha = alpha[:, :, np.newaxis] alpha_2d = alpha[:, :, 0] fg_lin = srgb_to_linear(fg) comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha)) cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"), (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"), (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"), (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST) def _write_frame_deferred(i, raw_path, w, h, bg_lin, fg_dir, processed_dir): """Deferred write: FG (JPEG) + Processed (RGBA PNG). Runs after GPU release.""" d = np.load(raw_path) alpha, fg = d["alpha"], d["fg"] if alpha.ndim == 2: alpha = alpha[:, :, np.newaxis] alpha_2d = alpha[:, :, 0] cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"), (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) fg_lin = srgb_to_linear(fg) fg_premul = premultiply(fg_lin, alpha) fg_premul_srgb = linear_to_srgb(fg_premul) fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8) a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8) rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1) cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST) os.remove(raw_path) # cleanup def _write_frame_outputs(i, alpha, fg, w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir): """Full write: all 4 outputs. Used by CPU path.""" if alpha.ndim == 2: alpha = alpha[:, :, np.newaxis] alpha_2d = alpha[:, :, 0] fg_lin = srgb_to_linear(fg) comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha)) cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"), (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"), (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"), (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST) fg_premul = premultiply(fg_lin, alpha) fg_premul_srgb = linear_to_srgb(fg_premul) fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8) a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8) rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1) cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST) # --------------------------------------------------------------------------- # Shared storage: GPU function stores results here instead of returning them. # This avoids ZeroGPU serializing gigabytes of numpy arrays on return. # --------------------------------------------------------------------------- _shared_results = {"data": None} # --------------------------------------------------------------------------- # Main pipeline # --------------------------------------------------------------------------- def _gpu_decorator(fn): if HAS_SPACES: return spaces.GPU(duration=120)(fn) return fn @_gpu_decorator def _gpu_phase(video_path, resolution, despill_val, mask_mode, auto_despeckle, despeckle_size, progress=gr.Progress(), precompute_dir=None, precompute_count=0): """ALL GPU work: load models, read video, generate masks, run inference. Returns raw numpy results in RAM. No disk I/O. """ if video_path is None: raise gr.Error("Please upload a video.") _ensure_gpu_sessions() try: import torch has_torch_cuda = torch.cuda.is_available() except ImportError: has_torch_cuda = False use_gpu = has_torch_cuda logger.info("[GPU phase] CUDA=%s, mode=%s", has_torch_cuda, "PyTorch batched" if use_gpu else "ONNX sequential") img_size = int(resolution) max_dur = MAX_DURATION_GPU if use_gpu else MAX_DURATION_CPU despill_strength = despill_val / 10.0 # Read video metadata cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() if total_frames == 0: raise gr.Error("Could not read video frames.") duration = total_frames / fps if duration > max_dur: raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s.") frames_to_process = min(total_frames, MAX_FRAMES) # Load BiRefNet only if masks need it (skip if all precomputed) birefnet = None needs_birefnet = precompute_dir is None or precompute_count == 0 if not needs_birefnet and mask_mode != "Fast (classical)": # Check if any frames need BiRefNet (missing mask files) for i in range(min(frames_to_process, precompute_count)): if not os.path.exists(os.path.join(precompute_dir, f"mask_{i:05d}.npy")): needs_birefnet = True break if needs_birefnet: progress(0.02, desc="Loading BiRefNet...") birefnet = get_birefnet() logger.info("BiRefNet loaded (needed for some frames)") else: logger.info("Skipping BiRefNet load (all masks precomputed)") batch_size = GPU_BATCH_SIZES.get(resolution, 16) if use_gpu else 1 if use_gpu: progress(0.05, desc=f"Loading GreenFormer ({resolution})...") pytorch_model = get_pytorch_model(img_size) else: progress(0.05, desc=f"Loading CorridorKey ONNX ({resolution})...") corridorkey_onnx = get_corridorkey_onnx(resolution) logger.info("[GPU phase] %d frames (%dx%d @ %.1ffps), res=%d, mask=%s, batch=%d", frames_to_process, w, h, fps, img_size, mask_mode, batch_size) # Read all frames + generate masks + run inference tmpdir = tempfile.mkdtemp(prefix="ck_") frame_times = [] total_start = time.time() try: cap = cv2.VideoCapture(video_path) if use_gpu: import torch vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 logger.info("VRAM: %.1f/%.1fGB", torch.cuda.memory_allocated() / 1024**3, vram_total) all_results = [] frame_idx = 0 # Load precomputed frames from disk (no serialization overhead) use_precomputed = precompute_dir is not None and precompute_count > 0 while frame_idx < frames_to_process: t_batch = time.time() batch_images, batch_masks, batch_indices = [], [], [] t_mask = 0 fast_n, biref_n = 0, 0 for _ in range(batch_size): if frame_idx >= frames_to_process: break if use_precomputed: frame_f32 = np.load(os.path.join(precompute_dir, f"frame_{frame_idx:05d}.npy")) mask_path = os.path.join(precompute_dir, f"mask_{frame_idx:05d}.npy") if os.path.exists(mask_path): mask = np.load(mask_path) fast_n += 1 else: # BiRefNet fallback — load original RGB, run on GPU rgb_path = os.path.join(precompute_dir, f"rgb_{frame_idx:05d}.npy") frame_rgb = np.load(rgb_path) tm = time.time() mask = birefnet_frame(birefnet, frame_rgb) t_mask += time.time() - tm biref_n += 1 else: ret, frame_bgr = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) frame_f32 = frame_rgb.astype(np.float32) / 255.0 tm = time.time() if mask_mode == "Fast (classical)": mask, _ = fast_greenscreen_mask(frame_f32) fast_n += 1 elif mask_mode == "Hybrid (auto)": mask, conf = fast_greenscreen_mask(frame_f32) if mask is None or conf < 0.7: mask = birefnet_frame(birefnet, frame_rgb) biref_n += 1 else: fast_n += 1 else: mask = birefnet_frame(birefnet, frame_rgb) biref_n += 1 t_mask += time.time() - tm batch_images.append(frame_f32) batch_masks.append(mask) batch_indices.append(frame_idx) frame_idx += 1 if not batch_images: break # Batched GPU inference t_inf = time.time() results = corridorkey_batch_pytorch( pytorch_model, batch_images, batch_masks, img_size, despill_strength=despill_strength, auto_despeckle=auto_despeckle, despeckle_size=int(despeckle_size), ) t_inf = time.time() - t_inf for j, result in enumerate(results): all_results.append((batch_indices[j], result["alpha"], result["fg"])) n = len(batch_images) elapsed = time.time() - t_batch vram_peak = torch.cuda.max_memory_allocated() / 1024**3 logger.info("Batch %d: mask=%.1fs(fast=%d,biref=%d) infer=%.1fs total=%.1fs(%.2fs/fr) VRAM=%.1fGB", n, t_mask, fast_n, biref_n, t_inf, elapsed, elapsed/n, vram_peak) per_frame = elapsed / n frame_times.extend([per_frame] * n) remaining = (frames_to_process - frame_idx) * (np.mean(frame_times[-20:]) if len(frame_times) > 1 else per_frame) progress(0.10 + 0.75 * frame_idx / frames_to_process, desc=f"Frame {frame_idx}/{frames_to_process} ({per_frame:.2f}s/fr) ~{remaining:.0f}s left") cap.release() gpu_elapsed = time.time() - total_start logger.info("[GPU phase] done: %d frames in %.1fs (%.2fs/fr)", len(all_results), gpu_elapsed, gpu_elapsed / max(len(all_results), 1)) # FAST WRITE inside GPU: only comp (JPEG) + matte (PNG) + raw numpy. # FG + Processed written AFTER GPU release (deferred). from concurrent.futures import ThreadPoolExecutor bg_lin = srgb_to_linear(create_checkerboard(w, h)) comp_dir = os.path.join(tmpdir, "Comp") matte_dir = os.path.join(tmpdir, "Matte") fg_dir = os.path.join(tmpdir, "FG") processed_dir = os.path.join(tmpdir, "Processed") for d in [comp_dir, fg_dir, matte_dir, processed_dir]: os.makedirs(d, exist_ok=True) t_write = time.time() progress(0.86, desc="Writing preview frames...") with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as pool: futs = [pool.submit(_write_frame_fast, idx, alpha, fg, w, h, bg_lin, comp_dir, matte_dir, fg_dir) for idx, alpha, fg in all_results] for f in futs: f.result() del all_results gc.collect() logger.info("[GPU phase] Fast write in %.1fs", time.time() - t_write) return { "results": "written", "frame_times": frame_times, "use_gpu": True, "batch_size": batch_size, "w": w, "h": h, "fps": fps, "tmpdir": tmpdir, } else: # CPU PATH: sequential ONNX + inline writes (no GPU budget concern) bg_lin = srgb_to_linear(create_checkerboard(w, h)) comp_dir, fg_dir = os.path.join(tmpdir, "Comp"), os.path.join(tmpdir, "FG") matte_dir, processed_dir = os.path.join(tmpdir, "Matte"), os.path.join(tmpdir, "Processed") for d in [comp_dir, fg_dir, matte_dir, processed_dir]: os.makedirs(d, exist_ok=True) for i in range(frames_to_process): t0 = time.time() ret, frame_bgr = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) frame_f32 = frame_rgb.astype(np.float32) / 255.0 if mask_mode == "Fast (classical)": mask, _ = fast_greenscreen_mask(frame_f32) if mask is None: raise gr.Error("Fast mask failed. Try 'AI (BiRefNet)' mode.") elif mask_mode == "Hybrid (auto)": mask, conf = fast_greenscreen_mask(frame_f32) if mask is None or conf < 0.7: mask = birefnet_frame(birefnet, frame_rgb) else: mask = birefnet_frame(birefnet, frame_rgb) result = corridorkey_frame_onnx(corridorkey_onnx, frame_f32, mask, img_size, despill_strength=despill_strength, auto_despeckle=auto_despeckle, despeckle_size=int(despeckle_size)) _write_frame_outputs(i, result["alpha"], result["fg"], w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir) elapsed = time.time() - t0 frame_times.append(elapsed) remaining = (frames_to_process - i - 1) * (np.mean(frame_times[-5:]) if len(frame_times) > 1 else elapsed) progress(0.10 + 0.80 * (i+1) / frames_to_process, desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) ~{remaining:.0f}s left") cap.release() return { "results": None, "frame_times": frame_times, "use_gpu": False, "batch_size": 1, "w": w, "h": h, "fps": fps, "tmpdir": tmpdir, } except gr.Error: raise except Exception as e: logger.exception("Inference failed") raise gr.Error(f"Inference failed: {e}") def process_video(video_path, resolution, despill_val, mask_mode, auto_despeckle, despeckle_size, progress=gr.Progress()): """Orchestrator: precompute fast masks (CPU) → GPU inference → CPU I/O.""" if video_path is None: raise gr.Error("Please upload a video.") # Phase 0: Precompute fast masks on CPU and save to disk. # IMPORTANT: Can't pass large data as args to @spaces.GPU (ZeroGPU serializes args). # Save to a numpy file, pass only the path. logger.info("[Phase 0] Precomputing fast masks on CPU") t_mask = time.time() precompute_dir = tempfile.mkdtemp(prefix="ck_pre_") cap = cv2.VideoCapture(video_path) frame_count = 0 needs_birefnet = False while True: ret, frame_bgr = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) frame_f32 = frame_rgb.astype(np.float32) / 255.0 if mask_mode == "Fast (classical)": mask, _ = fast_greenscreen_mask(frame_f32) if mask is None: raise gr.Error("Fast mask failed. Try 'Hybrid' or 'AI' mode.") elif mask_mode == "Hybrid (auto)": mask, conf = fast_greenscreen_mask(frame_f32) if mask is None or conf < 0.7: mask = None needs_birefnet = True else: mask = None needs_birefnet = True # Save as compressed numpy (fast to load, no serialization overhead) np.save(os.path.join(precompute_dir, f"frame_{frame_count:05d}.npy"), frame_f32) if mask is not None: np.save(os.path.join(precompute_dir, f"mask_{frame_count:05d}.npy"), mask) if mask is None: np.save(os.path.join(precompute_dir, f"rgb_{frame_count:05d}.npy"), frame_rgb) frame_count += 1 cap.release() logger.info("[Phase 0] %d frames saved to %s in %.1fs (needs_birefnet=%s)", frame_count, precompute_dir, time.time() - t_mask, needs_birefnet) # Phase 1: GPU inference — pass only paths (tiny strings), not data logger.info("[Phase 1] Starting GPU phase") t0 = time.time() data = _gpu_phase(video_path, resolution, despill_val, mask_mode, auto_despeckle, despeckle_size, progress, precompute_dir=precompute_dir, precompute_count=frame_count) logger.info("[process_video] GPU phase done in %.1fs", time.time() - t0) tmpdir = data["tmpdir"] w, h, fps = data["w"], data["h"], data["fps"] frame_times = data["frame_times"] use_gpu = data["use_gpu"] batch_size = data["batch_size"] comp_dir = os.path.join(tmpdir, "Comp") fg_dir = os.path.join(tmpdir, "FG") matte_dir = os.path.join(tmpdir, "Matte") processed_dir = os.path.join(tmpdir, "Processed") for d in [comp_dir, fg_dir, matte_dir, processed_dir]: os.makedirs(d, exist_ok=True) try: from concurrent.futures import ThreadPoolExecutor logger.info("[Phase 2] Frames written by GPU/CPU phase (comp+fg+matte)") # Phase 3: stitch videos from written frames logger.info("[Phase 3] Stitching videos") progress(0.93, desc="Stitching videos...") comp_video = os.path.join(tmpdir, "comp_preview.mp4") matte_video = os.path.join(tmpdir, "matte_preview.mp4") # Comp uses JPEG, Matte uses PNG _stitch_ffmpeg(comp_dir, comp_video, fps, pattern="%05d.jpg", extra_args=["-crf", "18"]) _stitch_ffmpeg(matte_dir, matte_video, fps, pattern="%05d.png", extra_args=["-crf", "18"]) # Phase 4: ZIP (no GPU) logger.info("[Phase 4] Packaging ZIP") progress(0.96, desc="Packaging ZIP...") zip_path = os.path.join(tmpdir, "CorridorKey_Output.zip") with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf: for folder in ["Comp", "FG", "Matte", "Processed"]: src = os.path.join(tmpdir, folder) if os.path.isdir(src): for f in sorted(os.listdir(src)): zf.write(os.path.join(src, f), f"Output/{folder}/{f}") progress(1.0, desc="Done!") total_elapsed = sum(frame_times) if frame_times else 0 n = len(frame_times) avg = np.mean(frame_times) if frame_times else 0 engine = "PyTorch GPU" if use_gpu else "ONNX CPU" status = (f"Processed {n} frames ({w}x{h}) at {resolution}px | " f"{avg:.2f}s/frame | {engine}" + (f" batch={batch_size}" if use_gpu else "")) return ( comp_video if os.path.exists(comp_video) else None, matte_video if os.path.exists(matte_video) else None, zip_path, status, ) except gr.Error: raise except Exception as e: logger.exception("Output writing failed") raise gr.Error(f"Output failed: {e}") finally: for d in ["Comp", "FG", "Matte", "Processed"]: p = os.path.join(tmpdir, d) if os.path.isdir(p): shutil.rmtree(p, ignore_errors=True) gc.collect() # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size): return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size) DESCRIPTION = """# CorridorKey Green Screen Matting Remove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. ZeroGPU H200: batched PyTorch inference (up to 32 frames at once). CPU fallback via ONNX.""" with gr.Blocks(title="CorridorKey") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_video = gr.Video(label="Upload Green Screen Video") with gr.Accordion("Settings", open=True): resolution = gr.Radio( choices=["1024", "2048"], value="1024", label="Processing Resolution", info="1024 = fast (batch 32 on GPU), 2048 = max quality (batch 8 on GPU)" ) mask_mode = gr.Radio( choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"], value="Hybrid (auto)", label="Mask Mode", info="Hybrid = fast green detection + AI fallback. Fast = classical only. AI = always BiRefNet" ) despill_slider = gr.Slider( 0, 10, value=5, step=1, label="Despill Strength", info="Remove green reflections (0=off, 10=max)" ) despeckle_check = gr.Checkbox( value=True, label="Auto Despeckle", info="Remove small disconnected artifacts" ) despeckle_size = gr.Number( value=400, precision=0, label="Despeckle Size", info="Min pixel area to keep" ) process_btn = gr.Button("Process Video", variant="primary", size="lg") with gr.Column(scale=1): with gr.Row(): comp_video = gr.Video(label="Composite Preview") matte_video = gr.Video(label="Alpha Matte") download_zip = gr.File(label="Download Full Package (Comp + FG + Matte + Processed)") status_text = gr.Textbox(label="Status", interactive=False) gr.Examples( examples=[ ["examples/corridor_greenscreen_demo.mp4", "1024", 5, "Hybrid (auto)", True, 400], ], inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size], outputs=[comp_video, matte_video, download_zip, status_text], fn=process_example, cache_examples=True, cache_mode="lazy", label="Examples (click to load)" ) process_btn.click( fn=process_video, inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size], outputs=[comp_video, matte_video, download_zip, status_text], ) # --------------------------------------------------------------------------- # CLI mode # --------------------------------------------------------------------------- def cli_main(): import argparse parser = argparse.ArgumentParser(description="CorridorKey Green Screen Matting") parser.add_argument("--input", required=True) parser.add_argument("--output", default="output") parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"]) parser.add_argument("--resolution", default="1024", choices=["1024", "2048"]) parser.add_argument("--mask-mode", default="Hybrid (auto)", choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"]) parser.add_argument("--despill", type=int, default=5) parser.add_argument("--no-despeckle", action="store_true") parser.add_argument("--despeckle-size", type=int, default=400) args = parser.parse_args() global HAS_CUDA if args.device == "cpu": HAS_CUDA = False elif args.device == "cuda": HAS_CUDA = True print(f"Device: {'CUDA' if HAS_CUDA else 'CPU'}") class CLIProgress: def __call__(self, val, desc=""): if desc: print(f" [{val:.0%}] {desc}") comp, matte, zipf, status = process_video( args.input, args.resolution, args.despill, args.mask_mode, not args.no_despeckle, args.despeckle_size, progress=CLIProgress() ) print(f"\n{status}") os.makedirs(args.output, exist_ok=True) if zipf: dst = os.path.join(args.output, os.path.basename(zipf)) shutil.copy2(zipf, dst) print(f"Output: {dst}") if __name__ == "__main__": if len(sys.argv) > 1 and "--input" in sys.argv: cli_main() else: demo.queue(default_concurrency_limit=1) demo.launch(ssr_mode=False, mcp_server=True)