| |
| """ |
| utils.segmentation |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| All high-quality person-segmentation code for BackgroundFX Pro. |
| |
| Exports |
| ------- |
| segment_person_hq(image, predictor, fallback_enabled=True) β np.ndarray |
| segment_person_hq_original(image, predictor, fallback_enabled=True) β np.ndarray |
| SegmentationError - Custom exception for segmentation errors |
| |
| Everything else is prefixed "_" and considered private. |
| """ |
|
|
| from __future__ import annotations |
| from typing import Any, Tuple, Optional, Dict |
| import logging, os, math |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| log = logging.getLogger(__name__) |
|
|
| |
| |
| |
| class SegmentationError(Exception): |
| """Custom exception for segmentation-related errors""" |
| pass |
|
|
| |
| |
| |
| USE_ENHANCED_SEGMENTATION = True |
| USE_INTELLIGENT_PROMPTING = True |
| USE_ITERATIVE_REFINEMENT = True |
|
|
| MIN_AREA_RATIO = 0.015 |
| MAX_AREA_RATIO = 0.97 |
| SALIENCY_THRESH = 0.65 |
| GRABCUT_ITERS = 3 |
|
|
| |
| |
| |
| __all__ = [ |
| "segment_person_hq", |
| "segment_person_hq_original", |
| "SegmentationError", |
| ] |
|
|
| |
| |
| |
| def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray: |
| """ |
| Convert SAM2 multi-mask output to single best mask for MatAnyone. |
| SAM2 returns (N, H, W) where N is typically 3 masks. |
| We need to return a single (H, W) mask. |
| """ |
| if masks is None or len(masks) == 0: |
| raise SegmentationError("No masks returned from SAM2") |
| |
| |
| if isinstance(masks, torch.Tensor): |
| masks = masks.cpu().numpy() |
| if scores is not None and isinstance(scores, torch.Tensor): |
| scores = scores.cpu().numpy() |
| |
| |
| if masks.ndim == 4: |
| masks = masks[0] |
| if masks.ndim != 3: |
| raise SegmentationError(f"Unexpected mask shape: {masks.shape}") |
| |
| |
| if scores is not None and len(scores) > 0: |
| best_idx = int(np.argmax(scores)) |
| else: |
| |
| areas = [np.sum(m > 0.5) for m in masks] |
| best_idx = int(np.argmax(areas)) |
| |
| mask = masks[best_idx] |
| |
| |
| if mask.dtype in (np.float32, np.float64): |
| mask = (mask > 0.5).astype(np.uint8) * 255 |
| elif mask.dtype != np.uint8: |
| mask = mask.astype(np.uint8) |
| |
| |
| if mask.ndim == 3: |
| mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze() |
| |
| |
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| |
| |
| assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}" |
| |
| return mask |
|
|
| |
| |
| |
|
|
| def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| """ |
| High-quality person segmentation. Tries SAM-2 with smart prompts first, |
| then a classical CV cascade, then a geometric fallback. |
| Returns uint8 mask (0/255). Never raises if fallback_enabled=True. |
| """ |
| if not USE_ENHANCED_SEGMENTATION: |
| return segment_person_hq_original(image, predictor, fallback_enabled) |
|
|
| if image is None or image.size == 0: |
| raise SegmentationError("Invalid input image") |
|
|
| |
| if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): |
| try: |
| predictor.set_image(image) |
| mask = ( |
| _segment_with_intelligent_prompts(image, predictor) |
| if USE_INTELLIGENT_PROMPTING |
| else _segment_with_basic_prompts(image, predictor) |
| ) |
| if USE_ITERATIVE_REFINEMENT: |
| mask = _auto_refine_mask_iteratively(image, mask, predictor) |
| if _validate_mask_quality(mask, image.shape[:2]): |
| return mask |
| log.warning("SAM2 mask failed validation β fallback") |
| except Exception as e: |
| log.warning(f"SAM2 path failed: {e}") |
|
|
| |
| try: |
| mask = _classical_segmentation_cascade(image) |
| if _validate_mask_quality(mask, image.shape[:2]): |
| return mask |
| log.warning("Classical cascade weak β geometric fallback") |
| except Exception as e: |
| log.debug(f"Classical cascade error: {e}") |
|
|
| |
| return _geometric_person_mask(image) |
|
|
|
|
| def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| """ |
| Very first implementation kept for rollback. Fewer smarts, still robust. |
| """ |
| if image is None or image.size == 0: |
| raise SegmentationError("Invalid input image") |
|
|
| try: |
| if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): |
| h, w = image.shape[:2] |
| predictor.set_image(image) |
|
|
| points = np.array([ |
| [w//2, h//4], |
| [w//2, h//2], |
| [w//2, 3*h//4], |
| [w//3, h//2], |
| [2*w//3, h//2], |
| ], dtype=np.float32) |
| labels = np.ones(len(points), np.int32) |
|
|
| with torch.no_grad(): |
| masks, scores, _ = predictor.predict( |
| point_coords=points, |
| point_labels=labels, |
| multimask_output=True, |
| ) |
| |
| |
| if masks is not None and len(masks): |
| mask = _sam2_to_matanyone_mask(masks, scores) |
| if _validate_mask_quality(mask, image.shape[:2]): |
| return mask |
| |
| if fallback_enabled: |
| return _classical_segmentation_cascade(image) |
| raise RuntimeError("SAM2 failed and fallback disabled") |
| except Exception as e: |
| log.warning(f"segment_person_hq_original error: {e}") |
| return _classical_segmentation_cascade(image) |
|
|
|
|
| |
| |
| |
|
|
| def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| pos, neg = _generate_smart_prompts(image) |
| return _sam2_predict(image, predictor, pos, neg) |
|
|
|
|
| def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| h, w = image.shape[:2] |
| pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32) |
| neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32) |
| return _sam2_predict(image, predictor, pos, neg) |
|
|
|
|
| def _sam2_predict(image: np.ndarray, predictor: Any, |
| pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray: |
| if pos_points.size == 0: |
| pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32) |
| points = np.vstack([pos_points, neg_points]) |
| labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32) |
| with torch.no_grad(): |
| masks, scores, _ = predictor.predict( |
| point_coords=points, |
| point_labels=labels, |
| multimask_output=True, |
| ) |
| |
| |
| return _sam2_to_matanyone_mask(masks, scores) |
|
|
|
|
| def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Simple saliency-based heuristic to auto-place positive / negative points. |
| """ |
| h, w = image.shape[:2] |
| sal = _compute_saliency(image) |
| pos, neg = [], [] |
| if sal is not None: |
| high = sal > (SALIENCY_THRESH - .1) |
| contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: |
| M = cv2.moments(c) |
| if M["m00"]: |
| pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])]) |
| if not pos: |
| pos = [[w//2, h//2]] |
| neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]] |
| return np.asarray(pos, np.float32), np.asarray(neg, np.float32) |
|
|
| |
| |
| |
|
|
| def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray: |
| """ |
| Edge-median background subtraction β saliency flood-fill β GrabCut. |
| """ |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]]) |
| diff = np.abs(gray.astype(float) - np.median(edge_px)) |
| mask = (diff > 30).astype(np.uint8) * 255 |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))) |
| if _validate_mask_quality(mask, image.shape[:2]): |
| return mask |
| |
| mask = _refine_with_saliency(image, mask) |
| if _validate_mask_quality(mask, image.shape[:2]): |
| return mask |
| |
| mask = _refine_with_grabcut(image, mask) |
| if _validate_mask_quality(mask, image.shape[:2]): |
| return mask |
| |
| return _geometric_person_mask(image) |
|
|
| |
|
|
| def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]: |
| try: |
| if hasattr(cv2, "saliency"): |
| s = cv2.saliency.StaticSaliencySpectralResidual_create() |
| ok, smap = s.computeSaliency(image) |
| if ok: |
| smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min()) |
| return smap |
| except Exception: |
| pass |
| return None |
|
|
| def _auto_person_rect(image): |
| sal = _compute_saliency(image) |
| if sal is None: |
| return None |
| m = (sal > SALIENCY_THRESH).astype(np.uint8) |
| cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if not cnts: |
| return None |
| x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea)) |
| H,W = image.shape[:2] |
| pad = 0.05 |
| x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad)) |
| w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad))) |
| return x,y,w,h |
|
|
| def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray: |
| h,w = image.shape[:2] |
| gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8) |
| gc[seed>200] = cv2.GC_FGD |
| rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7)) |
| bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64) |
| cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK) |
| return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8) |
|
|
| def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray: |
| sal = _compute_saliency(image) |
| if sal is None: |
| return seed |
| high = (sal > SALIENCY_THRESH).astype(np.uint8)*255 |
| ys,xs = np.where(seed>127) |
| cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2 |
| ff = high.copy() |
| cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5) |
| return ff |
|
|
| |
| |
| |
|
|
| def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool: |
| h,w = shape |
| ratio = np.sum(mask>127)/(h*w) |
| return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO |
|
|
| def _process_mask(mask: np.ndarray) -> np.ndarray: |
| """Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask""" |
| if mask.dtype in (np.float32, np.float64): |
| if mask.max() <= 1.0: |
| mask = (mask*255).astype(np.uint8) |
| if mask.dtype != np.uint8: |
| mask = mask.astype(np.uint8) |
| if mask.ndim == 3: |
| mask = mask.squeeze() |
| if mask.ndim == 3: |
| mask = mask[:,:,0] |
| _,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY) |
| return mask |
|
|
| def _geometric_person_mask(image: np.ndarray) -> np.ndarray: |
| h,w = image.shape[:2] |
| mask = np.zeros((h,w), np.uint8) |
| cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1) |
| return mask |
|
|
| |
| |
| |
|
|
| def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1): |
| |
| return mask |