Spaces:
Running on Zero
Running on Zero
| """CorridorKey Green Screen Matting - HuggingFace Space. | |
| Self-contained Gradio app with dual inference paths: | |
| - GPU (ZeroGPU H200): PyTorch batched inference via GreenFormer | |
| - CPU (fallback): ONNX Runtime sequential inference | |
| Usage: | |
| python app.py # Launch Gradio UI | |
| python app.py --input video.mp4 # CLI mode | |
| """ | |
| import os | |
| import sys | |
| import math | |
| import shutil | |
| import gc | |
| import time | |
| import tempfile | |
| import zipfile | |
| import subprocess | |
| import logging | |
| # Thread tuning for CPU (must be set before numpy/cv2/ort import) | |
| os.environ["OMP_NUM_THREADS"] = "2" | |
| os.environ["OPENBLAS_NUM_THREADS"] = "2" | |
| os.environ["MKL_NUM_THREADS"] = "2" | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| import onnxruntime as ort | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| # Workaround: Gradio cache_examples bug with None outputs. | |
| _original_read_from_flag = gr.components.Component.read_from_flag | |
| def _patched_read_from_flag(self, payload): | |
| if payload is None or (isinstance(payload, str) and payload.strip() == ""): | |
| return None | |
| return _original_read_from_flag(self, payload) | |
| gr.components.Component.read_from_flag = _patched_read_from_flag | |
| from huggingface_hub import hf_hub_download | |
| cv2.setNumThreads(2) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| BIREFNET_REPO = "onnx-community/BiRefNet_lite-ONNX" | |
| BIREFNET_FILE = "onnx/model.onnx" | |
| MODELS_DIR = os.path.join(os.path.dirname(__file__), "models") | |
| CORRIDORKEY_MODELS = { | |
| "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"), | |
| "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"), | |
| } | |
| CORRIDORKEY_PTH_REPO = "nikopueringer/CorridorKey_v1.0" | |
| CORRIDORKEY_PTH_FILE = "CorridorKey_v1.0.pth" | |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) | |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) | |
| MAX_DURATION_CPU = 5 | |
| MAX_DURATION_GPU = 60 | |
| MAX_FRAMES = 1800 | |
| HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers() | |
| # --------------------------------------------------------------------------- | |
| # Preload model files at startup (OUTSIDE GPU function — don't waste GPU time on downloads) | |
| # --------------------------------------------------------------------------- | |
| logger.info("Preloading model files at startup...") | |
| _preloaded_birefnet_path = None | |
| _preloaded_pth_path = None | |
| try: | |
| _preloaded_birefnet_path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE) | |
| logger.info("BiRefNet cached: %s", _preloaded_birefnet_path) | |
| except Exception as e: | |
| logger.warning("BiRefNet preload failed (will retry later): %s", e) | |
| try: | |
| _preloaded_pth_path = hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE) | |
| logger.info("CorridorKey.pth cached: %s", _preloaded_pth_path) | |
| except Exception as e: | |
| logger.warning("CorridorKey.pth preload failed (will retry later): %s", e) | |
| # Batch sizes for GPU inference (conservative for H200 80GB) | |
| GPU_BATCH_SIZES = {"1024": 32, "2048": 16} # 2048 uses only 5.7GB/batch=2, so 16 easily fits in 69.8GB | |
| # --------------------------------------------------------------------------- | |
| # Color utilities (numpy-only) | |
| # --------------------------------------------------------------------------- | |
| def linear_to_srgb(x): | |
| x = np.clip(x, 0.0, None) | |
| return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055) | |
| def srgb_to_linear(x): | |
| x = np.clip(x, 0.0, None) | |
| return np.where(x <= 0.04045, x / 12.92, np.power((x + 0.055) / 1.055, 2.4)) | |
| def composite_straight(fg, bg, alpha): | |
| return fg * alpha + bg * (1.0 - alpha) | |
| def despill(image, green_limit_mode="average", strength=1.0): | |
| if strength <= 0.0: | |
| return image | |
| r, g, b = image[..., 0], image[..., 1], image[..., 2] | |
| limit = (r + b) / 2.0 if green_limit_mode == "average" else np.maximum(r, b) | |
| spill = np.maximum(g - limit, 0.0) | |
| despilled = np.stack([r + spill * 0.5, g - spill, b + spill * 0.5], axis=-1) | |
| return image * (1.0 - strength) + despilled * strength if strength < 1.0 else despilled | |
| def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5): | |
| is_3d = alpha_np.ndim == 3 | |
| if is_3d: | |
| alpha_np = alpha_np[:, :, 0] | |
| mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255 | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_8u, connectivity=8) | |
| valid = np.zeros(num_labels, dtype=bool) | |
| valid[1:] = stats[1:, cv2.CC_STAT_AREA] >= area_threshold | |
| cleaned = (valid[labels].astype(np.uint8) * 255) | |
| if dilation > 0: | |
| k = int(dilation * 2 + 1) | |
| cleaned = cv2.dilate(cleaned, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))) | |
| if blur_size > 0: | |
| b = int(blur_size * 2 + 1) | |
| cleaned = cv2.GaussianBlur(cleaned, (b, b), 0) | |
| result = alpha_np * (cleaned.astype(np.float32) / 255.0) | |
| return result[:, :, np.newaxis] if is_3d else result | |
| def create_checkerboard(w, h, checker_size=64, color1=0.15, color2=0.55): | |
| xg, yg = np.meshgrid(np.arange(w) // checker_size, np.arange(h) // checker_size) | |
| bg = np.where(((xg + yg) % 2) == 0, color1, color2).astype(np.float32) | |
| return np.stack([bg, bg, bg], axis=-1) | |
| def premultiply(fg, alpha): | |
| return fg * alpha | |
| # --------------------------------------------------------------------------- | |
| # Fast classical green-screen mask | |
| # --------------------------------------------------------------------------- | |
| def fast_greenscreen_mask(frame_rgb_f32): | |
| h, w = frame_rgb_f32.shape[:2] | |
| ph, pw = max(int(h * 0.05), 4), max(int(w * 0.05), 4) | |
| corners = np.concatenate([ | |
| frame_rgb_f32[:ph, :pw].reshape(-1, 3), | |
| frame_rgb_f32[:ph, -pw:].reshape(-1, 3), | |
| frame_rgb_f32[-ph:, :pw].reshape(-1, 3), | |
| frame_rgb_f32[-ph:, -pw:].reshape(-1, 3), | |
| ], axis=0) | |
| bg_color = np.median(corners, axis=0) | |
| if not (bg_color[1] > bg_color[0] + 0.05 and bg_color[1] > bg_color[2] + 0.05): | |
| return None, 0.0 | |
| frame_u8 = (np.clip(frame_rgb_f32, 0, 1) * 255).astype(np.uint8) | |
| hsv = cv2.cvtColor(frame_u8, cv2.COLOR_RGB2HSV) | |
| green_mask = cv2.inRange(hsv, (35, 40, 40), (85, 255, 255)) | |
| fg_mask = cv2.bitwise_not(green_mask) | |
| fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))) | |
| fg_mask = cv2.GaussianBlur(fg_mask, (5, 5), 0) | |
| mask_f32 = fg_mask.astype(np.float32) / 255.0 | |
| confidence = 1.0 - 2.0 * np.mean(np.minimum(mask_f32, 1.0 - mask_f32)) | |
| return mask_f32, confidence | |
| # --------------------------------------------------------------------------- | |
| # ONNX model loading (CPU fallback + BiRefNet) | |
| # --------------------------------------------------------------------------- | |
| _birefnet_session = None | |
| _corridorkey_sessions = {} | |
| _sessions_on_gpu = False | |
| def _get_providers(): | |
| """Get best available providers. Inside @spaces.GPU, CUDA is available.""" | |
| providers = ort.get_available_providers() | |
| if "CUDAExecutionProvider" in providers: | |
| return ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| return ["CPUExecutionProvider"] | |
| def _ort_opts(): | |
| opts = ort.SessionOptions() | |
| if "CUDAExecutionProvider" in ort.get_available_providers(): | |
| opts.intra_op_num_threads = 0 | |
| opts.inter_op_num_threads = 0 | |
| else: | |
| opts.intra_op_num_threads = 2 | |
| opts.inter_op_num_threads = 1 | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | |
| opts.enable_mem_pattern = True | |
| return opts | |
| def _ensure_gpu_sessions(): | |
| """Reload ONNX sessions on GPU if CUDA just became available (ZeroGPU).""" | |
| global _birefnet_session, _corridorkey_sessions, _sessions_on_gpu | |
| has_cuda_now = "CUDAExecutionProvider" in ort.get_available_providers() | |
| if has_cuda_now and not _sessions_on_gpu: | |
| logger.info("CUDA available! Reloading ONNX sessions on GPU...") | |
| _birefnet_session = None | |
| _corridorkey_sessions = {} | |
| _sessions_on_gpu = True | |
| def get_birefnet(force_cpu=False): | |
| global _birefnet_session | |
| if _birefnet_session is None or force_cpu: | |
| path = _preloaded_birefnet_path or hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE) | |
| providers = ["CPUExecutionProvider"] if force_cpu else _get_providers() | |
| logger.info("Loading BiRefNet ONNX: %s (providers: %s)", path, providers) | |
| opts = _ort_opts() | |
| if force_cpu: | |
| opts.intra_op_num_threads = 2 | |
| opts.inter_op_num_threads = 1 | |
| _birefnet_session = ort.InferenceSession(path, opts, providers=providers) | |
| return _birefnet_session | |
| def get_corridorkey_onnx(resolution="1024"): | |
| global _corridorkey_sessions | |
| if resolution not in _corridorkey_sessions: | |
| onnx_path = CORRIDORKEY_MODELS.get(resolution) | |
| if not onnx_path or not os.path.exists(onnx_path): | |
| raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.") | |
| providers = _get_providers() | |
| logger.info("Loading CorridorKey ONNX (%s): %s (providers: %s)", resolution, onnx_path, providers) | |
| _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_opts(), providers=providers) | |
| return _corridorkey_sessions[resolution] | |
| # --------------------------------------------------------------------------- | |
| # PyTorch model loading (GPU path) | |
| # --------------------------------------------------------------------------- | |
| _pytorch_model = None | |
| _pytorch_model_size = None | |
| def _load_greenformer(img_size): | |
| """Load the GreenFormer PyTorch model for GPU inference.""" | |
| import torch | |
| import torch.nn.functional as F | |
| from CorridorKeyModule.core.model_transformer import GreenFormer | |
| checkpoint_path = _preloaded_pth_path or hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE) | |
| logger.info("Using checkpoint: %s", checkpoint_path) | |
| logger.info("Initializing GreenFormer (img_size=%d)...", img_size) | |
| model = GreenFormer( | |
| encoder_name="hiera_base_plus_224.mae_in1k_ft_in1k", | |
| img_size=img_size, | |
| use_refiner=True, | |
| ) | |
| # Load weights | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| state_dict = checkpoint.get("state_dict", checkpoint) | |
| # Fix compiled model prefix & handle PosEmbed mismatch | |
| new_state_dict = {} | |
| model_state = model.state_dict() | |
| for k, v in state_dict.items(): | |
| if k.startswith("_orig_mod."): | |
| k = k[10:] | |
| if "pos_embed" in k and k in model_state: | |
| if v.shape != model_state[k].shape: | |
| logger.info("Resizing %s from %s to %s", k, v.shape, model_state[k].shape) | |
| N_src = v.shape[1] | |
| C = v.shape[2] | |
| grid_src = int(math.sqrt(N_src)) | |
| grid_dst = int(math.sqrt(model_state[k].shape[1])) | |
| v_img = v.permute(0, 2, 1).view(1, C, grid_src, grid_src) | |
| v_resized = F.interpolate(v_img, size=(grid_dst, grid_dst), mode="bicubic", align_corners=False) | |
| v = v_resized.flatten(2).transpose(1, 2) | |
| new_state_dict[k] = v | |
| missing, unexpected = model.load_state_dict(new_state_dict, strict=False) | |
| if missing: | |
| logger.warning("Missing keys: %s", missing) | |
| if unexpected: | |
| logger.warning("Unexpected keys: %s", unexpected) | |
| model.eval() | |
| model = model.cuda().half() # FP16 for speed on H200 | |
| logger.info("Model loaded as FP16") | |
| try: | |
| import flash_attn | |
| logger.info("flash-attn v%s installed (prebuilt wheel)", getattr(flash_attn, '__version__', '?')) | |
| except ImportError: | |
| logger.info("flash-attn not available (using PyTorch SDPA)") | |
| logger.info("SDPA backends: flash=%s, mem_efficient=%s, math=%s", | |
| torch.backends.cuda.flash_sdp_enabled(), | |
| torch.backends.cuda.mem_efficient_sdp_enabled(), | |
| torch.backends.cuda.math_sdp_enabled()) | |
| # Skip torch.compile on ZeroGPU — the 37s warmup eats too much of the 120s budget. | |
| if not HAS_SPACES and sys.platform in ("linux", "win32"): | |
| try: | |
| compiled = torch.compile(model) | |
| dummy = torch.zeros(1, 4, img_size, img_size, dtype=torch.float16, device="cuda") | |
| with torch.inference_mode(): | |
| compiled(dummy) | |
| model = compiled | |
| logger.info("torch.compile() succeeded") | |
| except Exception as e: | |
| logger.warning("torch.compile() failed, using eager mode: %s", e) | |
| torch.cuda.empty_cache() | |
| else: | |
| logger.info("Skipping torch.compile() (ZeroGPU: saving GPU time for inference)") | |
| logger.info("GreenFormer loaded on CUDA (img_size=%d)", img_size) | |
| return model | |
| def get_pytorch_model(img_size): | |
| """Get or load the PyTorch GreenFormer model for the given resolution.""" | |
| global _pytorch_model, _pytorch_model_size | |
| if _pytorch_model is None or _pytorch_model_size != img_size: | |
| # Free old model if switching resolution | |
| if _pytorch_model is not None: | |
| import torch | |
| del _pytorch_model | |
| _pytorch_model = None | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| _pytorch_model = _load_greenformer(img_size) | |
| _pytorch_model_size = img_size | |
| return _pytorch_model | |
| # --------------------------------------------------------------------------- | |
| # Per-frame inference: ONNX (CPU fallback) | |
| # --------------------------------------------------------------------------- | |
| def birefnet_frame(session, image_rgb_uint8): | |
| h, w = image_rgb_uint8.shape[:2] | |
| inp = session.get_inputs()[0] | |
| res = (inp.shape[2], inp.shape[3]) | |
| img = cv2.resize(image_rgb_uint8, res).astype(np.float32) / 255.0 | |
| img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)[np.newaxis, :].astype(np.float32) | |
| pred = 1.0 / (1.0 + np.exp(-session.run(None, {inp.name: img})[-1])) | |
| return (cv2.resize(pred[0, 0], (w, h)) > 0.04).astype(np.float32) | |
| def corridorkey_frame_onnx(session, image_f32, mask_f32, img_size, | |
| despill_strength=0.5, auto_despeckle=True, despeckle_size=400): | |
| """ONNX inference for a single frame (CPU path).""" | |
| h, w = image_f32.shape[:2] | |
| img_r = cv2.resize(image_f32, (img_size, img_size)) | |
| mask_r = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis] | |
| inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1) | |
| inp = inp.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32) | |
| alpha_raw, fg_raw = session.run(None, {"input": inp}) | |
| alpha = cv2.resize(alpha_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) | |
| fg = cv2.resize(fg_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) | |
| if alpha.ndim == 2: | |
| alpha = alpha[:, :, np.newaxis] | |
| if auto_despeckle: | |
| alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) | |
| fg = despill(fg, green_limit_mode="average", strength=despill_strength) | |
| return {"alpha": alpha, "fg": fg} | |
| # --------------------------------------------------------------------------- | |
| # Batched inference: PyTorch (GPU path) | |
| # --------------------------------------------------------------------------- | |
| def corridorkey_batch_pytorch(model, images_f32, masks_f32, img_size, | |
| despill_strength=0.5, auto_despeckle=True, despeckle_size=400): | |
| """PyTorch batched inference for multiple frames on GPU. | |
| Args: | |
| model: GreenFormer model on CUDA | |
| images_f32: list of [H, W, 3] float32 numpy arrays (0-1, sRGB) | |
| masks_f32: list of [H, W] float32 numpy arrays (0-1) | |
| img_size: model input resolution (1024 or 2048) | |
| Returns: | |
| list of dicts with 'alpha' [H,W,1] and 'fg' [H,W,3] | |
| """ | |
| import torch | |
| batch_size = len(images_f32) | |
| if batch_size == 0: | |
| return [] | |
| # Store original sizes per frame | |
| orig_sizes = [(img.shape[1], img.shape[0]) for img in images_f32] # (w, h) | |
| # Preprocess: resize, normalize, concatenate into batch tensor | |
| batch_inputs = [] | |
| for img, mask in zip(images_f32, masks_f32): | |
| img_r = cv2.resize(img, (img_size, img_size)) | |
| mask_r = cv2.resize(mask, (img_size, img_size))[:, :, np.newaxis] | |
| inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1) | |
| batch_inputs.append(inp.transpose(2, 0, 1)) # [4, H, W] | |
| batch_np = np.stack(batch_inputs, axis=0).astype(np.float32) # [B, 4, H, W] | |
| batch_tensor = torch.from_numpy(batch_np).cuda().half() # FP16 input | |
| # Forward pass — model is FP16, input is FP16, no autocast needed | |
| with torch.inference_mode(): | |
| out = model(batch_tensor) | |
| # Extract results | |
| alphas_gpu = out["alpha"].float().cpu().numpy() # [B, 1, H, W] | |
| fgs_gpu = out["fg"].float().cpu().numpy() # [B, 3, H, W] | |
| del batch_tensor | |
| # Don't empty cache per batch - too expensive. Let PyTorch manage. | |
| # Postprocess each frame | |
| results = [] | |
| for i in range(batch_size): | |
| w, h = orig_sizes[i] | |
| alpha = cv2.resize(alphas_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) | |
| fg = cv2.resize(fgs_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4) | |
| if alpha.ndim == 2: | |
| alpha = alpha[:, :, np.newaxis] | |
| if auto_despeckle: | |
| alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) | |
| fg = despill(fg, green_limit_mode="average", strength=despill_strength) | |
| results.append({"alpha": alpha, "fg": fg}) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Video stitching | |
| # --------------------------------------------------------------------------- | |
| def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420p", | |
| codec="libx264", extra_args=None): | |
| cmd = ["ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(frame_dir, pattern), | |
| "-c:v", codec, "-pix_fmt", pix_fmt] | |
| if extra_args: | |
| cmd.extend(extra_args) | |
| cmd.append(out_path) | |
| try: | |
| subprocess.run(cmd, capture_output=True, timeout=300, check=True) | |
| return True | |
| except (FileNotFoundError, subprocess.TimeoutExpired, subprocess.CalledProcessError) as e: | |
| logger.warning("ffmpeg failed: %s", e) | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Output writing helper | |
| # --------------------------------------------------------------------------- | |
| # Fastest PNG params: compression 1 (instead of default 3) | |
| _PNG_FAST = [cv2.IMWRITE_PNG_COMPRESSION, 1] | |
| # JPEG for opaque outputs (comp/fg) — 10x faster than PNG at 4K | |
| _JPG_QUALITY = [cv2.IMWRITE_JPEG_QUALITY, 95] | |
| def _write_frame_fast(i, alpha, fg, w, h, bg_lin, comp_dir, matte_dir, fg_dir): | |
| """Fast write: comp (JPEG) + matte (PNG) + fg (JPEG). No heavy PNG/npz.""" | |
| if alpha.ndim == 2: | |
| alpha = alpha[:, :, np.newaxis] | |
| alpha_2d = alpha[:, :, 0] | |
| fg_lin = srgb_to_linear(fg) | |
| comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha)) | |
| cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"), | |
| (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) | |
| cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"), | |
| (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) | |
| cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"), | |
| (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST) | |
| def _write_frame_deferred(i, raw_path, w, h, bg_lin, fg_dir, processed_dir): | |
| """Deferred write: FG (JPEG) + Processed (RGBA PNG). Runs after GPU release.""" | |
| d = np.load(raw_path) | |
| alpha, fg = d["alpha"], d["fg"] | |
| if alpha.ndim == 2: | |
| alpha = alpha[:, :, np.newaxis] | |
| alpha_2d = alpha[:, :, 0] | |
| cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"), | |
| (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) | |
| fg_lin = srgb_to_linear(fg) | |
| fg_premul = premultiply(fg_lin, alpha) | |
| fg_premul_srgb = linear_to_srgb(fg_premul) | |
| fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8) | |
| a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8) | |
| rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1) | |
| cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST) | |
| os.remove(raw_path) # cleanup | |
| def _write_frame_outputs(i, alpha, fg, w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir): | |
| """Full write: all 4 outputs. Used by CPU path.""" | |
| if alpha.ndim == 2: | |
| alpha = alpha[:, :, np.newaxis] | |
| alpha_2d = alpha[:, :, 0] | |
| fg_lin = srgb_to_linear(fg) | |
| comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha)) | |
| cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"), | |
| (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) | |
| cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"), | |
| (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY) | |
| cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"), | |
| (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST) | |
| fg_premul = premultiply(fg_lin, alpha) | |
| fg_premul_srgb = linear_to_srgb(fg_premul) | |
| fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8) | |
| a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8) | |
| rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1) | |
| cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST) | |
| # --------------------------------------------------------------------------- | |
| # Shared storage: GPU function stores results here instead of returning them. | |
| # This avoids ZeroGPU serializing gigabytes of numpy arrays on return. | |
| # --------------------------------------------------------------------------- | |
| _shared_results = {"data": None} | |
| # --------------------------------------------------------------------------- | |
| # Main pipeline | |
| # --------------------------------------------------------------------------- | |
| def _gpu_decorator(fn): | |
| if HAS_SPACES: | |
| return spaces.GPU(duration=120)(fn) | |
| return fn | |
| def _gpu_phase(video_path, resolution, despill_val, mask_mode, | |
| auto_despeckle, despeckle_size, progress=gr.Progress(), | |
| precompute_dir=None, precompute_count=0): | |
| """ALL GPU work: load models, read video, generate masks, run inference. | |
| Returns raw numpy results in RAM. No disk I/O. | |
| """ | |
| if video_path is None: | |
| raise gr.Error("Please upload a video.") | |
| _ensure_gpu_sessions() | |
| try: | |
| import torch | |
| has_torch_cuda = torch.cuda.is_available() | |
| except ImportError: | |
| has_torch_cuda = False | |
| use_gpu = has_torch_cuda | |
| logger.info("[GPU phase] CUDA=%s, mode=%s", has_torch_cuda, | |
| "PyTorch batched" if use_gpu else "ONNX sequential") | |
| img_size = int(resolution) | |
| max_dur = MAX_DURATION_GPU if use_gpu else MAX_DURATION_CPU | |
| despill_strength = despill_val / 10.0 | |
| # Read video metadata | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| if total_frames == 0: | |
| raise gr.Error("Could not read video frames.") | |
| duration = total_frames / fps | |
| if duration > max_dur: | |
| raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s.") | |
| frames_to_process = min(total_frames, MAX_FRAMES) | |
| # Load BiRefNet only if masks need it (skip if all precomputed) | |
| birefnet = None | |
| needs_birefnet = precompute_dir is None or precompute_count == 0 | |
| if not needs_birefnet and mask_mode != "Fast (classical)": | |
| # Check if any frames need BiRefNet (missing mask files) | |
| for i in range(min(frames_to_process, precompute_count)): | |
| if not os.path.exists(os.path.join(precompute_dir, f"mask_{i:05d}.npy")): | |
| needs_birefnet = True | |
| break | |
| if needs_birefnet: | |
| progress(0.02, desc="Loading BiRefNet...") | |
| birefnet = get_birefnet() | |
| logger.info("BiRefNet loaded (needed for some frames)") | |
| else: | |
| logger.info("Skipping BiRefNet load (all masks precomputed)") | |
| batch_size = GPU_BATCH_SIZES.get(resolution, 16) if use_gpu else 1 | |
| if use_gpu: | |
| progress(0.05, desc=f"Loading GreenFormer ({resolution})...") | |
| pytorch_model = get_pytorch_model(img_size) | |
| else: | |
| progress(0.05, desc=f"Loading CorridorKey ONNX ({resolution})...") | |
| corridorkey_onnx = get_corridorkey_onnx(resolution) | |
| logger.info("[GPU phase] %d frames (%dx%d @ %.1ffps), res=%d, mask=%s, batch=%d", | |
| frames_to_process, w, h, fps, img_size, mask_mode, batch_size) | |
| # Read all frames + generate masks + run inference | |
| tmpdir = tempfile.mkdtemp(prefix="ck_") | |
| frame_times = [] | |
| total_start = time.time() | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if use_gpu: | |
| import torch | |
| vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| logger.info("VRAM: %.1f/%.1fGB", | |
| torch.cuda.memory_allocated() / 1024**3, vram_total) | |
| all_results = [] | |
| frame_idx = 0 | |
| # Load precomputed frames from disk (no serialization overhead) | |
| use_precomputed = precompute_dir is not None and precompute_count > 0 | |
| while frame_idx < frames_to_process: | |
| t_batch = time.time() | |
| batch_images, batch_masks, batch_indices = [], [], [] | |
| t_mask = 0 | |
| fast_n, biref_n = 0, 0 | |
| for _ in range(batch_size): | |
| if frame_idx >= frames_to_process: | |
| break | |
| if use_precomputed: | |
| frame_f32 = np.load(os.path.join(precompute_dir, f"frame_{frame_idx:05d}.npy")) | |
| mask_path = os.path.join(precompute_dir, f"mask_{frame_idx:05d}.npy") | |
| if os.path.exists(mask_path): | |
| mask = np.load(mask_path) | |
| fast_n += 1 | |
| else: | |
| # BiRefNet fallback — load original RGB, run on GPU | |
| rgb_path = os.path.join(precompute_dir, f"rgb_{frame_idx:05d}.npy") | |
| frame_rgb = np.load(rgb_path) | |
| tm = time.time() | |
| mask = birefnet_frame(birefnet, frame_rgb) | |
| t_mask += time.time() - tm | |
| biref_n += 1 | |
| else: | |
| ret, frame_bgr = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| frame_f32 = frame_rgb.astype(np.float32) / 255.0 | |
| tm = time.time() | |
| if mask_mode == "Fast (classical)": | |
| mask, _ = fast_greenscreen_mask(frame_f32) | |
| fast_n += 1 | |
| elif mask_mode == "Hybrid (auto)": | |
| mask, conf = fast_greenscreen_mask(frame_f32) | |
| if mask is None or conf < 0.7: | |
| mask = birefnet_frame(birefnet, frame_rgb) | |
| biref_n += 1 | |
| else: | |
| fast_n += 1 | |
| else: | |
| mask = birefnet_frame(birefnet, frame_rgb) | |
| biref_n += 1 | |
| t_mask += time.time() - tm | |
| batch_images.append(frame_f32) | |
| batch_masks.append(mask) | |
| batch_indices.append(frame_idx) | |
| frame_idx += 1 | |
| if not batch_images: | |
| break | |
| # Batched GPU inference | |
| t_inf = time.time() | |
| results = corridorkey_batch_pytorch( | |
| pytorch_model, batch_images, batch_masks, img_size, | |
| despill_strength=despill_strength, | |
| auto_despeckle=auto_despeckle, | |
| despeckle_size=int(despeckle_size), | |
| ) | |
| t_inf = time.time() - t_inf | |
| for j, result in enumerate(results): | |
| all_results.append((batch_indices[j], result["alpha"], result["fg"])) | |
| n = len(batch_images) | |
| elapsed = time.time() - t_batch | |
| vram_peak = torch.cuda.max_memory_allocated() / 1024**3 | |
| logger.info("Batch %d: mask=%.1fs(fast=%d,biref=%d) infer=%.1fs total=%.1fs(%.2fs/fr) VRAM=%.1fGB", | |
| n, t_mask, fast_n, biref_n, t_inf, elapsed, elapsed/n, vram_peak) | |
| per_frame = elapsed / n | |
| frame_times.extend([per_frame] * n) | |
| remaining = (frames_to_process - frame_idx) * (np.mean(frame_times[-20:]) if len(frame_times) > 1 else per_frame) | |
| progress(0.10 + 0.75 * frame_idx / frames_to_process, | |
| desc=f"Frame {frame_idx}/{frames_to_process} ({per_frame:.2f}s/fr) ~{remaining:.0f}s left") | |
| cap.release() | |
| gpu_elapsed = time.time() - total_start | |
| logger.info("[GPU phase] done: %d frames in %.1fs (%.2fs/fr)", | |
| len(all_results), gpu_elapsed, gpu_elapsed / max(len(all_results), 1)) | |
| # FAST WRITE inside GPU: only comp (JPEG) + matte (PNG) + raw numpy. | |
| # FG + Processed written AFTER GPU release (deferred). | |
| from concurrent.futures import ThreadPoolExecutor | |
| bg_lin = srgb_to_linear(create_checkerboard(w, h)) | |
| comp_dir = os.path.join(tmpdir, "Comp") | |
| matte_dir = os.path.join(tmpdir, "Matte") | |
| fg_dir = os.path.join(tmpdir, "FG") | |
| processed_dir = os.path.join(tmpdir, "Processed") | |
| for d in [comp_dir, fg_dir, matte_dir, processed_dir]: | |
| os.makedirs(d, exist_ok=True) | |
| t_write = time.time() | |
| progress(0.86, desc="Writing preview frames...") | |
| with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as pool: | |
| futs = [pool.submit(_write_frame_fast, idx, alpha, fg, w, h, bg_lin, | |
| comp_dir, matte_dir, fg_dir) | |
| for idx, alpha, fg in all_results] | |
| for f in futs: | |
| f.result() | |
| del all_results | |
| gc.collect() | |
| logger.info("[GPU phase] Fast write in %.1fs", time.time() - t_write) | |
| return { | |
| "results": "written", "frame_times": frame_times, | |
| "use_gpu": True, "batch_size": batch_size, | |
| "w": w, "h": h, "fps": fps, "tmpdir": tmpdir, | |
| } | |
| else: | |
| # CPU PATH: sequential ONNX + inline writes (no GPU budget concern) | |
| bg_lin = srgb_to_linear(create_checkerboard(w, h)) | |
| comp_dir, fg_dir = os.path.join(tmpdir, "Comp"), os.path.join(tmpdir, "FG") | |
| matte_dir, processed_dir = os.path.join(tmpdir, "Matte"), os.path.join(tmpdir, "Processed") | |
| for d in [comp_dir, fg_dir, matte_dir, processed_dir]: | |
| os.makedirs(d, exist_ok=True) | |
| for i in range(frames_to_process): | |
| t0 = time.time() | |
| ret, frame_bgr = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| frame_f32 = frame_rgb.astype(np.float32) / 255.0 | |
| if mask_mode == "Fast (classical)": | |
| mask, _ = fast_greenscreen_mask(frame_f32) | |
| if mask is None: | |
| raise gr.Error("Fast mask failed. Try 'AI (BiRefNet)' mode.") | |
| elif mask_mode == "Hybrid (auto)": | |
| mask, conf = fast_greenscreen_mask(frame_f32) | |
| if mask is None or conf < 0.7: | |
| mask = birefnet_frame(birefnet, frame_rgb) | |
| else: | |
| mask = birefnet_frame(birefnet, frame_rgb) | |
| result = corridorkey_frame_onnx(corridorkey_onnx, frame_f32, mask, img_size, | |
| despill_strength=despill_strength, | |
| auto_despeckle=auto_despeckle, | |
| despeckle_size=int(despeckle_size)) | |
| _write_frame_outputs(i, result["alpha"], result["fg"], | |
| w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir) | |
| elapsed = time.time() - t0 | |
| frame_times.append(elapsed) | |
| remaining = (frames_to_process - i - 1) * (np.mean(frame_times[-5:]) if len(frame_times) > 1 else elapsed) | |
| progress(0.10 + 0.80 * (i+1) / frames_to_process, | |
| desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) ~{remaining:.0f}s left") | |
| cap.release() | |
| return { | |
| "results": None, "frame_times": frame_times, | |
| "use_gpu": False, "batch_size": 1, | |
| "w": w, "h": h, "fps": fps, "tmpdir": tmpdir, | |
| } | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| logger.exception("Inference failed") | |
| raise gr.Error(f"Inference failed: {e}") | |
| def process_video(video_path, resolution, despill_val, mask_mode, | |
| auto_despeckle, despeckle_size, progress=gr.Progress()): | |
| """Orchestrator: precompute fast masks (CPU) → GPU inference → CPU I/O.""" | |
| if video_path is None: | |
| raise gr.Error("Please upload a video.") | |
| # Phase 0: Precompute fast masks on CPU and save to disk. | |
| # IMPORTANT: Can't pass large data as args to @spaces.GPU (ZeroGPU serializes args). | |
| # Save to a numpy file, pass only the path. | |
| logger.info("[Phase 0] Precomputing fast masks on CPU") | |
| t_mask = time.time() | |
| precompute_dir = tempfile.mkdtemp(prefix="ck_pre_") | |
| cap = cv2.VideoCapture(video_path) | |
| frame_count = 0 | |
| needs_birefnet = False | |
| while True: | |
| ret, frame_bgr = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| frame_f32 = frame_rgb.astype(np.float32) / 255.0 | |
| if mask_mode == "Fast (classical)": | |
| mask, _ = fast_greenscreen_mask(frame_f32) | |
| if mask is None: | |
| raise gr.Error("Fast mask failed. Try 'Hybrid' or 'AI' mode.") | |
| elif mask_mode == "Hybrid (auto)": | |
| mask, conf = fast_greenscreen_mask(frame_f32) | |
| if mask is None or conf < 0.7: | |
| mask = None | |
| needs_birefnet = True | |
| else: | |
| mask = None | |
| needs_birefnet = True | |
| # Save as compressed numpy (fast to load, no serialization overhead) | |
| np.save(os.path.join(precompute_dir, f"frame_{frame_count:05d}.npy"), frame_f32) | |
| if mask is not None: | |
| np.save(os.path.join(precompute_dir, f"mask_{frame_count:05d}.npy"), mask) | |
| if mask is None: | |
| np.save(os.path.join(precompute_dir, f"rgb_{frame_count:05d}.npy"), frame_rgb) | |
| frame_count += 1 | |
| cap.release() | |
| logger.info("[Phase 0] %d frames saved to %s in %.1fs (needs_birefnet=%s)", | |
| frame_count, precompute_dir, time.time() - t_mask, needs_birefnet) | |
| # Phase 1: GPU inference — pass only paths (tiny strings), not data | |
| logger.info("[Phase 1] Starting GPU phase") | |
| t0 = time.time() | |
| data = _gpu_phase(video_path, resolution, despill_val, mask_mode, | |
| auto_despeckle, despeckle_size, progress, | |
| precompute_dir=precompute_dir, precompute_count=frame_count) | |
| logger.info("[process_video] GPU phase done in %.1fs", time.time() - t0) | |
| tmpdir = data["tmpdir"] | |
| w, h, fps = data["w"], data["h"], data["fps"] | |
| frame_times = data["frame_times"] | |
| use_gpu = data["use_gpu"] | |
| batch_size = data["batch_size"] | |
| comp_dir = os.path.join(tmpdir, "Comp") | |
| fg_dir = os.path.join(tmpdir, "FG") | |
| matte_dir = os.path.join(tmpdir, "Matte") | |
| processed_dir = os.path.join(tmpdir, "Processed") | |
| for d in [comp_dir, fg_dir, matte_dir, processed_dir]: | |
| os.makedirs(d, exist_ok=True) | |
| try: | |
| from concurrent.futures import ThreadPoolExecutor | |
| logger.info("[Phase 2] Frames written by GPU/CPU phase (comp+fg+matte)") | |
| # Phase 3: stitch videos from written frames | |
| logger.info("[Phase 3] Stitching videos") | |
| progress(0.93, desc="Stitching videos...") | |
| comp_video = os.path.join(tmpdir, "comp_preview.mp4") | |
| matte_video = os.path.join(tmpdir, "matte_preview.mp4") | |
| # Comp uses JPEG, Matte uses PNG | |
| _stitch_ffmpeg(comp_dir, comp_video, fps, pattern="%05d.jpg", extra_args=["-crf", "18"]) | |
| _stitch_ffmpeg(matte_dir, matte_video, fps, pattern="%05d.png", extra_args=["-crf", "18"]) | |
| # Phase 4: ZIP (no GPU) | |
| logger.info("[Phase 4] Packaging ZIP") | |
| progress(0.96, desc="Packaging ZIP...") | |
| zip_path = os.path.join(tmpdir, "CorridorKey_Output.zip") | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf: | |
| for folder in ["Comp", "FG", "Matte", "Processed"]: | |
| src = os.path.join(tmpdir, folder) | |
| if os.path.isdir(src): | |
| for f in sorted(os.listdir(src)): | |
| zf.write(os.path.join(src, f), f"Output/{folder}/{f}") | |
| progress(1.0, desc="Done!") | |
| total_elapsed = sum(frame_times) if frame_times else 0 | |
| n = len(frame_times) | |
| avg = np.mean(frame_times) if frame_times else 0 | |
| engine = "PyTorch GPU" if use_gpu else "ONNX CPU" | |
| status = (f"Processed {n} frames ({w}x{h}) at {resolution}px | " | |
| f"{avg:.2f}s/frame | {engine}" + | |
| (f" batch={batch_size}" if use_gpu else "")) | |
| return ( | |
| comp_video if os.path.exists(comp_video) else None, | |
| matte_video if os.path.exists(matte_video) else None, | |
| zip_path, | |
| status, | |
| ) | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| logger.exception("Output writing failed") | |
| raise gr.Error(f"Output failed: {e}") | |
| finally: | |
| for d in ["Comp", "FG", "Matte", "Processed"]: | |
| p = os.path.join(tmpdir, d) | |
| if os.path.isdir(p): | |
| shutil.rmtree(p, ignore_errors=True) | |
| gc.collect() | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size): | |
| return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size) | |
| DESCRIPTION = """# CorridorKey Green Screen Matting | |
| Remove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. | |
| ZeroGPU H200: batched PyTorch inference (up to 32 frames at once). CPU fallback via ONNX.""" | |
| with gr.Blocks(title="CorridorKey") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_video = gr.Video(label="Upload Green Screen Video") | |
| with gr.Accordion("Settings", open=True): | |
| resolution = gr.Radio( | |
| choices=["1024", "2048"], value="1024", | |
| label="Processing Resolution", | |
| info="1024 = fast (batch 32 on GPU), 2048 = max quality (batch 8 on GPU)" | |
| ) | |
| mask_mode = gr.Radio( | |
| choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"], | |
| value="Hybrid (auto)", label="Mask Mode", | |
| info="Hybrid = fast green detection + AI fallback. Fast = classical only. AI = always BiRefNet" | |
| ) | |
| despill_slider = gr.Slider( | |
| 0, 10, value=5, step=1, label="Despill Strength", | |
| info="Remove green reflections (0=off, 10=max)" | |
| ) | |
| despeckle_check = gr.Checkbox( | |
| value=True, label="Auto Despeckle", | |
| info="Remove small disconnected artifacts" | |
| ) | |
| despeckle_size = gr.Number( | |
| value=400, precision=0, label="Despeckle Size", | |
| info="Min pixel area to keep" | |
| ) | |
| process_btn = gr.Button("Process Video", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| comp_video = gr.Video(label="Composite Preview") | |
| matte_video = gr.Video(label="Alpha Matte") | |
| download_zip = gr.File(label="Download Full Package (Comp + FG + Matte + Processed)") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/corridor_greenscreen_demo.mp4", "1024", 5, "Hybrid (auto)", True, 400], | |
| ], | |
| inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size], | |
| outputs=[comp_video, matte_video, download_zip, status_text], | |
| fn=process_example, | |
| cache_examples=True, | |
| cache_mode="lazy", | |
| label="Examples (click to load)" | |
| ) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size], | |
| outputs=[comp_video, matte_video, download_zip, status_text], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # CLI mode | |
| # --------------------------------------------------------------------------- | |
| def cli_main(): | |
| import argparse | |
| parser = argparse.ArgumentParser(description="CorridorKey Green Screen Matting") | |
| parser.add_argument("--input", required=True) | |
| parser.add_argument("--output", default="output") | |
| parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"]) | |
| parser.add_argument("--resolution", default="1024", choices=["1024", "2048"]) | |
| parser.add_argument("--mask-mode", default="Hybrid (auto)", | |
| choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"]) | |
| parser.add_argument("--despill", type=int, default=5) | |
| parser.add_argument("--no-despeckle", action="store_true") | |
| parser.add_argument("--despeckle-size", type=int, default=400) | |
| args = parser.parse_args() | |
| global HAS_CUDA | |
| if args.device == "cpu": HAS_CUDA = False | |
| elif args.device == "cuda": HAS_CUDA = True | |
| print(f"Device: {'CUDA' if HAS_CUDA else 'CPU'}") | |
| class CLIProgress: | |
| def __call__(self, val, desc=""): | |
| if desc: print(f" [{val:.0%}] {desc}") | |
| comp, matte, zipf, status = process_video( | |
| args.input, args.resolution, args.despill, args.mask_mode, | |
| not args.no_despeckle, args.despeckle_size, progress=CLIProgress() | |
| ) | |
| print(f"\n{status}") | |
| os.makedirs(args.output, exist_ok=True) | |
| if zipf: | |
| dst = os.path.join(args.output, os.path.basename(zipf)) | |
| shutil.copy2(zipf, dst) | |
| print(f"Output: {dst}") | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 1 and "--input" in sys.argv: | |
| cli_main() | |
| else: | |
| demo.queue(default_concurrency_limit=1) | |
| demo.launch(ssr_mode=False, mcp_server=True) | |