| |
| """ |
| Advanced matting algorithms for BackgroundFX Pro. |
| Implements multiple matting techniques with automatic fallback. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Dict, Optional |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from utils.logger import get_logger |
| from utils.hardware.device_manager import DeviceManager |
| from utils.config import ConfigManager |
| from core.models import ModelFactory, ModelType |
| from core.quality import QualityAnalyzer |
| from core.edge import EdgeRefinement |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| @dataclass |
| class MattingConfig: |
| """Configuration for matting operations.""" |
| alpha_threshold: float = 0.5 |
| erode_iterations: int = 2 |
| dilate_iterations: int = 2 |
| blur_radius: int = 3 |
| trimap_size: int = 30 |
| confidence_threshold: float = 0.7 |
| use_guided_filter: bool = True |
| guided_filter_radius: int = 8 |
| guided_filter_eps: float = 1e-6 |
| use_temporal_smoothing: bool = False |
| temporal_window: int = 5 |
|
|
|
|
| class AlphaMatting: |
| """Advanced alpha matting using multiple techniques.""" |
|
|
| def __init__(self, config: Optional[MattingConfig] = None): |
| self.config = config or MattingConfig() |
| self.device_manager = DeviceManager() |
| self.quality_analyzer = QualityAnalyzer() |
| self.edge_refinement = EdgeRefinement() |
|
|
| def create_trimap(self, mask: np.ndarray, dilation_size: Optional[int] = None) -> np.ndarray: |
| """ |
| Create trimap from a binary mask. |
| |
| Args: |
| mask: Binary mask (H, W) in {0, 255} or [0,1] |
| dilation_size: Size of uncertain region (pixels) |
| |
| Returns: |
| Trimap with values 0 (background), 128 (unknown), 255 (foreground) |
| """ |
| dilation_size = dilation_size or self.config.trimap_size |
|
|
| |
| if mask.dtype != np.uint8: |
| mask = (mask * 255).astype(np.uint8) |
| mask = (mask > 127).astype(np.uint8) * 255 |
|
|
| trimap = np.copy(mask) |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size)) |
|
|
| |
| dilated = cv2.dilate(mask, kernel, iterations=1) |
| eroded = cv2.erode(mask, kernel, iterations=1) |
|
|
| |
| trimap[:] = 0 |
| trimap[eroded == 255] = 255 |
| unknown = (dilated == 255) & (eroded == 0) |
| trimap[unknown] = 128 |
|
|
| return trimap |
|
|
| def guided_filter( |
| self, |
| image: np.ndarray, |
| guide: np.ndarray, |
| radius: Optional[int] = None, |
| eps: Optional[float] = None, |
| ) -> np.ndarray: |
| """ |
| Apply guided filter for edge-preserving smoothing. |
| |
| Args: |
| image: Input image to filter (H, W) uint8 |
| guide: Guide image (H, W, 3) or (H, W) |
| radius: Filter radius |
| eps: Regularization parameter |
| |
| Returns: |
| Filtered image (H, W) uint8 |
| """ |
| radius = radius or self.config.guided_filter_radius |
| eps = eps or self.config.guided_filter_eps |
|
|
| if guide.ndim == 3: |
| guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) |
| else: |
| guide_gray = guide |
|
|
| |
| I = guide_gray.astype(np.float32) / 255.0 |
| p = image.astype(np.float32) / 255.0 |
|
|
| |
| def box_filter(img, r): |
| return cv2.boxFilter(img, -1, (r, r)) |
|
|
| mean_I = box_filter(I, radius) |
| mean_p = box_filter(p, radius) |
| mean_Ip = box_filter(I * p, radius) |
| cov_Ip = mean_Ip - mean_I * mean_p |
|
|
| mean_II = box_filter(I * I, radius) |
| var_I = mean_II - mean_I * mean_I |
|
|
| a = cov_Ip / (var_I + eps) |
| b = mean_p - a * mean_I |
|
|
| mean_a = box_filter(a, radius) |
| mean_b = box_filter(b, radius) |
|
|
| q = mean_a * I + mean_b |
| return np.clip(q * 255.0, 0, 255).astype(np.uint8) |
|
|
| def closed_form_matting(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: |
| """ |
| Closed-form-inspired fast matting using distance transforms + optional guided filtering. |
| |
| Args: |
| image: RGB image (H, W, 3) uint8 |
| trimap: Trimap with values {0, 128, 255} |
| |
| Returns: |
| Alpha matte in [0,1] float32 |
| """ |
| h, w = trimap.shape[:2] |
| alpha = (trimap.astype(np.float32) / 255.0) |
|
|
| is_fg = trimap == 255 |
| is_bg = trimap == 0 |
| is_unknown = trimap == 128 |
|
|
| if not np.any(is_unknown): |
| return np.clip(alpha, 0.0, 1.0) |
|
|
| dist_fg = cv2.distanceTransform(is_fg.astype(np.uint8), cv2.DIST_L2, 5) |
| dist_bg = cv2.distanceTransform(is_bg.astype(np.uint8), cv2.DIST_L2, 5) |
|
|
| total = dist_fg + dist_bg + 1e-10 |
| alpha_unknown = dist_fg / total |
| alpha[is_unknown] = alpha_unknown[is_unknown] |
|
|
| if self.config.use_guided_filter: |
| alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
| alpha_u8 = self.guided_filter(alpha_u8, image) |
| alpha = alpha_u8.astype(np.float32) / 255.0 |
|
|
| return np.clip(alpha, 0.0, 1.0) |
|
|
| def deep_matting( |
| self, |
| image: np.ndarray, |
| mask: np.ndarray, |
| model: Optional[nn.Module] = None, |
| ) -> np.ndarray: |
| """ |
| Apply deep learning-based matting refinement. |
| |
| Args: |
| image: RGB image (H, W, 3) uint8 |
| mask: Initial mask (H, W) {0..255} or [0,1] |
| model: Optional pre-trained model taking (img, mask) → alpha |
| |
| Returns: |
| Refined alpha matte in [0,1] float32 |
| """ |
| device = self.device_manager.get_device() |
|
|
| h, w = image.shape[:2] |
| input_size = (512, 512) |
|
|
| img_rs = cv2.resize(image, input_size) |
| msk_rs = cv2.resize(mask, input_size) |
|
|
| img_t = torch.from_numpy(img_rs.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0 |
| msk_t = torch.from_numpy(msk_rs).float().unsqueeze(0).unsqueeze(0) |
| if msk_t.max() > 1.0: |
| msk_t = msk_t / 255.0 |
|
|
| img_t = img_t.to(device) |
| msk_t = msk_t.to(device) |
|
|
| with torch.no_grad(): |
| if model is None: |
| x = torch.cat([img_t, msk_t], dim=1) |
| refined = self._simple_refine_network(x) |
| else: |
| refined = model(img_t, msk_t) |
| alpha = refined.squeeze().float().cpu().numpy() |
|
|
| alpha = cv2.resize(alpha, (w, h)) |
| return np.clip(alpha, 0.0, 1.0) |
|
|
| def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor: |
| """Tiny non-learned refinement block (demo-quality).""" |
| |
| mask = x[:, 3:4, :, :] |
|
|
| refined = mask |
| for _ in range(3): |
| refined = F.avg_pool2d(refined, 3, stride=1, padding=1) |
| refined = torch.sigmoid((refined - 0.5) * 10.0) |
|
|
| return refined |
|
|
| def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray: |
| """ |
| Apply morphological operations and boundary smoothing. |
| |
| Args: |
| alpha: Alpha matte in [0,1] float32 |
| |
| Returns: |
| Refined alpha in [0,1] float32 |
| """ |
| alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
| |
| alpha_u8 = cv2.morphologyEx( |
| alpha_u8, cv2.MORPH_CLOSE, kernel, iterations=self.config.erode_iterations |
| ) |
| |
| alpha_u8 = cv2.morphologyEx( |
| alpha_u8, cv2.MORPH_OPEN, kernel, iterations=self.config.dilate_iterations |
| ) |
|
|
| if self.config.blur_radius > 0: |
| r = self.config.blur_radius * 2 + 1 |
| alpha_u8 = cv2.GaussianBlur(alpha_u8, (r, r), 0) |
|
|
| return alpha_u8.astype(np.float32) / 255.0 |
|
|
| def process(self, image: np.ndarray, mask: np.ndarray, method: str = "auto") -> Dict[str, np.ndarray]: |
| """ |
| Process image with selected matting method. |
| |
| Args: |
| image: RGB image (H, W, 3) uint8 |
| mask: Initial segmentation mask (H, W) |
| method: 'auto' | 'trimap' | 'deep' | 'guided' |
| |
| Returns: |
| dict(alpha, confidence, method_used, quality_metrics[, error]) |
| """ |
| try: |
| quality_metrics = self.quality_analyzer.analyze_frame(image) |
|
|
| chosen = method |
| if method == "auto": |
| |
| blur_score = quality_metrics.get("blur_score", 0.0) |
| edge_clarity = quality_metrics.get("edge_clarity", 0.0) |
| if blur_score > 50: |
| chosen = "guided" |
| elif edge_clarity > 0.7: |
| chosen = "trimap" |
| else: |
| chosen = "deep" |
|
|
| logger.info(f"Using matting method: {chosen}") |
|
|
| if chosen == "trimap": |
| trimap = self.create_trimap(mask) |
| alpha = self.closed_form_matting(image, trimap) |
| elif chosen == "deep": |
| alpha = self.deep_matting(image, mask) |
| elif chosen == "guided": |
| alpha = mask.astype(np.float32) |
| if alpha.max() > 1.0: |
| alpha = alpha / 255.0 |
| if self.config.use_guided_filter: |
| alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
| alpha = self.guided_filter(alpha_u8, image).astype(np.float32) / 255.0 |
| else: |
| alpha = mask.astype(np.float32) |
| if alpha.max() > 1.0: |
| alpha = alpha / 255.0 |
|
|
| |
| alpha = self.morphological_refinement(alpha) |
| alpha = self.edge_refinement.refine_edges( |
| image, np.clip(alpha * 255.0, 0, 255).astype(np.uint8) |
| ).astype(np.float32) / 255.0 |
|
|
| confidence = self._calculate_confidence(alpha, quality_metrics) |
|
|
| return { |
| "alpha": np.clip(alpha, 0.0, 1.0), |
| "confidence": float(np.clip(confidence, 0.0, 1.0)), |
| "method_used": chosen, |
| "quality_metrics": quality_metrics, |
| } |
|
|
| except Exception as e: |
| logger.error(f"Matting processing failed: {e}") |
| fallback = mask.astype(np.float32) |
| if fallback.max() > 1.0: |
| fallback = fallback / 255.0 |
| return { |
| "alpha": np.clip(fallback, 0.0, 1.0), |
| "confidence": 0.0, |
| "method_used": "fallback", |
| "error": str(e), |
| } |
|
|
| def _calculate_confidence(self, alpha: np.ndarray, quality_metrics: Dict) -> float: |
| """Calculate confidence score for the matting result.""" |
| confidence = float(quality_metrics.get("overall_quality", 0.5)) |
|
|
| alpha_mean = float(np.mean(alpha)) |
| alpha_std = float(np.std(alpha)) |
|
|
| |
| if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3: |
| confidence *= 1.2 |
|
|
| edges = cv2.Canny(np.clip(alpha * 255.0, 0, 255).astype(np.uint8), 50, 150) |
| edge_ratio = float(np.sum(edges > 0) / edges.size) |
| if edge_ratio < 0.1: |
| confidence *= 1.1 |
|
|
| return float(np.clip(confidence, 0.0, 1.0)) |
|
|
|
|
| class CompositingEngine: |
| """Handle alpha compositing and blending.""" |
|
|
| def __init__(self): |
| self.logger = get_logger(f"{__name__}.CompositingEngine") |
|
|
| def composite(self, foreground: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
| """ |
| Composite foreground over background using alpha. |
| |
| Args: |
| foreground: Foreground image (H, W, 3) uint8 |
| background: Background image (H, W, 3) uint8 |
| alpha: Alpha matte (H, W) or (H, W, 1) in [0..255] or [0..1] |
| |
| Returns: |
| Composited image (H, W, 3) uint8 |
| """ |
| |
| if alpha.ndim == 2: |
| alpha = np.expand_dims(alpha, axis=2) |
| if alpha.shape[2] == 1: |
| alpha = np.repeat(alpha, 3, axis=2) |
|
|
| |
| a = alpha.astype(np.float32) |
| if a.max() > 1.0: |
| a = a / 255.0 |
|
|
| fg = foreground.astype(np.float32) / 255.0 |
| bg = background.astype(np.float32) / 255.0 |
|
|
| result = fg * a + bg * (1.0 - a) |
| return np.clip(result * 255.0, 0, 255).astype(np.uint8) |
|
|
| def premultiply_alpha(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
| """ |
| Premultiply RGB image by alpha channel. |
| |
| Args: |
| image: (H, W, 3) uint8 |
| alpha: (H, W) or (H, W, 1) in [0..255] or [0..1] |
| |
| Returns: |
| Premultiplied (H, W, 3) uint8 |
| """ |
| if alpha.ndim == 2: |
| alpha = np.expand_dims(alpha, axis=2) |
| if alpha.shape[2] == 1: |
| alpha = np.repeat(alpha, 3, axis=2) |
|
|
| a = alpha.astype(np.float32) |
| if a.max() > 1.0: |
| a = a / 255.0 |
|
|
| img_f = image.astype(np.float32) |
| premul = img_f * a |
| return np.clip(premul, 0.0, 255.0).astype(np.uint8) |
|
|