CorridorKey / app.py
Nekochu's picture
add ZeroGPU GPU inference (FP16, flash-attn, batch=32@1024/16@2048)
0b6961f
"""CorridorKey Green Screen Matting - HuggingFace Space.
Self-contained Gradio app with dual inference paths:
- GPU (ZeroGPU H200): PyTorch batched inference via GreenFormer
- CPU (fallback): ONNX Runtime sequential inference
Usage:
python app.py # Launch Gradio UI
python app.py --input video.mp4 # CLI mode
"""
import os
import sys
import math
import shutil
import gc
import time
import tempfile
import zipfile
import subprocess
import logging
# Thread tuning for CPU (must be set before numpy/cv2/ort import)
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["OPENBLAS_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
import numpy as np
import cv2
import gradio as gr
import onnxruntime as ort
try:
import spaces
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
# Workaround: Gradio cache_examples bug with None outputs.
_original_read_from_flag = gr.components.Component.read_from_flag
def _patched_read_from_flag(self, payload):
if payload is None or (isinstance(payload, str) and payload.strip() == ""):
return None
return _original_read_from_flag(self, payload)
gr.components.Component.read_from_flag = _patched_read_from_flag
from huggingface_hub import hf_hub_download
cv2.setNumThreads(2)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
BIREFNET_REPO = "onnx-community/BiRefNet_lite-ONNX"
BIREFNET_FILE = "onnx/model.onnx"
MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
CORRIDORKEY_MODELS = {
"1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"),
"2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"),
}
CORRIDORKEY_PTH_REPO = "nikopueringer/CorridorKey_v1.0"
CORRIDORKEY_PTH_FILE = "CorridorKey_v1.0.pth"
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
MAX_DURATION_CPU = 5
MAX_DURATION_GPU = 60
MAX_FRAMES = 1800
HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers()
# ---------------------------------------------------------------------------
# Preload model files at startup (OUTSIDE GPU function — don't waste GPU time on downloads)
# ---------------------------------------------------------------------------
logger.info("Preloading model files at startup...")
_preloaded_birefnet_path = None
_preloaded_pth_path = None
try:
_preloaded_birefnet_path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
logger.info("BiRefNet cached: %s", _preloaded_birefnet_path)
except Exception as e:
logger.warning("BiRefNet preload failed (will retry later): %s", e)
try:
_preloaded_pth_path = hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE)
logger.info("CorridorKey.pth cached: %s", _preloaded_pth_path)
except Exception as e:
logger.warning("CorridorKey.pth preload failed (will retry later): %s", e)
# Batch sizes for GPU inference (conservative for H200 80GB)
GPU_BATCH_SIZES = {"1024": 32, "2048": 16} # 2048 uses only 5.7GB/batch=2, so 16 easily fits in 69.8GB
# ---------------------------------------------------------------------------
# Color utilities (numpy-only)
# ---------------------------------------------------------------------------
def linear_to_srgb(x):
x = np.clip(x, 0.0, None)
return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055)
def srgb_to_linear(x):
x = np.clip(x, 0.0, None)
return np.where(x <= 0.04045, x / 12.92, np.power((x + 0.055) / 1.055, 2.4))
def composite_straight(fg, bg, alpha):
return fg * alpha + bg * (1.0 - alpha)
def despill(image, green_limit_mode="average", strength=1.0):
if strength <= 0.0:
return image
r, g, b = image[..., 0], image[..., 1], image[..., 2]
limit = (r + b) / 2.0 if green_limit_mode == "average" else np.maximum(r, b)
spill = np.maximum(g - limit, 0.0)
despilled = np.stack([r + spill * 0.5, g - spill, b + spill * 0.5], axis=-1)
return image * (1.0 - strength) + despilled * strength if strength < 1.0 else despilled
def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5):
is_3d = alpha_np.ndim == 3
if is_3d:
alpha_np = alpha_np[:, :, 0]
mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_8u, connectivity=8)
valid = np.zeros(num_labels, dtype=bool)
valid[1:] = stats[1:, cv2.CC_STAT_AREA] >= area_threshold
cleaned = (valid[labels].astype(np.uint8) * 255)
if dilation > 0:
k = int(dilation * 2 + 1)
cleaned = cv2.dilate(cleaned, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)))
if blur_size > 0:
b = int(blur_size * 2 + 1)
cleaned = cv2.GaussianBlur(cleaned, (b, b), 0)
result = alpha_np * (cleaned.astype(np.float32) / 255.0)
return result[:, :, np.newaxis] if is_3d else result
def create_checkerboard(w, h, checker_size=64, color1=0.15, color2=0.55):
xg, yg = np.meshgrid(np.arange(w) // checker_size, np.arange(h) // checker_size)
bg = np.where(((xg + yg) % 2) == 0, color1, color2).astype(np.float32)
return np.stack([bg, bg, bg], axis=-1)
def premultiply(fg, alpha):
return fg * alpha
# ---------------------------------------------------------------------------
# Fast classical green-screen mask
# ---------------------------------------------------------------------------
def fast_greenscreen_mask(frame_rgb_f32):
h, w = frame_rgb_f32.shape[:2]
ph, pw = max(int(h * 0.05), 4), max(int(w * 0.05), 4)
corners = np.concatenate([
frame_rgb_f32[:ph, :pw].reshape(-1, 3),
frame_rgb_f32[:ph, -pw:].reshape(-1, 3),
frame_rgb_f32[-ph:, :pw].reshape(-1, 3),
frame_rgb_f32[-ph:, -pw:].reshape(-1, 3),
], axis=0)
bg_color = np.median(corners, axis=0)
if not (bg_color[1] > bg_color[0] + 0.05 and bg_color[1] > bg_color[2] + 0.05):
return None, 0.0
frame_u8 = (np.clip(frame_rgb_f32, 0, 1) * 255).astype(np.uint8)
hsv = cv2.cvtColor(frame_u8, cv2.COLOR_RGB2HSV)
green_mask = cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))
fg_mask = cv2.bitwise_not(green_mask)
fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
fg_mask = cv2.GaussianBlur(fg_mask, (5, 5), 0)
mask_f32 = fg_mask.astype(np.float32) / 255.0
confidence = 1.0 - 2.0 * np.mean(np.minimum(mask_f32, 1.0 - mask_f32))
return mask_f32, confidence
# ---------------------------------------------------------------------------
# ONNX model loading (CPU fallback + BiRefNet)
# ---------------------------------------------------------------------------
_birefnet_session = None
_corridorkey_sessions = {}
_sessions_on_gpu = False
def _get_providers():
"""Get best available providers. Inside @spaces.GPU, CUDA is available."""
providers = ort.get_available_providers()
if "CUDAExecutionProvider" in providers:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
def _ort_opts():
opts = ort.SessionOptions()
if "CUDAExecutionProvider" in ort.get_available_providers():
opts.intra_op_num_threads = 0
opts.inter_op_num_threads = 0
else:
opts.intra_op_num_threads = 2
opts.inter_op_num_threads = 1
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
opts.enable_mem_pattern = True
return opts
def _ensure_gpu_sessions():
"""Reload ONNX sessions on GPU if CUDA just became available (ZeroGPU)."""
global _birefnet_session, _corridorkey_sessions, _sessions_on_gpu
has_cuda_now = "CUDAExecutionProvider" in ort.get_available_providers()
if has_cuda_now and not _sessions_on_gpu:
logger.info("CUDA available! Reloading ONNX sessions on GPU...")
_birefnet_session = None
_corridorkey_sessions = {}
_sessions_on_gpu = True
def get_birefnet(force_cpu=False):
global _birefnet_session
if _birefnet_session is None or force_cpu:
path = _preloaded_birefnet_path or hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
providers = ["CPUExecutionProvider"] if force_cpu else _get_providers()
logger.info("Loading BiRefNet ONNX: %s (providers: %s)", path, providers)
opts = _ort_opts()
if force_cpu:
opts.intra_op_num_threads = 2
opts.inter_op_num_threads = 1
_birefnet_session = ort.InferenceSession(path, opts, providers=providers)
return _birefnet_session
def get_corridorkey_onnx(resolution="1024"):
global _corridorkey_sessions
if resolution not in _corridorkey_sessions:
onnx_path = CORRIDORKEY_MODELS.get(resolution)
if not onnx_path or not os.path.exists(onnx_path):
raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.")
providers = _get_providers()
logger.info("Loading CorridorKey ONNX (%s): %s (providers: %s)", resolution, onnx_path, providers)
_corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_opts(), providers=providers)
return _corridorkey_sessions[resolution]
# ---------------------------------------------------------------------------
# PyTorch model loading (GPU path)
# ---------------------------------------------------------------------------
_pytorch_model = None
_pytorch_model_size = None
def _load_greenformer(img_size):
"""Load the GreenFormer PyTorch model for GPU inference."""
import torch
import torch.nn.functional as F
from CorridorKeyModule.core.model_transformer import GreenFormer
checkpoint_path = _preloaded_pth_path or hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE)
logger.info("Using checkpoint: %s", checkpoint_path)
logger.info("Initializing GreenFormer (img_size=%d)...", img_size)
model = GreenFormer(
encoder_name="hiera_base_plus_224.mae_in1k_ft_in1k",
img_size=img_size,
use_refiner=True,
)
# Load weights
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
state_dict = checkpoint.get("state_dict", checkpoint)
# Fix compiled model prefix & handle PosEmbed mismatch
new_state_dict = {}
model_state = model.state_dict()
for k, v in state_dict.items():
if k.startswith("_orig_mod."):
k = k[10:]
if "pos_embed" in k and k in model_state:
if v.shape != model_state[k].shape:
logger.info("Resizing %s from %s to %s", k, v.shape, model_state[k].shape)
N_src = v.shape[1]
C = v.shape[2]
grid_src = int(math.sqrt(N_src))
grid_dst = int(math.sqrt(model_state[k].shape[1]))
v_img = v.permute(0, 2, 1).view(1, C, grid_src, grid_src)
v_resized = F.interpolate(v_img, size=(grid_dst, grid_dst), mode="bicubic", align_corners=False)
v = v_resized.flatten(2).transpose(1, 2)
new_state_dict[k] = v
missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
if missing:
logger.warning("Missing keys: %s", missing)
if unexpected:
logger.warning("Unexpected keys: %s", unexpected)
model.eval()
model = model.cuda().half() # FP16 for speed on H200
logger.info("Model loaded as FP16")
try:
import flash_attn
logger.info("flash-attn v%s installed (prebuilt wheel)", getattr(flash_attn, '__version__', '?'))
except ImportError:
logger.info("flash-attn not available (using PyTorch SDPA)")
logger.info("SDPA backends: flash=%s, mem_efficient=%s, math=%s",
torch.backends.cuda.flash_sdp_enabled(),
torch.backends.cuda.mem_efficient_sdp_enabled(),
torch.backends.cuda.math_sdp_enabled())
# Skip torch.compile on ZeroGPU — the 37s warmup eats too much of the 120s budget.
if not HAS_SPACES and sys.platform in ("linux", "win32"):
try:
compiled = torch.compile(model)
dummy = torch.zeros(1, 4, img_size, img_size, dtype=torch.float16, device="cuda")
with torch.inference_mode():
compiled(dummy)
model = compiled
logger.info("torch.compile() succeeded")
except Exception as e:
logger.warning("torch.compile() failed, using eager mode: %s", e)
torch.cuda.empty_cache()
else:
logger.info("Skipping torch.compile() (ZeroGPU: saving GPU time for inference)")
logger.info("GreenFormer loaded on CUDA (img_size=%d)", img_size)
return model
def get_pytorch_model(img_size):
"""Get or load the PyTorch GreenFormer model for the given resolution."""
global _pytorch_model, _pytorch_model_size
if _pytorch_model is None or _pytorch_model_size != img_size:
# Free old model if switching resolution
if _pytorch_model is not None:
import torch
del _pytorch_model
_pytorch_model = None
torch.cuda.empty_cache()
gc.collect()
_pytorch_model = _load_greenformer(img_size)
_pytorch_model_size = img_size
return _pytorch_model
# ---------------------------------------------------------------------------
# Per-frame inference: ONNX (CPU fallback)
# ---------------------------------------------------------------------------
def birefnet_frame(session, image_rgb_uint8):
h, w = image_rgb_uint8.shape[:2]
inp = session.get_inputs()[0]
res = (inp.shape[2], inp.shape[3])
img = cv2.resize(image_rgb_uint8, res).astype(np.float32) / 255.0
img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
pred = 1.0 / (1.0 + np.exp(-session.run(None, {inp.name: img})[-1]))
return (cv2.resize(pred[0, 0], (w, h)) > 0.04).astype(np.float32)
def corridorkey_frame_onnx(session, image_f32, mask_f32, img_size,
despill_strength=0.5, auto_despeckle=True, despeckle_size=400):
"""ONNX inference for a single frame (CPU path)."""
h, w = image_f32.shape[:2]
img_r = cv2.resize(image_f32, (img_size, img_size))
mask_r = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis]
inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1)
inp = inp.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
alpha_raw, fg_raw = session.run(None, {"input": inp})
alpha = cv2.resize(alpha_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
fg = cv2.resize(fg_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
if alpha.ndim == 2:
alpha = alpha[:, :, np.newaxis]
if auto_despeckle:
alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
fg = despill(fg, green_limit_mode="average", strength=despill_strength)
return {"alpha": alpha, "fg": fg}
# ---------------------------------------------------------------------------
# Batched inference: PyTorch (GPU path)
# ---------------------------------------------------------------------------
def corridorkey_batch_pytorch(model, images_f32, masks_f32, img_size,
despill_strength=0.5, auto_despeckle=True, despeckle_size=400):
"""PyTorch batched inference for multiple frames on GPU.
Args:
model: GreenFormer model on CUDA
images_f32: list of [H, W, 3] float32 numpy arrays (0-1, sRGB)
masks_f32: list of [H, W] float32 numpy arrays (0-1)
img_size: model input resolution (1024 or 2048)
Returns:
list of dicts with 'alpha' [H,W,1] and 'fg' [H,W,3]
"""
import torch
batch_size = len(images_f32)
if batch_size == 0:
return []
# Store original sizes per frame
orig_sizes = [(img.shape[1], img.shape[0]) for img in images_f32] # (w, h)
# Preprocess: resize, normalize, concatenate into batch tensor
batch_inputs = []
for img, mask in zip(images_f32, masks_f32):
img_r = cv2.resize(img, (img_size, img_size))
mask_r = cv2.resize(mask, (img_size, img_size))[:, :, np.newaxis]
inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1)
batch_inputs.append(inp.transpose(2, 0, 1)) # [4, H, W]
batch_np = np.stack(batch_inputs, axis=0).astype(np.float32) # [B, 4, H, W]
batch_tensor = torch.from_numpy(batch_np).cuda().half() # FP16 input
# Forward pass — model is FP16, input is FP16, no autocast needed
with torch.inference_mode():
out = model(batch_tensor)
# Extract results
alphas_gpu = out["alpha"].float().cpu().numpy() # [B, 1, H, W]
fgs_gpu = out["fg"].float().cpu().numpy() # [B, 3, H, W]
del batch_tensor
# Don't empty cache per batch - too expensive. Let PyTorch manage.
# Postprocess each frame
results = []
for i in range(batch_size):
w, h = orig_sizes[i]
alpha = cv2.resize(alphas_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
fg = cv2.resize(fgs_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
if alpha.ndim == 2:
alpha = alpha[:, :, np.newaxis]
if auto_despeckle:
alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
fg = despill(fg, green_limit_mode="average", strength=despill_strength)
results.append({"alpha": alpha, "fg": fg})
return results
# ---------------------------------------------------------------------------
# Video stitching
# ---------------------------------------------------------------------------
def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420p",
codec="libx264", extra_args=None):
cmd = ["ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(frame_dir, pattern),
"-c:v", codec, "-pix_fmt", pix_fmt]
if extra_args:
cmd.extend(extra_args)
cmd.append(out_path)
try:
subprocess.run(cmd, capture_output=True, timeout=300, check=True)
return True
except (FileNotFoundError, subprocess.TimeoutExpired, subprocess.CalledProcessError) as e:
logger.warning("ffmpeg failed: %s", e)
return False
# ---------------------------------------------------------------------------
# Output writing helper
# ---------------------------------------------------------------------------
# Fastest PNG params: compression 1 (instead of default 3)
_PNG_FAST = [cv2.IMWRITE_PNG_COMPRESSION, 1]
# JPEG for opaque outputs (comp/fg) — 10x faster than PNG at 4K
_JPG_QUALITY = [cv2.IMWRITE_JPEG_QUALITY, 95]
def _write_frame_fast(i, alpha, fg, w, h, bg_lin, comp_dir, matte_dir, fg_dir):
"""Fast write: comp (JPEG) + matte (PNG) + fg (JPEG). No heavy PNG/npz."""
if alpha.ndim == 2:
alpha = alpha[:, :, np.newaxis]
alpha_2d = alpha[:, :, 0]
fg_lin = srgb_to_linear(fg)
comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"),
(np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"),
(np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"),
(np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST)
def _write_frame_deferred(i, raw_path, w, h, bg_lin, fg_dir, processed_dir):
"""Deferred write: FG (JPEG) + Processed (RGBA PNG). Runs after GPU release."""
d = np.load(raw_path)
alpha, fg = d["alpha"], d["fg"]
if alpha.ndim == 2:
alpha = alpha[:, :, np.newaxis]
alpha_2d = alpha[:, :, 0]
cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"),
(np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
fg_lin = srgb_to_linear(fg)
fg_premul = premultiply(fg_lin, alpha)
fg_premul_srgb = linear_to_srgb(fg_premul)
fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8)
a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8)
rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1)
cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST)
os.remove(raw_path) # cleanup
def _write_frame_outputs(i, alpha, fg, w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir):
"""Full write: all 4 outputs. Used by CPU path."""
if alpha.ndim == 2:
alpha = alpha[:, :, np.newaxis]
alpha_2d = alpha[:, :, 0]
fg_lin = srgb_to_linear(fg)
comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"),
(np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"),
(np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"),
(np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST)
fg_premul = premultiply(fg_lin, alpha)
fg_premul_srgb = linear_to_srgb(fg_premul)
fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8)
a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8)
rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1)
cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST)
# ---------------------------------------------------------------------------
# Shared storage: GPU function stores results here instead of returning them.
# This avoids ZeroGPU serializing gigabytes of numpy arrays on return.
# ---------------------------------------------------------------------------
_shared_results = {"data": None}
# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------
def _gpu_decorator(fn):
if HAS_SPACES:
return spaces.GPU(duration=120)(fn)
return fn
@_gpu_decorator
def _gpu_phase(video_path, resolution, despill_val, mask_mode,
auto_despeckle, despeckle_size, progress=gr.Progress(),
precompute_dir=None, precompute_count=0):
"""ALL GPU work: load models, read video, generate masks, run inference.
Returns raw numpy results in RAM. No disk I/O.
"""
if video_path is None:
raise gr.Error("Please upload a video.")
_ensure_gpu_sessions()
try:
import torch
has_torch_cuda = torch.cuda.is_available()
except ImportError:
has_torch_cuda = False
use_gpu = has_torch_cuda
logger.info("[GPU phase] CUDA=%s, mode=%s", has_torch_cuda,
"PyTorch batched" if use_gpu else "ONNX sequential")
img_size = int(resolution)
max_dur = MAX_DURATION_GPU if use_gpu else MAX_DURATION_CPU
despill_strength = despill_val / 10.0
# Read video metadata
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if total_frames == 0:
raise gr.Error("Could not read video frames.")
duration = total_frames / fps
if duration > max_dur:
raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s.")
frames_to_process = min(total_frames, MAX_FRAMES)
# Load BiRefNet only if masks need it (skip if all precomputed)
birefnet = None
needs_birefnet = precompute_dir is None or precompute_count == 0
if not needs_birefnet and mask_mode != "Fast (classical)":
# Check if any frames need BiRefNet (missing mask files)
for i in range(min(frames_to_process, precompute_count)):
if not os.path.exists(os.path.join(precompute_dir, f"mask_{i:05d}.npy")):
needs_birefnet = True
break
if needs_birefnet:
progress(0.02, desc="Loading BiRefNet...")
birefnet = get_birefnet()
logger.info("BiRefNet loaded (needed for some frames)")
else:
logger.info("Skipping BiRefNet load (all masks precomputed)")
batch_size = GPU_BATCH_SIZES.get(resolution, 16) if use_gpu else 1
if use_gpu:
progress(0.05, desc=f"Loading GreenFormer ({resolution})...")
pytorch_model = get_pytorch_model(img_size)
else:
progress(0.05, desc=f"Loading CorridorKey ONNX ({resolution})...")
corridorkey_onnx = get_corridorkey_onnx(resolution)
logger.info("[GPU phase] %d frames (%dx%d @ %.1ffps), res=%d, mask=%s, batch=%d",
frames_to_process, w, h, fps, img_size, mask_mode, batch_size)
# Read all frames + generate masks + run inference
tmpdir = tempfile.mkdtemp(prefix="ck_")
frame_times = []
total_start = time.time()
try:
cap = cv2.VideoCapture(video_path)
if use_gpu:
import torch
vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
logger.info("VRAM: %.1f/%.1fGB",
torch.cuda.memory_allocated() / 1024**3, vram_total)
all_results = []
frame_idx = 0
# Load precomputed frames from disk (no serialization overhead)
use_precomputed = precompute_dir is not None and precompute_count > 0
while frame_idx < frames_to_process:
t_batch = time.time()
batch_images, batch_masks, batch_indices = [], [], []
t_mask = 0
fast_n, biref_n = 0, 0
for _ in range(batch_size):
if frame_idx >= frames_to_process:
break
if use_precomputed:
frame_f32 = np.load(os.path.join(precompute_dir, f"frame_{frame_idx:05d}.npy"))
mask_path = os.path.join(precompute_dir, f"mask_{frame_idx:05d}.npy")
if os.path.exists(mask_path):
mask = np.load(mask_path)
fast_n += 1
else:
# BiRefNet fallback — load original RGB, run on GPU
rgb_path = os.path.join(precompute_dir, f"rgb_{frame_idx:05d}.npy")
frame_rgb = np.load(rgb_path)
tm = time.time()
mask = birefnet_frame(birefnet, frame_rgb)
t_mask += time.time() - tm
biref_n += 1
else:
ret, frame_bgr = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
frame_f32 = frame_rgb.astype(np.float32) / 255.0
tm = time.time()
if mask_mode == "Fast (classical)":
mask, _ = fast_greenscreen_mask(frame_f32)
fast_n += 1
elif mask_mode == "Hybrid (auto)":
mask, conf = fast_greenscreen_mask(frame_f32)
if mask is None or conf < 0.7:
mask = birefnet_frame(birefnet, frame_rgb)
biref_n += 1
else:
fast_n += 1
else:
mask = birefnet_frame(birefnet, frame_rgb)
biref_n += 1
t_mask += time.time() - tm
batch_images.append(frame_f32)
batch_masks.append(mask)
batch_indices.append(frame_idx)
frame_idx += 1
if not batch_images:
break
# Batched GPU inference
t_inf = time.time()
results = corridorkey_batch_pytorch(
pytorch_model, batch_images, batch_masks, img_size,
despill_strength=despill_strength,
auto_despeckle=auto_despeckle,
despeckle_size=int(despeckle_size),
)
t_inf = time.time() - t_inf
for j, result in enumerate(results):
all_results.append((batch_indices[j], result["alpha"], result["fg"]))
n = len(batch_images)
elapsed = time.time() - t_batch
vram_peak = torch.cuda.max_memory_allocated() / 1024**3
logger.info("Batch %d: mask=%.1fs(fast=%d,biref=%d) infer=%.1fs total=%.1fs(%.2fs/fr) VRAM=%.1fGB",
n, t_mask, fast_n, biref_n, t_inf, elapsed, elapsed/n, vram_peak)
per_frame = elapsed / n
frame_times.extend([per_frame] * n)
remaining = (frames_to_process - frame_idx) * (np.mean(frame_times[-20:]) if len(frame_times) > 1 else per_frame)
progress(0.10 + 0.75 * frame_idx / frames_to_process,
desc=f"Frame {frame_idx}/{frames_to_process} ({per_frame:.2f}s/fr) ~{remaining:.0f}s left")
cap.release()
gpu_elapsed = time.time() - total_start
logger.info("[GPU phase] done: %d frames in %.1fs (%.2fs/fr)",
len(all_results), gpu_elapsed, gpu_elapsed / max(len(all_results), 1))
# FAST WRITE inside GPU: only comp (JPEG) + matte (PNG) + raw numpy.
# FG + Processed written AFTER GPU release (deferred).
from concurrent.futures import ThreadPoolExecutor
bg_lin = srgb_to_linear(create_checkerboard(w, h))
comp_dir = os.path.join(tmpdir, "Comp")
matte_dir = os.path.join(tmpdir, "Matte")
fg_dir = os.path.join(tmpdir, "FG")
processed_dir = os.path.join(tmpdir, "Processed")
for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
os.makedirs(d, exist_ok=True)
t_write = time.time()
progress(0.86, desc="Writing preview frames...")
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as pool:
futs = [pool.submit(_write_frame_fast, idx, alpha, fg, w, h, bg_lin,
comp_dir, matte_dir, fg_dir)
for idx, alpha, fg in all_results]
for f in futs:
f.result()
del all_results
gc.collect()
logger.info("[GPU phase] Fast write in %.1fs", time.time() - t_write)
return {
"results": "written", "frame_times": frame_times,
"use_gpu": True, "batch_size": batch_size,
"w": w, "h": h, "fps": fps, "tmpdir": tmpdir,
}
else:
# CPU PATH: sequential ONNX + inline writes (no GPU budget concern)
bg_lin = srgb_to_linear(create_checkerboard(w, h))
comp_dir, fg_dir = os.path.join(tmpdir, "Comp"), os.path.join(tmpdir, "FG")
matte_dir, processed_dir = os.path.join(tmpdir, "Matte"), os.path.join(tmpdir, "Processed")
for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
os.makedirs(d, exist_ok=True)
for i in range(frames_to_process):
t0 = time.time()
ret, frame_bgr = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
frame_f32 = frame_rgb.astype(np.float32) / 255.0
if mask_mode == "Fast (classical)":
mask, _ = fast_greenscreen_mask(frame_f32)
if mask is None:
raise gr.Error("Fast mask failed. Try 'AI (BiRefNet)' mode.")
elif mask_mode == "Hybrid (auto)":
mask, conf = fast_greenscreen_mask(frame_f32)
if mask is None or conf < 0.7:
mask = birefnet_frame(birefnet, frame_rgb)
else:
mask = birefnet_frame(birefnet, frame_rgb)
result = corridorkey_frame_onnx(corridorkey_onnx, frame_f32, mask, img_size,
despill_strength=despill_strength,
auto_despeckle=auto_despeckle,
despeckle_size=int(despeckle_size))
_write_frame_outputs(i, result["alpha"], result["fg"],
w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir)
elapsed = time.time() - t0
frame_times.append(elapsed)
remaining = (frames_to_process - i - 1) * (np.mean(frame_times[-5:]) if len(frame_times) > 1 else elapsed)
progress(0.10 + 0.80 * (i+1) / frames_to_process,
desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) ~{remaining:.0f}s left")
cap.release()
return {
"results": None, "frame_times": frame_times,
"use_gpu": False, "batch_size": 1,
"w": w, "h": h, "fps": fps, "tmpdir": tmpdir,
}
except gr.Error:
raise
except Exception as e:
logger.exception("Inference failed")
raise gr.Error(f"Inference failed: {e}")
def process_video(video_path, resolution, despill_val, mask_mode,
auto_despeckle, despeckle_size, progress=gr.Progress()):
"""Orchestrator: precompute fast masks (CPU) → GPU inference → CPU I/O."""
if video_path is None:
raise gr.Error("Please upload a video.")
# Phase 0: Precompute fast masks on CPU and save to disk.
# IMPORTANT: Can't pass large data as args to @spaces.GPU (ZeroGPU serializes args).
# Save to a numpy file, pass only the path.
logger.info("[Phase 0] Precomputing fast masks on CPU")
t_mask = time.time()
precompute_dir = tempfile.mkdtemp(prefix="ck_pre_")
cap = cv2.VideoCapture(video_path)
frame_count = 0
needs_birefnet = False
while True:
ret, frame_bgr = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
frame_f32 = frame_rgb.astype(np.float32) / 255.0
if mask_mode == "Fast (classical)":
mask, _ = fast_greenscreen_mask(frame_f32)
if mask is None:
raise gr.Error("Fast mask failed. Try 'Hybrid' or 'AI' mode.")
elif mask_mode == "Hybrid (auto)":
mask, conf = fast_greenscreen_mask(frame_f32)
if mask is None or conf < 0.7:
mask = None
needs_birefnet = True
else:
mask = None
needs_birefnet = True
# Save as compressed numpy (fast to load, no serialization overhead)
np.save(os.path.join(precompute_dir, f"frame_{frame_count:05d}.npy"), frame_f32)
if mask is not None:
np.save(os.path.join(precompute_dir, f"mask_{frame_count:05d}.npy"), mask)
if mask is None:
np.save(os.path.join(precompute_dir, f"rgb_{frame_count:05d}.npy"), frame_rgb)
frame_count += 1
cap.release()
logger.info("[Phase 0] %d frames saved to %s in %.1fs (needs_birefnet=%s)",
frame_count, precompute_dir, time.time() - t_mask, needs_birefnet)
# Phase 1: GPU inference — pass only paths (tiny strings), not data
logger.info("[Phase 1] Starting GPU phase")
t0 = time.time()
data = _gpu_phase(video_path, resolution, despill_val, mask_mode,
auto_despeckle, despeckle_size, progress,
precompute_dir=precompute_dir, precompute_count=frame_count)
logger.info("[process_video] GPU phase done in %.1fs", time.time() - t0)
tmpdir = data["tmpdir"]
w, h, fps = data["w"], data["h"], data["fps"]
frame_times = data["frame_times"]
use_gpu = data["use_gpu"]
batch_size = data["batch_size"]
comp_dir = os.path.join(tmpdir, "Comp")
fg_dir = os.path.join(tmpdir, "FG")
matte_dir = os.path.join(tmpdir, "Matte")
processed_dir = os.path.join(tmpdir, "Processed")
for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
os.makedirs(d, exist_ok=True)
try:
from concurrent.futures import ThreadPoolExecutor
logger.info("[Phase 2] Frames written by GPU/CPU phase (comp+fg+matte)")
# Phase 3: stitch videos from written frames
logger.info("[Phase 3] Stitching videos")
progress(0.93, desc="Stitching videos...")
comp_video = os.path.join(tmpdir, "comp_preview.mp4")
matte_video = os.path.join(tmpdir, "matte_preview.mp4")
# Comp uses JPEG, Matte uses PNG
_stitch_ffmpeg(comp_dir, comp_video, fps, pattern="%05d.jpg", extra_args=["-crf", "18"])
_stitch_ffmpeg(matte_dir, matte_video, fps, pattern="%05d.png", extra_args=["-crf", "18"])
# Phase 4: ZIP (no GPU)
logger.info("[Phase 4] Packaging ZIP")
progress(0.96, desc="Packaging ZIP...")
zip_path = os.path.join(tmpdir, "CorridorKey_Output.zip")
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
for folder in ["Comp", "FG", "Matte", "Processed"]:
src = os.path.join(tmpdir, folder)
if os.path.isdir(src):
for f in sorted(os.listdir(src)):
zf.write(os.path.join(src, f), f"Output/{folder}/{f}")
progress(1.0, desc="Done!")
total_elapsed = sum(frame_times) if frame_times else 0
n = len(frame_times)
avg = np.mean(frame_times) if frame_times else 0
engine = "PyTorch GPU" if use_gpu else "ONNX CPU"
status = (f"Processed {n} frames ({w}x{h}) at {resolution}px | "
f"{avg:.2f}s/frame | {engine}" +
(f" batch={batch_size}" if use_gpu else ""))
return (
comp_video if os.path.exists(comp_video) else None,
matte_video if os.path.exists(matte_video) else None,
zip_path,
status,
)
except gr.Error:
raise
except Exception as e:
logger.exception("Output writing failed")
raise gr.Error(f"Output failed: {e}")
finally:
for d in ["Comp", "FG", "Matte", "Processed"]:
p = os.path.join(tmpdir, d)
if os.path.isdir(p):
shutil.rmtree(p, ignore_errors=True)
gc.collect()
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size):
return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size)
DESCRIPTION = """# CorridorKey Green Screen Matting
Remove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital.
ZeroGPU H200: batched PyTorch inference (up to 32 frames at once). CPU fallback via ONNX."""
with gr.Blocks(title="CorridorKey") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video(label="Upload Green Screen Video")
with gr.Accordion("Settings", open=True):
resolution = gr.Radio(
choices=["1024", "2048"], value="1024",
label="Processing Resolution",
info="1024 = fast (batch 32 on GPU), 2048 = max quality (batch 8 on GPU)"
)
mask_mode = gr.Radio(
choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"],
value="Hybrid (auto)", label="Mask Mode",
info="Hybrid = fast green detection + AI fallback. Fast = classical only. AI = always BiRefNet"
)
despill_slider = gr.Slider(
0, 10, value=5, step=1, label="Despill Strength",
info="Remove green reflections (0=off, 10=max)"
)
despeckle_check = gr.Checkbox(
value=True, label="Auto Despeckle",
info="Remove small disconnected artifacts"
)
despeckle_size = gr.Number(
value=400, precision=0, label="Despeckle Size",
info="Min pixel area to keep"
)
process_btn = gr.Button("Process Video", variant="primary", size="lg")
with gr.Column(scale=1):
with gr.Row():
comp_video = gr.Video(label="Composite Preview")
matte_video = gr.Video(label="Alpha Matte")
download_zip = gr.File(label="Download Full Package (Comp + FG + Matte + Processed)")
status_text = gr.Textbox(label="Status", interactive=False)
gr.Examples(
examples=[
["examples/corridor_greenscreen_demo.mp4", "1024", 5, "Hybrid (auto)", True, 400],
],
inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size],
outputs=[comp_video, matte_video, download_zip, status_text],
fn=process_example,
cache_examples=True,
cache_mode="lazy",
label="Examples (click to load)"
)
process_btn.click(
fn=process_video,
inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size],
outputs=[comp_video, matte_video, download_zip, status_text],
)
# ---------------------------------------------------------------------------
# CLI mode
# ---------------------------------------------------------------------------
def cli_main():
import argparse
parser = argparse.ArgumentParser(description="CorridorKey Green Screen Matting")
parser.add_argument("--input", required=True)
parser.add_argument("--output", default="output")
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
parser.add_argument("--resolution", default="1024", choices=["1024", "2048"])
parser.add_argument("--mask-mode", default="Hybrid (auto)",
choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"])
parser.add_argument("--despill", type=int, default=5)
parser.add_argument("--no-despeckle", action="store_true")
parser.add_argument("--despeckle-size", type=int, default=400)
args = parser.parse_args()
global HAS_CUDA
if args.device == "cpu": HAS_CUDA = False
elif args.device == "cuda": HAS_CUDA = True
print(f"Device: {'CUDA' if HAS_CUDA else 'CPU'}")
class CLIProgress:
def __call__(self, val, desc=""):
if desc: print(f" [{val:.0%}] {desc}")
comp, matte, zipf, status = process_video(
args.input, args.resolution, args.despill, args.mask_mode,
not args.no_despeckle, args.despeckle_size, progress=CLIProgress()
)
print(f"\n{status}")
os.makedirs(args.output, exist_ok=True)
if zipf:
dst = os.path.join(args.output, os.path.basename(zipf))
shutil.copy2(zipf, dst)
print(f"Output: {dst}")
if __name__ == "__main__":
if len(sys.argv) > 1 and "--input" in sys.argv:
cli_main()
else:
demo.queue(default_concurrency_limit=1)
demo.launch(ssr_mode=False, mcp_server=True)