Spaces:
Running
Running
| """Neural face verification, distortion detection, and restoration pipeline. | |
| End-to-end system that: | |
| 1. Detects face distortions (blur, beauty filters, compression, warping, etc.) | |
| 2. Classifies distortion type and severity using no-reference quality metrics | |
| 3. Restores faces using cascaded neural networks (CodeFormer → GFPGAN → Real-ESRGAN) | |
| 4. Verifies output identity matches input via ArcFace embeddings | |
| 5. Scores output realism using learned perceptual metrics | |
| Designed for: | |
| - Cleaning scraped training data (reject/fix bad images before pair generation) | |
| - Post-diffusion quality gate (ensure generated faces pass realism threshold) | |
| - Filter removal (undo Snapchat/Instagram beauty filters for clinical use) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import cv2 | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Data structures | |
| # --------------------------------------------------------------------------- | |
| class DistortionReport: | |
| """Analysis of detected distortions in a face image.""" | |
| # Overall quality score (0-100, higher = better) | |
| quality_score: float = 0.0 | |
| # Individual distortion scores (0-1, higher = more distorted) | |
| blur_score: float = 0.0 # Laplacian variance-based | |
| noise_score: float = 0.0 # High-freq energy ratio | |
| compression_score: float = 0.0 # JPEG block artifact detection | |
| oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection | |
| color_cast_score: float = 0.0 # Unnatural color shift | |
| geometric_distort: float = 0.0 # Face proportion anomalies | |
| lighting_score: float = 0.0 # Over/under exposure | |
| # Classification | |
| primary_distortion: str = "none" | |
| severity: str = "none" # none, mild, moderate, severe | |
| is_usable: bool = True # Whether image is worth restoring vs rejecting | |
| # Details | |
| details: dict = field(default_factory=dict) | |
| def summary(self) -> str: | |
| lines = [ | |
| f"Quality Score: {self.quality_score:.1f}/100", | |
| f"Primary Issue: {self.primary_distortion} ({self.severity})", | |
| f"Usable: {self.is_usable}", | |
| "", | |
| "Distortion Breakdown:", | |
| f" Blur: {self.blur_score:.3f}", | |
| f" Noise: {self.noise_score:.3f}", | |
| f" Compression: {self.compression_score:.3f}", | |
| f" Oversmooth: {self.oversmooth_score:.3f}", | |
| f" Color Cast: {self.color_cast_score:.3f}", | |
| f" Geometric: {self.geometric_distort:.3f}", | |
| f" Lighting: {self.lighting_score:.3f}", | |
| ] | |
| return "\n".join(lines) | |
| class RestorationResult: | |
| """Result of neural face restoration pipeline.""" | |
| restored: np.ndarray # Restored BGR image | |
| original: np.ndarray # Original BGR image | |
| distortion_report: DistortionReport # Pre-restoration analysis | |
| post_quality_score: float = 0.0 # Quality after restoration | |
| identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored) | |
| identity_preserved: bool = True # Whether identity check passed | |
| restoration_stages: list[str] = field(default_factory=list) # Which nets ran | |
| improvement: float = 0.0 # quality_after - quality_before | |
| def summary(self) -> str: | |
| lines = [ | |
| f"Pre-restoration: {self.distortion_report.quality_score:.1f}/100", | |
| f"Post-restoration: {self.post_quality_score:.1f}/100", | |
| f"Improvement: +{self.improvement:.1f}", | |
| f"Identity Sim: {self.identity_similarity:.3f}", | |
| f"Identity OK: {self.identity_preserved}", | |
| f"Stages Used: {' → '.join(self.restoration_stages) or 'none'}", | |
| ] | |
| return "\n".join(lines) | |
| class BatchVerificationReport: | |
| """Summary of batch face verification/restoration.""" | |
| total: int = 0 | |
| passed: int = 0 # Good quality, no fix needed | |
| restored: int = 0 # Fixed and now usable | |
| rejected: int = 0 # Too distorted to salvage | |
| identity_failures: int = 0 # Restoration changed identity | |
| avg_quality_before: float = 0.0 | |
| avg_quality_after: float = 0.0 | |
| avg_identity_sim: float = 0.0 | |
| distortion_counts: dict[str, int] = field(default_factory=dict) | |
| def summary(self) -> str: | |
| lines = [ | |
| f"Total Images: {self.total}", | |
| f" Passed (good): {self.passed}", | |
| f" Restored: {self.restored}", | |
| f" Rejected: {self.rejected}", | |
| f" Identity Fail: {self.identity_failures}", | |
| f"Avg Quality Before: {self.avg_quality_before:.1f}", | |
| f"Avg Quality After: {self.avg_quality_after:.1f}", | |
| f"Avg Identity Sim: {self.avg_identity_sim:.3f}", | |
| "", | |
| "Distortion Breakdown:", | |
| ] | |
| for dist_type, count in sorted( | |
| self.distortion_counts.items(), | |
| key=lambda x: -x[1], | |
| ): | |
| lines.append(f" {dist_type}: {count}") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Distortion Detection (classical + neural) | |
| # --------------------------------------------------------------------------- | |
| def detect_blur(image: np.ndarray) -> float: | |
| """Detect blur using Laplacian variance. | |
| Low variance = blurry. We normalize to 0-1 where 1 = very blurry. | |
| Uses both Laplacian variance and gradient magnitude for robustness. | |
| """ | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image | |
| # Laplacian variance (primary metric) | |
| lap_var = cv2.Laplacian(gray, cv2.CV_64F).var() | |
| # Gradient magnitude (secondary) | |
| gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) | |
| gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) | |
| grad_mag = np.sqrt(gx**2 + gy**2).mean() | |
| # Normalize: typical sharp face has lap_var > 500, grad_mag > 30 | |
| blur_lap = 1.0 - min(lap_var / 800.0, 1.0) | |
| blur_grad = 1.0 - min(grad_mag / 50.0, 1.0) | |
| return float(np.clip(0.6 * blur_lap + 0.4 * blur_grad, 0, 1)) | |
| def detect_noise(image: np.ndarray) -> float: | |
| """Detect image noise level. | |
| Estimates noise by measuring high-frequency energy in smooth regions. | |
| Uses the median absolute deviation of the Laplacian (robust estimator). | |
| """ | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image | |
| # Robust noise estimation via MAD of Laplacian | |
| lap = cv2.Laplacian(gray.astype(np.float64), cv2.CV_64F) | |
| sigma_est = np.median(np.abs(lap)) * 1.4826 # MAD → std conversion | |
| # Normalize: sigma > 20 is very noisy | |
| return float(np.clip(sigma_est / 25.0, 0, 1)) | |
| def detect_compression_artifacts(image: np.ndarray) -> float: | |
| """Detect JPEG compression block artifacts. | |
| Measures energy at 8x8 block boundaries (JPEG DCT block size). | |
| High boundary energy relative to interior = compression artifacts. | |
| """ | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image | |
| h, w = gray.shape | |
| if h < 16 or w < 16: | |
| return 0.0 | |
| gray_f = gray.astype(np.float64) | |
| # Compute horizontal and vertical differences | |
| h_diff = np.abs(np.diff(gray_f, axis=1)) | |
| v_diff = np.abs(np.diff(gray_f, axis=0)) | |
| # Energy at 8-pixel boundaries vs non-boundaries | |
| h_boundary = h_diff[:, 7::8].mean() if h_diff[:, 7::8].size > 0 else 0 | |
| h_interior = h_diff.mean() | |
| v_boundary = v_diff[7::8, :].mean() if v_diff[7::8, :].size > 0 else 0 | |
| v_interior = v_diff.mean() | |
| if h_interior < 1e-6 or v_interior < 1e-6: | |
| return 0.0 | |
| # Ratio of boundary to interior energy (>1 means block artifacts) | |
| h_ratio = h_boundary / (h_interior + 1e-6) | |
| v_ratio = v_boundary / (v_interior + 1e-6) | |
| artifact_ratio = (h_ratio + v_ratio) / 2.0 | |
| # Normalize: ratio > 1.5 indicates visible artifacts | |
| return float(np.clip((artifact_ratio - 1.0) / 0.8, 0, 1)) | |
| def detect_oversmoothing(image: np.ndarray) -> float: | |
| """Detect beauty filter / airbrushed skin (oversmoothing). | |
| Beauty filters remove skin texture while preserving edges. We detect | |
| this by measuring the ratio of edge energy to texture energy. | |
| High edge / low texture = beauty filtered. | |
| """ | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image | |
| h, w = gray.shape | |
| # Focus on face center region (avoid background) | |
| if h < 8 or w < 8: | |
| return 0.0 # Too small to analyze | |
| roi = gray[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] | |
| # Texture energy: variance of high-pass filtered image | |
| blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0) | |
| high_pass = roi.astype(np.float64) - blurred | |
| texture_energy = np.var(high_pass) | |
| # Edge energy: Canny edge density | |
| edges = cv2.Canny(roi, 50, 150) | |
| edge_density = np.mean(edges > 0) | |
| # Oversmooth: low texture but edges still present | |
| # Natural skin: texture_energy > 20, beauty filter: < 8 | |
| smooth_score = 1.0 - min(texture_energy / 30.0, 1.0) | |
| # If there are still strong edges but no texture, it's a filter | |
| if edge_density > 0.02: | |
| smooth_score *= 1.3 # Amplify if edges present but no texture | |
| return float(np.clip(smooth_score, 0, 1)) | |
| def detect_color_cast(image: np.ndarray) -> float: | |
| """Detect unnatural color cast (Instagram-style filters). | |
| Measures deviation of average A/B channels in LAB space from | |
| neutral. Natural skin has consistent LAB distributions; filtered | |
| images shift these channels. | |
| """ | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| h, w = image.shape[:2] | |
| # Sample face center region | |
| roi = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] | |
| # A channel: green-red axis (neutral ~128) | |
| # B channel: blue-yellow axis (neutral ~128) | |
| a_mean = roi[:, :, 1].mean() | |
| b_mean = roi[:, :, 2].mean() | |
| # Deviation from neutral | |
| a_dev = abs(a_mean - 128) / 128.0 | |
| b_dev = abs(b_mean - 128) / 128.0 | |
| # Also check if color distribution is unnaturally narrow (saturated filter) | |
| a_std = roi[:, :, 1].std() | |
| b_std = roi[:, :, 2].std() | |
| narrow_color = max(0, 1.0 - (a_std + b_std) / 30.0) | |
| score = 0.5 * (a_dev + b_dev) + 0.3 * narrow_color | |
| return float(np.clip(score, 0, 1)) | |
| def detect_geometric_distortion(image: np.ndarray) -> float: | |
| """Detect geometric face distortion (warping filters, lens distortion). | |
| Uses MediaPipe landmarks to check face proportions against anatomical | |
| norms. Distorted faces have abnormal inter-ocular / face-width ratios. | |
| """ | |
| try: | |
| from landmarkdiff.landmarks import extract_landmarks | |
| except ImportError: | |
| return 0.0 | |
| face = extract_landmarks(image) | |
| if face is None: | |
| return 0.5 # Can't detect face = possibly distorted | |
| coords = face.pixel_coords | |
| h, w = image.shape[:2] | |
| if len(coords) < 478: | |
| return 0.5 # Incomplete landmark set | |
| # Key ratios that should be anatomically consistent | |
| left_eye = coords[33] | |
| right_eye = coords[263] | |
| nose_tip = coords[1] | |
| chin = coords[152] | |
| forehead = coords[10] | |
| iod = np.linalg.norm(left_eye - right_eye) | |
| face_height = np.linalg.norm(forehead - chin) | |
| nose_to_chin = np.linalg.norm(nose_tip - chin) | |
| if iod < 1.0 or face_height < 1.0: | |
| return 0.5 | |
| # Anatomical norms (approximate): | |
| # face_height / iod ≈ 2.5-3.5 | |
| # nose_to_chin / face_height ≈ 0.3-0.45 | |
| height_ratio = face_height / iod | |
| lower_ratio = nose_to_chin / face_height | |
| # Score deviations from normal ranges | |
| height_dev = max(0, abs(height_ratio - 3.0) - 0.5) / 1.5 | |
| lower_dev = max(0, abs(lower_ratio - 0.38) - 0.08) / 0.15 | |
| # Eye symmetry check (vertical alignment) | |
| eye_tilt = abs(left_eye[1] - right_eye[1]) / (iod + 1e-6) | |
| tilt_dev = max(0, eye_tilt - 0.05) / 0.15 | |
| score = 0.4 * height_dev + 0.3 * lower_dev + 0.3 * tilt_dev | |
| return float(np.clip(score, 0, 1)) | |
| def detect_lighting_issues(image: np.ndarray) -> float: | |
| """Detect over/under exposure and harsh lighting. | |
| Checks luminance histogram for clipping and uneven distribution. | |
| """ | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
| l_channel = lab[:, :, 0] | |
| # Check for clipping | |
| overexposed = np.mean(l_channel > 245) * 5 # Fraction near white | |
| underexposed = np.mean(l_channel < 10) * 5 # Fraction near black | |
| # Check for bimodal distribution (harsh shadows) | |
| hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten() | |
| hist_sum = hist.sum() | |
| if hist_sum < 1e-10: | |
| return 0.0 | |
| hist = hist / hist_sum | |
| # Measure how spread out the histogram is | |
| entropy = -np.sum(hist[hist > 0] * np.log2(hist[hist > 0] + 1e-10)) | |
| # Low entropy = concentrated = potentially problematic | |
| entropy_score = max(0, 1.0 - entropy / 7.0) | |
| score = 0.4 * overexposed + 0.4 * underexposed + 0.2 * entropy_score | |
| return float(np.clip(score, 0, 1)) | |
| def analyze_distortions(image: np.ndarray) -> DistortionReport: | |
| """Run full distortion analysis on a face image. | |
| Combines all detection methods into a comprehensive report with | |
| quality score, primary distortion classification, and severity. | |
| """ | |
| blur = detect_blur(image) | |
| noise = detect_noise(image) | |
| compression = detect_compression_artifacts(image) | |
| oversmooth = detect_oversmoothing(image) | |
| color_cast = detect_color_cast(image) | |
| geometric = detect_geometric_distortion(image) | |
| lighting = detect_lighting_issues(image) | |
| # Overall quality: weighted combination (inverted — 100 = perfect) | |
| weighted = ( | |
| 0.25 * blur | |
| + 0.15 * noise | |
| + 0.10 * compression | |
| + 0.20 * oversmooth | |
| + 0.10 * color_cast | |
| + 0.10 * geometric | |
| + 0.10 * lighting | |
| ) | |
| quality = (1.0 - weighted) * 100.0 | |
| # Classify primary distortion | |
| scores = { | |
| "blur": blur, | |
| "noise": noise, | |
| "compression": compression, | |
| "oversmooth": oversmooth, | |
| "color_cast": color_cast, | |
| "geometric": geometric, | |
| "lighting": lighting, | |
| } | |
| primary = max(scores, key=scores.get) | |
| primary_val = scores[primary] | |
| if primary_val < 0.15: | |
| severity = "none" | |
| primary = "none" | |
| elif primary_val < 0.35: | |
| severity = "mild" | |
| elif primary_val < 0.60: | |
| severity = "moderate" | |
| else: | |
| severity = "severe" | |
| # Image is usable if quality > 30 and no severe geometric distortion | |
| is_usable = quality > 25 and geometric < 0.7 | |
| return DistortionReport( | |
| quality_score=quality, | |
| blur_score=blur, | |
| noise_score=noise, | |
| compression_score=compression, | |
| oversmooth_score=oversmooth, | |
| color_cast_score=color_cast, | |
| geometric_distort=geometric, | |
| lighting_score=lighting, | |
| primary_distortion=primary, | |
| severity=severity, | |
| is_usable=is_usable, | |
| details=scores, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Neural Face Quality Scoring (no-reference) | |
| # --------------------------------------------------------------------------- | |
| _FACE_QUALITY_NET = None | |
| def _get_face_quality_scorer() -> Any: | |
| """Get or create singleton face quality assessment model. | |
| Uses FaceXLib's quality scorer or falls back to BRISQUE-style features. | |
| """ | |
| global _FACE_QUALITY_NET | |
| if _FACE_QUALITY_NET is not None: | |
| return _FACE_QUALITY_NET | |
| try: | |
| from facexlib.assessment import init_assessment_model | |
| _FACE_QUALITY_NET = init_assessment_model("hypernet") | |
| return _FACE_QUALITY_NET | |
| except Exception: | |
| pass | |
| return None | |
| def neural_quality_score(image: np.ndarray) -> float: | |
| """Score face quality using neural network (0-100, higher = better). | |
| Tries FaceXLib quality assessment first, then falls back to | |
| BRISQUE-style scoring using OpenCV's QualityBRISQUE if available, | |
| or classical metrics as last resort. | |
| """ | |
| # Try neural scorer | |
| scorer = _get_face_quality_scorer() | |
| if scorer is not None: | |
| try: | |
| import torch | |
| from facexlib.utils import img2tensor | |
| img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True) | |
| img_t = img_t.unsqueeze(0) | |
| if torch.cuda.is_available(): | |
| img_t = img_t.cuda() | |
| scorer = scorer.cuda() | |
| with torch.no_grad(): | |
| score = scorer(img_t).item() | |
| return float(np.clip(score * 100, 0, 100)) | |
| except Exception: | |
| pass | |
| # Fallback: composite classical score | |
| report = analyze_distortions(image) | |
| return report.quality_score | |
| # --------------------------------------------------------------------------- | |
| # Neural Face Restoration (cascaded) | |
| # --------------------------------------------------------------------------- | |
| def restore_face( | |
| image: np.ndarray, | |
| distortion: DistortionReport | None = None, | |
| mode: str = "auto", | |
| codeformer_fidelity: float = 0.7, | |
| ) -> tuple[np.ndarray, list[str]]: | |
| """Cascaded neural face restoration. | |
| Selects and applies restoration networks based on detected distortions: | |
| - Blur/oversmooth → CodeFormer (recovers texture from codebook) | |
| - Noise/compression → GFPGAN (trained on degraded faces) | |
| - Background → Real-ESRGAN (neural 4x upscale + downsample) | |
| - Color cast → Classical LAB correction (no neural net needed) | |
| - Geometric → Not fixable by restoration (flag and skip) | |
| Args: | |
| image: BGR face image to restore. | |
| distortion: Pre-computed distortion report (computed if None). | |
| mode: 'auto' (choose based on distortion), 'codeformer', 'gfpgan', 'all'. | |
| codeformer_fidelity: CodeFormer quality-fidelity tradeoff. | |
| Returns: | |
| Tuple of (restored BGR image, list of stages applied). | |
| """ | |
| if distortion is None: | |
| distortion = analyze_distortions(image) | |
| result = image.copy() | |
| stages = [] | |
| # Step 0: Fix color cast first (classical — fast, doesn't affect identity) | |
| if distortion.color_cast_score > 0.25: | |
| result = _fix_color_cast(result) | |
| stages.append("color_correction") | |
| # Step 1: Fix lighting issues (classical) | |
| if distortion.lighting_score > 0.35: | |
| result = _fix_lighting(result) | |
| stages.append("lighting_fix") | |
| # Step 2: Neural face restoration | |
| if mode == "auto": | |
| # Choose based on what's wrong | |
| needs_face_restore = ( | |
| distortion.blur_score > 0.2 | |
| or distortion.oversmooth_score > 0.25 | |
| or distortion.noise_score > 0.25 | |
| or distortion.compression_score > 0.2 | |
| ) | |
| if needs_face_restore: | |
| mode = "codeformer" # CodeFormer handles most degradations well | |
| if mode in ("codeformer", "all"): | |
| restored = _try_codeformer(result, fidelity=codeformer_fidelity) | |
| if restored is not None: | |
| result = restored | |
| stages.append("codeformer") | |
| else: | |
| # Fallback to GFPGAN | |
| restored = _try_gfpgan(result) | |
| if restored is not None: | |
| result = restored | |
| stages.append("gfpgan") | |
| elif mode == "gfpgan": | |
| restored = _try_gfpgan(result) | |
| if restored is not None: | |
| result = restored | |
| stages.append("gfpgan") | |
| # Step 3: Background enhancement with Real-ESRGAN (if image is low-res) | |
| h, w = result.shape[:2] | |
| if h < 400 or w < 400: | |
| enhanced = _try_realesrgan(result) | |
| if enhanced is not None: | |
| result = enhanced | |
| stages.append("realesrgan") | |
| # Step 4: Mild sharpening if still soft after restoration | |
| post_blur = detect_blur(result) | |
| if post_blur > 0.3: | |
| from landmarkdiff.postprocess import frequency_aware_sharpen | |
| result = frequency_aware_sharpen(result, strength=0.3) | |
| stages.append("sharpen") | |
| return result, stages | |
| def _try_codeformer(image: np.ndarray, fidelity: float = 0.7) -> np.ndarray | None: | |
| """Try CodeFormer restoration. Returns None if unavailable.""" | |
| try: | |
| from landmarkdiff.postprocess import restore_face_codeformer | |
| restored = restore_face_codeformer(image, fidelity=fidelity) | |
| if restored is not image: | |
| return restored | |
| except Exception: | |
| pass | |
| return None | |
| def _try_gfpgan(image: np.ndarray) -> np.ndarray | None: | |
| """Try GFPGAN restoration. Returns None if unavailable.""" | |
| try: | |
| from landmarkdiff.postprocess import restore_face_gfpgan | |
| restored = restore_face_gfpgan(image) | |
| if restored is not image: | |
| return restored | |
| except Exception: | |
| pass | |
| return None | |
| _FV_REALESRGAN = None | |
| def _try_realesrgan(image: np.ndarray) -> np.ndarray | None: | |
| """Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable.""" | |
| try: | |
| import torch | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| global _FV_REALESRGAN | |
| if _FV_REALESRGAN is None: | |
| model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=4, | |
| ) | |
| _FV_REALESRGAN = RealESRGANer( | |
| scale=4, | |
| model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", | |
| model=model, | |
| tile=400, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=torch.cuda.is_available(), | |
| ) | |
| enhanced, _ = _FV_REALESRGAN.enhance(image, outscale=2) | |
| # Downsample to 512x512 for pipeline consistency | |
| enhanced = cv2.resize(enhanced, (512, 512), interpolation=cv2.INTER_LANCZOS4) | |
| return enhanced | |
| except Exception: | |
| pass | |
| return None | |
| def _fix_color_cast(image: np.ndarray) -> np.ndarray: | |
| """Remove color cast by normalizing A/B channels in LAB space.""" | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| # Center A and B channels around 128 (neutral) | |
| for ch in [1, 2]: | |
| channel = lab[:, :, ch] | |
| mean_val = channel.mean() | |
| # Shift toward neutral, but only partially to preserve natural skin tone | |
| shift = (128.0 - mean_val) * 0.6 | |
| lab[:, :, ch] = np.clip(channel + shift, 0, 255) | |
| return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR) | |
| def _fix_lighting(image: np.ndarray) -> np.ndarray: | |
| """Fix over/under exposure using adaptive CLAHE in LAB space.""" | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
| # CLAHE on luminance channel only | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) | |
| return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) | |
| # --------------------------------------------------------------------------- | |
| # ArcFace Identity Verification | |
| # --------------------------------------------------------------------------- | |
| _ARCFACE_APP = None | |
| def _get_arcface() -> Any: | |
| """Get or create singleton ArcFace model.""" | |
| global _ARCFACE_APP | |
| if _ARCFACE_APP is not None: | |
| return _ARCFACE_APP | |
| try: | |
| import torch | |
| from insightface.app import FaceAnalysis | |
| app = FaceAnalysis( | |
| name="buffalo_l", | |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"], | |
| ) | |
| ctx_id = 0 if torch.cuda.is_available() else -1 | |
| app.prepare(ctx_id=ctx_id, det_size=(320, 320)) | |
| _ARCFACE_APP = app | |
| return app | |
| except Exception: | |
| return None | |
| def get_face_embedding(image: np.ndarray) -> np.ndarray | None: | |
| """Extract ArcFace 512-d embedding from a face image. | |
| Returns None if no face detected or InsightFace unavailable. | |
| """ | |
| app = _get_arcface() | |
| if app is None: | |
| return None | |
| try: | |
| faces = app.get(image) | |
| if faces: | |
| emb = faces[0].embedding | |
| if np.linalg.norm(emb) < 1e-6: | |
| logger.warning("ArcFace returned near-zero embedding (occluded face?)") | |
| return None | |
| return emb | |
| except Exception: | |
| pass | |
| return None | |
| def verify_identity( | |
| original: np.ndarray, | |
| restored: np.ndarray, | |
| threshold: float = 0.5, | |
| ) -> tuple[float, bool]: | |
| """Compare identity between original and restored using ArcFace. | |
| Returns (cosine_similarity, passed). | |
| Similarity > threshold means same person (0.5 accommodates non-frontal poses). | |
| """ | |
| emb_orig = get_face_embedding(original) | |
| emb_rest = get_face_embedding(restored) | |
| if emb_orig is None or emb_rest is None: | |
| return -1.0, True # Can't verify — assume OK | |
| sim = float( | |
| np.dot(emb_orig, emb_rest) / (np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8) | |
| ) | |
| sim = float(np.clip(sim, -1, 1)) | |
| return sim, sim >= threshold | |
| # --------------------------------------------------------------------------- | |
| # Full Verification + Restoration Pipeline | |
| # --------------------------------------------------------------------------- | |
| def verify_and_restore( | |
| image: np.ndarray, | |
| quality_threshold: float = 60.0, | |
| identity_threshold: float = 0.5, | |
| restore_mode: str = "auto", | |
| codeformer_fidelity: float = 0.7, | |
| ) -> RestorationResult: | |
| """Full pipeline: analyze -> restore -> verify identity. | |
| This is the main entry point for the face verifier. It: | |
| 1. Analyzes the input for distortions | |
| 2. If quality is below threshold, applies neural restoration | |
| 3. Verifies the restored face preserves identity | |
| 4. Returns comprehensive result with metrics | |
| Args: | |
| image: BGR face image. | |
| quality_threshold: Min quality to skip restoration (0-100). | |
| identity_threshold: Min ArcFace similarity to pass (0-1). | |
| restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'. | |
| codeformer_fidelity: CodeFormer quality-fidelity balance. | |
| Returns: | |
| RestorationResult with restored image and full metrics. | |
| """ | |
| # Step 1: Analyze distortions | |
| report = analyze_distortions(image) | |
| # Step 2: Decide if restoration needed | |
| if report.quality_score >= quality_threshold and report.severity in ("none", "mild"): | |
| # Image is good enough — no restoration needed | |
| return RestorationResult( | |
| restored=image.copy(), | |
| original=image.copy(), | |
| distortion_report=report, | |
| post_quality_score=report.quality_score, | |
| identity_similarity=1.0, | |
| identity_preserved=True, | |
| restoration_stages=[], | |
| improvement=0.0, | |
| ) | |
| if not report.is_usable: | |
| # Too distorted to salvage | |
| return RestorationResult( | |
| restored=image.copy(), | |
| original=image.copy(), | |
| distortion_report=report, | |
| post_quality_score=report.quality_score, | |
| identity_similarity=0.0, | |
| identity_preserved=False, | |
| restoration_stages=["rejected"], | |
| improvement=0.0, | |
| ) | |
| # Step 3: Neural restoration | |
| restored, stages = restore_face( | |
| image, | |
| distortion=report, | |
| mode=restore_mode, | |
| codeformer_fidelity=codeformer_fidelity, | |
| ) | |
| # Step 4: Post-restoration quality check | |
| post_quality = neural_quality_score(restored) | |
| # Step 5: Identity verification | |
| sim, id_ok = verify_identity(image, restored, threshold=identity_threshold) | |
| return RestorationResult( | |
| restored=restored, | |
| original=image.copy(), | |
| distortion_report=report, | |
| post_quality_score=post_quality, | |
| identity_similarity=sim, | |
| identity_preserved=id_ok, | |
| restoration_stages=stages, | |
| improvement=post_quality - report.quality_score, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Batch Processing | |
| # --------------------------------------------------------------------------- | |
| def verify_batch( | |
| image_dir: str, | |
| output_dir: str | None = None, | |
| quality_threshold: float = 60.0, | |
| identity_threshold: float = 0.5, | |
| restore_mode: str = "auto", | |
| save_rejected: bool = False, | |
| extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp", ".bmp"), | |
| ) -> BatchVerificationReport: | |
| """Process a directory of face images: analyze, restore, verify, sort. | |
| Outputs: | |
| - {output_dir}/passed/ — good images (no fix needed) | |
| - {output_dir}/restored/ — fixed images | |
| - {output_dir}/rejected/ — too distorted to use (if save_rejected=True) | |
| - {output_dir}/report.txt — batch verification report | |
| Args: | |
| image_dir: Directory of face images to process. | |
| output_dir: Where to save results (default: {image_dir}_verified/). | |
| quality_threshold: Min quality to pass without restoration. | |
| identity_threshold: Min identity similarity after restoration. | |
| restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'. | |
| save_rejected: Whether to copy rejected images to rejected/ subdir. | |
| extensions: File extensions to process. | |
| Returns: | |
| BatchVerificationReport with summary statistics. | |
| """ | |
| image_path = Path(image_dir) | |
| if output_dir is None: | |
| out_path = image_path.parent / f"{image_path.name}_verified" | |
| else: | |
| out_path = Path(output_dir) | |
| # Create output dirs | |
| passed_dir = out_path / "passed" | |
| restored_dir = out_path / "restored" | |
| rejected_dir = out_path / "rejected" | |
| passed_dir.mkdir(parents=True, exist_ok=True) | |
| restored_dir.mkdir(parents=True, exist_ok=True) | |
| if save_rejected: | |
| rejected_dir.mkdir(parents=True, exist_ok=True) | |
| # Find all images | |
| image_files = sorted( | |
| [f for f in image_path.iterdir() if f.suffix.lower() in extensions and f.is_file()] | |
| ) | |
| report = BatchVerificationReport(total=len(image_files)) | |
| quality_before = [] | |
| quality_after = [] | |
| identity_sims = [] | |
| for i, img_file in enumerate(image_files): | |
| if (i + 1) % 50 == 0 or i == 0: | |
| logger.info("Processing %d/%d: %s", i + 1, len(image_files), img_file.name) | |
| image = cv2.imread(str(img_file)) | |
| if image is None: | |
| report.rejected += 1 | |
| continue | |
| # Resize to 512x512 for consistency | |
| image = cv2.resize(image, (512, 512)) | |
| # Run verification + restoration | |
| result = verify_and_restore( | |
| image, | |
| quality_threshold=quality_threshold, | |
| identity_threshold=identity_threshold, | |
| restore_mode=restore_mode, | |
| ) | |
| quality_before.append(result.distortion_report.quality_score) | |
| quality_after.append(result.post_quality_score) | |
| # Track distortion types | |
| dist_type = result.distortion_report.primary_distortion | |
| report.distortion_counts[dist_type] = report.distortion_counts.get(dist_type, 0) + 1 | |
| if not result.distortion_report.is_usable or "rejected" in result.restoration_stages: | |
| report.rejected += 1 | |
| if save_rejected: | |
| cv2.imwrite(str(rejected_dir / img_file.name), image) | |
| elif not result.restoration_stages: | |
| # Passed without restoration | |
| report.passed += 1 | |
| cv2.imwrite(str(passed_dir / img_file.name), image) | |
| else: | |
| # Restored | |
| if result.identity_preserved: | |
| report.restored += 1 | |
| cv2.imwrite(str(restored_dir / img_file.name), result.restored) | |
| identity_sims.append(result.identity_similarity) | |
| else: | |
| report.identity_failures += 1 | |
| if save_rejected: | |
| cv2.imwrite(str(rejected_dir / img_file.name), image) | |
| # Compute averages | |
| report.avg_quality_before = float(np.mean(quality_before)) if quality_before else 0.0 | |
| report.avg_quality_after = float(np.mean(quality_after)) if quality_after else 0.0 | |
| report.avg_identity_sim = float(np.mean(identity_sims)) if identity_sims else 0.0 | |
| # Save report | |
| report_text = report.summary() | |
| (out_path / "report.txt").write_text(report_text) | |
| logger.info("\n%s", report_text) | |
| logger.info("Results saved to %s/", out_path) | |
| return report | |