| | |
| | """ |
| | utils.refinement |
| | High-quality mask refinement for BackgroundFX Pro. |
| | """ |
| |
|
| | from __future__ import annotations |
| | from typing import Any, Optional, Tuple, List |
| | import logging |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | class MaskRefinementError(Exception): |
| | """Custom exception for mask refinement errors""" |
| | pass |
| |
|
| | |
| | |
| | |
| | __all__ = [ |
| | "refine_mask_hq", |
| | "refine_masks_batch", |
| | "MaskRefinementError", |
| | ] |
| |
|
| | |
| | |
| | |
| | def refine_mask_hq( |
| | image: np.ndarray, |
| | mask: np.ndarray, |
| | matanyone_model: Optional[Any] = None, |
| | fallback_enabled: bool = True |
| | ) -> np.ndarray: |
| | """ |
| | High-quality mask refinement with multiple strategies. |
| | |
| | Args: |
| | image: Original BGR image |
| | mask: Initial binary mask (0/255) |
| | matanyone_model: Optional MatAnyone model for AI refinement |
| | fallback_enabled: Whether to use fallback methods if AI fails |
| | |
| | Returns: |
| | Refined binary mask (0/255) |
| | """ |
| | if image is None or mask is None: |
| | raise MaskRefinementError("Invalid input image or mask") |
| | |
| | if image.shape[:2] != mask.shape[:2]: |
| | raise MaskRefinementError(f"Image shape {image.shape[:2]} doesn't match mask shape {mask.shape[:2]}") |
| | |
| | |
| | if matanyone_model is not None: |
| | try: |
| | refined = _refine_with_matanyone(image, mask, matanyone_model) |
| | if _validate_refined_mask(refined, mask): |
| | return refined |
| | log.warning("MatAnyone refinement failed validation") |
| | except Exception as e: |
| | log.warning(f"MatAnyone refinement failed: {e}") |
| | |
| | |
| | if fallback_enabled: |
| | try: |
| | return _classical_refinement(image, mask) |
| | except Exception as e: |
| | log.warning(f"Classical refinement failed: {e}") |
| | return mask |
| | |
| | return mask |
| |
|
| | |
| | |
| | |
| | def refine_masks_batch( |
| | frames: List[np.ndarray], |
| | masks: List[np.ndarray], |
| | matanyone_model: Optional[Any] = None, |
| | fallback_enabled: bool = True |
| | ) -> List[np.ndarray]: |
| | """ |
| | Refine multiple masks using MatAnyone's temporal consistency. |
| | |
| | Args: |
| | frames: List of BGR images |
| | masks: List of initial binary masks |
| | matanyone_model: MatAnyone InferenceCore model |
| | fallback_enabled: Whether to use fallback methods |
| | |
| | Returns: |
| | List of refined binary masks |
| | """ |
| | if not frames or not masks: |
| | return masks |
| | |
| | if len(frames) != len(masks): |
| | raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}") |
| | |
| | if matanyone_model is not None: |
| | try: |
| | refined = _refine_batch_with_matanyone(frames, masks, matanyone_model) |
| | |
| | if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)): |
| | return refined |
| | log.warning("Batch MatAnyone refinement failed validation") |
| | except Exception as e: |
| | log.warning(f"Batch MatAnyone refinement failed: {e}") |
| | |
| | |
| | if fallback_enabled: |
| | return [_classical_refinement(f, m) for f, m in zip(frames, masks)] |
| | |
| | return masks |
| |
|
| | |
| | |
| | |
| | def _refine_with_matanyone( |
| | image: np.ndarray, |
| | mask: np.ndarray, |
| | model: Any |
| | ) -> np.ndarray: |
| | """Use MatAnyone model for mask refinement.""" |
| | try: |
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| | |
| | |
| | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | h, w = image_rgb.shape[:2] |
| | |
| | |
| | image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0 |
| | image_tensor = image_tensor.unsqueeze(0).to(device) |
| | |
| | |
| | if mask.ndim == 3: |
| | |
| | if mask.shape[2] == 3: |
| | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| | else: |
| | mask = mask[:, :, 0] |
| | |
| | |
| | if mask.dtype != np.uint8: |
| | mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8) |
| | |
| | |
| | assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}" |
| | assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})" |
| | |
| | |
| | mask_tensor = torch.from_numpy(mask).float() / 255.0 |
| | mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) |
| | |
| | |
| | assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})" |
| | |
| | |
| | result = None |
| | |
| | |
| | methods = [m for m in dir(model) if not m.startswith('_')] |
| | log.debug(f"MatAnyone InferenceCore methods: {methods}") |
| | |
| | with torch.no_grad(): |
| | if hasattr(model, 'step'): |
| | |
| | result = model.step(image_tensor, mask_tensor) |
| | elif hasattr(model, 'process_frame'): |
| | result = model.process_frame(image_tensor, mask_tensor) |
| | elif hasattr(model, 'forward'): |
| | result = model.forward(image_tensor, mask_tensor) |
| | elif hasattr(model, '__call__'): |
| | result = model(image_tensor, mask_tensor) |
| | else: |
| | raise MaskRefinementError(f"No recognized method. Available: {methods}") |
| | |
| | if result is None: |
| | raise MaskRefinementError("MatAnyone returned None") |
| | |
| | |
| | alpha = _extract_alpha_from_result(result) |
| | |
| | |
| | if isinstance(alpha, torch.Tensor): |
| | alpha = alpha.squeeze().cpu().numpy() |
| | |
| | if alpha.ndim == 3: |
| | alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0) |
| | |
| | if alpha.dtype != np.uint8: |
| | alpha = (alpha * 255).clip(0, 255).astype(np.uint8) |
| | |
| | if alpha.shape != (h, w): |
| | alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR) |
| | |
| | return _process_mask(alpha) |
| | |
| | except Exception as e: |
| | log.error(f"MatAnyone processing error: {str(e)}") |
| | raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}") |
| |
|
| | |
| | |
| | |
| | def _refine_batch_with_matanyone( |
| | frames: List[np.ndarray], |
| | masks: List[np.ndarray], |
| | model: Any |
| | ) -> List[np.ndarray]: |
| | """Process batch of frames through MatAnyone for temporal consistency.""" |
| | try: |
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| | |
| | batch_size = len(frames) |
| | h, w = frames[0].shape[:2] |
| | |
| | |
| | frame_tensors = [] |
| | for frame in frames: |
| | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0 |
| | frame_tensors.append(tensor) |
| | |
| | |
| | batch_tensor = torch.stack(frame_tensors).to(device) |
| | |
| | |
| | first_mask = masks[0] |
| | |
| | |
| | if first_mask.ndim == 3: |
| | if first_mask.shape[2] == 3: |
| | first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY) |
| | else: |
| | first_mask = first_mask[:, :, 0] |
| | |
| | if first_mask.dtype != np.uint8: |
| | first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8) |
| | |
| | assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}" |
| | |
| | |
| | first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0 |
| | first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device) |
| | |
| | refined_masks = [] |
| | |
| | with torch.no_grad(): |
| | |
| | if hasattr(model, 'process_batch'): |
| | |
| | results = model.process_batch(batch_tensor, first_mask_tensor) |
| | for result in results: |
| | alpha = _extract_alpha_from_result(result) |
| | refined_masks.append(_tensor_to_mask(alpha, h, w)) |
| | |
| | elif hasattr(model, 'step'): |
| | |
| | for i, frame_tensor in enumerate(frame_tensors): |
| | frame_on_device = frame_tensor.unsqueeze(0).to(device) |
| | if i == 0: |
| | |
| | result = model.step(frame_on_device, first_mask_tensor) |
| | else: |
| | |
| | result = model.step(frame_on_device, None) |
| | |
| | alpha = _extract_alpha_from_result(result) |
| | refined_masks.append(_tensor_to_mask(alpha, h, w)) |
| | |
| | else: |
| | |
| | log.warning("MatAnyone batch processing not available, using frame-by-frame") |
| | for frame_tensor, mask in zip(frame_tensors, masks): |
| | |
| | if mask.ndim == 3: |
| | if mask.shape[2] == 3: |
| | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| | else: |
| | mask = mask[:, :, 0] |
| | |
| | mask_tensor = torch.from_numpy(mask).float() / 255.0 |
| | mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) |
| | frame_on_device = frame_tensor.unsqueeze(0).to(device) |
| | |
| | result = model(frame_on_device, mask_tensor) |
| | alpha = _extract_alpha_from_result(result) |
| | refined_masks.append(_tensor_to_mask(alpha, h, w)) |
| | |
| | return refined_masks |
| | |
| | except Exception as e: |
| | log.error(f"Batch MatAnyone processing error: {str(e)}") |
| | raise MaskRefinementError(f"Batch processing failed: {str(e)}") |
| |
|
| | |
| | |
| | |
| | def _extract_alpha_from_result(result): |
| | """Extract alpha matte from various result formats.""" |
| | if isinstance(result, (tuple, list)): |
| | return result[0] if len(result) > 0 else None |
| | elif isinstance(result, dict): |
| | return result.get('alpha', result.get('matte', result.get('mask', None))) |
| | else: |
| | return result |
| |
|
| | def _tensor_to_mask(tensor, target_h, target_w): |
| | """Convert tensor to numpy mask with proper sizing.""" |
| | if isinstance(tensor, torch.Tensor): |
| | mask = tensor.squeeze().cpu().numpy() |
| | else: |
| | mask = tensor |
| | |
| | if mask.ndim == 3: |
| | mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0) |
| | |
| | if mask.dtype != np.uint8: |
| | mask = (mask * 255).clip(0, 255).astype(np.uint8) |
| | |
| | if mask.shape != (target_h, target_w): |
| | mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR) |
| | |
| | return mask |
| |
|
| | def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool: |
| | """Check if refined mask is reasonable.""" |
| | if refined is None or refined.size == 0: |
| | return False |
| | |
| | refined_area = np.sum(refined > 127) |
| | original_area = np.sum(original > 127) |
| | |
| | if refined_area == 0: |
| | return False |
| | |
| | ratio = refined_area / max(original_area, 1) |
| | return 0.5 <= ratio <= 2.0 |
| |
|
| | def _process_mask(mask: np.ndarray) -> np.ndarray: |
| | """Convert any mask format to binary 0/255.""" |
| | if mask.dtype == np.float32 or mask.dtype == np.float64: |
| | if mask.max() <= 1.0: |
| | mask = (mask * 255).astype(np.uint8) |
| | |
| | if mask.dtype != np.uint8: |
| | mask = mask.astype(np.uint8) |
| | |
| | if mask.ndim == 3: |
| | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| | |
| | _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| | return binary |
| |
|
| | |
| | |
| | |
| | def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| | """Apply classical CV techniques for mask refinement.""" |
| | refined = mask.copy() |
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| | refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel) |
| | refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel) |
| | refined = _edge_aware_smooth(image, refined) |
| | refined = _feather_edges(refined, radius=3) |
| | refined = _remove_small_components(refined, min_area_ratio=0.005) |
| | |
| | return refined |
| |
|
| | def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| | """Apply edge-aware smoothing using guided filter.""" |
| | mask_float = mask.astype(np.float32) / 255.0 |
| | radius = 5 |
| | eps = 0.01 |
| | |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 |
| | |
| | mean_I = cv2.boxFilter(gray, -1, (radius, radius)) |
| | mean_p = cv2.boxFilter(mask_float, -1, (radius, radius)) |
| | mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius)) |
| | |
| | cov_Ip = mean_Ip - mean_I * mean_p |
| | mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius)) |
| | var_I = mean_II - mean_I * mean_I |
| | |
| | a = cov_Ip / (var_I + eps) |
| | b = mean_p - a * mean_I |
| | |
| | mean_a = cv2.boxFilter(a, -1, (radius, radius)) |
| | mean_b = cv2.boxFilter(b, -1, (radius, radius)) |
| | |
| | refined = mean_a * gray + mean_b |
| | return (refined * 255).clip(0, 255).astype(np.uint8) |
| |
|
| | def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray: |
| | """Slightly blur edges for smoother transitions.""" |
| | if radius <= 0: |
| | return mask |
| | |
| | blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2) |
| | _, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY) |
| | return binary |
| |
|
| | def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray: |
| | """Remove small disconnected components.""" |
| | num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) |
| | |
| | if num_labels <= 1: |
| | return mask |
| | |
| | total_area = mask.shape[0] * mask.shape[1] |
| | min_area = int(total_area * min_area_ratio) |
| | |
| | areas = stats[1:, cv2.CC_STAT_AREA] |
| | if len(areas) == 0: |
| | return mask |
| | |
| | max_label = np.argmax(areas) + 1 |
| | |
| | cleaned = np.zeros_like(mask) |
| | for label in range(1, num_labels): |
| | if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label: |
| | cleaned[labels == label] = 255 |
| | |
| | return cleaned |