Spaces:
Running
Running
| """Clinical safety validation for responsible deployment. | |
| Implements safety checks for surgical outcome predictions: | |
| 1. Identity preservation: verify output preserves patient identity | |
| 2. Anatomical plausibility: check landmark displacements are realistic | |
| 3. Out-of-distribution detection: flag unusual inputs | |
| 4. Watermarking: mark AI-generated images | |
| 5. Consent metadata: embed provenance information | |
| Usage: | |
| from landmarkdiff.safety import SafetyValidator | |
| validator = SafetyValidator() | |
| result = validator.validate( | |
| input_image=image, | |
| output_image=generated, | |
| landmarks_original=face.landmarks, | |
| landmarks_manipulated=manip.landmarks, | |
| procedure="rhinoplasty", | |
| ) | |
| if not result.passed: | |
| print(f"Safety check failed: {result.failures}") | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| import cv2 | |
| import numpy as np | |
| class SafetyResult: | |
| """Result of safety validation checks.""" | |
| passed: bool = True | |
| failures: list[str] = field(default_factory=list) | |
| warnings: list[str] = field(default_factory=list) | |
| checks: dict[str, bool] = field(default_factory=dict) | |
| details: dict[str, object] = field(default_factory=dict) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"SafetyResult(passed={self.passed}, " | |
| f"failures={self.failures}, " | |
| f"warnings={self.warnings}, " | |
| f"checks={self.checks}, " | |
| f"details={self.details})" | |
| ) | |
| def __eq__(self, other: object) -> bool: | |
| if not isinstance(other, SafetyResult): | |
| return NotImplemented | |
| return ( | |
| self.passed == other.passed | |
| and self.failures == other.failures | |
| and self.warnings == other.warnings | |
| and self.checks == other.checks | |
| and self.details == other.details | |
| ) | |
| def add_failure(self, name: str, message: str) -> None: | |
| self.passed = False | |
| self.failures.append(message) | |
| self.checks[name] = False | |
| def add_warning(self, name: str, message: str) -> None: | |
| self.warnings.append(message) | |
| def add_pass(self, name: str) -> None: | |
| self.checks[name] = True | |
| def summary(self) -> str: | |
| lines = [f"Safety: {'PASS' if self.passed else 'FAIL'}"] | |
| for name, ok in self.checks.items(): | |
| lines.append(f" [{'OK' if ok else 'FAIL'}] {name}") | |
| for w in self.warnings: | |
| lines.append(f" [WARN] {w}") | |
| return "\n".join(lines) | |
| class SafetyValidator: | |
| """Clinical safety validation for surgical predictions.""" | |
| def __init__( | |
| self, | |
| identity_threshold: float = 0.5, | |
| max_displacement_fraction: float = 0.05, | |
| min_face_confidence: float = 0.5, | |
| max_yaw_degrees: float = 45.0, | |
| watermark_enabled: bool = True, | |
| watermark_text: str = "AI-GENERATED PREDICTION", | |
| ): | |
| self.identity_threshold = identity_threshold | |
| self.max_displacement_fraction = max_displacement_fraction | |
| self.min_face_confidence = min_face_confidence | |
| self.max_yaw_degrees = max_yaw_degrees | |
| self.watermark_enabled = watermark_enabled | |
| self.watermark_text = watermark_text | |
| def validate( | |
| self, | |
| input_image: np.ndarray, | |
| output_image: np.ndarray, | |
| landmarks_original: np.ndarray | None = None, | |
| landmarks_manipulated: np.ndarray | None = None, | |
| procedure: str | None = None, | |
| face_confidence: float = 1.0, | |
| ) -> SafetyResult: | |
| """Run all safety checks on a prediction. | |
| Args: | |
| input_image: Original patient image (BGR, uint8). | |
| output_image: Generated prediction (BGR, uint8). | |
| landmarks_original: Original landmarks (N, 2-3), normalized [0, 1]. | |
| landmarks_manipulated: Manipulated landmarks (N, 2-3), normalized [0, 1]. | |
| procedure: Surgical procedure name. | |
| face_confidence: MediaPipe face detection confidence. | |
| Returns: | |
| SafetyResult with all check results. | |
| """ | |
| result = SafetyResult() | |
| # 1. Face detection confidence | |
| self._check_face_confidence(result, face_confidence) | |
| # 2. Identity preservation | |
| self._check_identity(result, input_image, output_image) | |
| # 3. Anatomical plausibility | |
| if landmarks_original is not None and landmarks_manipulated is not None: | |
| self._check_anatomical_plausibility( | |
| result, landmarks_original, landmarks_manipulated, procedure | |
| ) | |
| # 4. Output quality | |
| self._check_output_quality(result, output_image) | |
| # 5. OOD detection (basic) | |
| self._check_ood(result, input_image) | |
| return result | |
| def _check_face_confidence(self, result: SafetyResult, confidence: float) -> None: | |
| """Check face detection confidence.""" | |
| if confidence < self.min_face_confidence: | |
| result.add_failure( | |
| "face_confidence", | |
| f"Face detection confidence {confidence:.2f} below threshold " | |
| f"{self.min_face_confidence}", | |
| ) | |
| else: | |
| result.add_pass("face_confidence") | |
| result.details["face_confidence"] = confidence | |
| def _check_identity( | |
| self, | |
| result: SafetyResult, | |
| input_image: np.ndarray, | |
| output_image: np.ndarray, | |
| ) -> None: | |
| """Check identity preservation using ArcFace similarity.""" | |
| try: | |
| from landmarkdiff.evaluation import compute_identity_similarity | |
| sim = compute_identity_similarity(output_image, input_image) | |
| result.details["identity_similarity"] = float(sim) | |
| if sim < self.identity_threshold: | |
| result.add_failure( | |
| "identity", | |
| f"Identity similarity {sim:.3f} below threshold {self.identity_threshold}", | |
| ) | |
| else: | |
| result.add_pass("identity") | |
| except Exception as e: | |
| result.add_warning("identity", f"Identity check failed: {e}") | |
| def _check_anatomical_plausibility( | |
| self, | |
| result: SafetyResult, | |
| landmarks_orig: np.ndarray, | |
| landmarks_manip: np.ndarray, | |
| procedure: str | None, | |
| ) -> None: | |
| """Check that landmark displacements are anatomically plausible.""" | |
| if len(landmarks_orig) != len(landmarks_manip): | |
| result.add_failure( | |
| "anatomical", | |
| f"Landmark count mismatch: {len(landmarks_orig)} vs {len(landmarks_manip)}", | |
| ) | |
| return | |
| # Compute displacement magnitudes | |
| n = min(len(landmarks_orig), len(landmarks_manip)) | |
| orig = landmarks_orig[:n, :2] # (N, 2), normalized [0, 1] | |
| manip = landmarks_manip[:n, :2] | |
| displacements = np.linalg.norm(manip - orig, axis=1) | |
| max_disp = float(displacements.max()) | |
| mean_disp = float(displacements.mean()) | |
| result.details["max_displacement"] = max_disp | |
| result.details["mean_displacement"] = mean_disp | |
| # Check maximum displacement | |
| if max_disp > self.max_displacement_fraction: | |
| result.add_failure( | |
| "anatomical_magnitude", | |
| f"Maximum displacement {max_disp:.4f} exceeds threshold " | |
| f"{self.max_displacement_fraction}", | |
| ) | |
| else: | |
| result.add_pass("anatomical_magnitude") | |
| # Check procedure-specific regions | |
| if procedure: | |
| self._check_procedure_regions(result, orig, manip, displacements, procedure) | |
| def _check_procedure_regions( | |
| self, | |
| result: SafetyResult, | |
| orig: np.ndarray, | |
| manip: np.ndarray, | |
| displacements: np.ndarray, | |
| procedure: str, | |
| ) -> None: | |
| """Verify displacement is concentrated in expected anatomical regions.""" | |
| from landmarkdiff.landmarks import LANDMARK_REGIONS | |
| # Expected regions by procedure | |
| expected_regions = { | |
| "rhinoplasty": ["nose"], | |
| "blepharoplasty": ["eye_left", "eye_right"], | |
| "rhytidectomy": ["jawline"], | |
| "orthognathic": ["jawline", "lips"], | |
| } | |
| expected = expected_regions.get(procedure, []) | |
| if not expected: | |
| result.add_pass("procedure_region") | |
| return | |
| # Get expected region indices | |
| expected_indices = set() | |
| for region in expected: | |
| if region in LANDMARK_REGIONS: | |
| expected_indices.update(LANDMARK_REGIONS[region]) | |
| if not expected_indices: | |
| result.add_pass("procedure_region") | |
| return | |
| # Check: is most displacement in expected regions? | |
| n = min(len(displacements), len(orig)) | |
| expected_mask = np.array([i in expected_indices for i in range(n)]) | |
| if expected_mask.sum() > 0 and (~expected_mask).sum() > 0: | |
| expected_disp = displacements[expected_mask].mean() | |
| unexpected_disp = displacements[~expected_mask].mean() | |
| result.details["expected_region_disp"] = float(expected_disp) | |
| result.details["unexpected_region_disp"] = float(unexpected_disp) | |
| # Expected regions should have more displacement | |
| if unexpected_disp > expected_disp * 2 and unexpected_disp > 0.005: | |
| result.add_warning( | |
| "procedure_region", | |
| f"{procedure}: unexpected regions displaced more than expected " | |
| f"({unexpected_disp:.4f} vs {expected_disp:.4f})", | |
| ) | |
| else: | |
| result.add_pass("procedure_region") | |
| else: | |
| result.add_pass("procedure_region") | |
| def _check_output_quality(self, result: SafetyResult, output: np.ndarray) -> None: | |
| """Check output image quality (not blank, not corrupted).""" | |
| if output is None or output.size == 0: | |
| result.add_failure("output_quality", "Output image is empty") | |
| return | |
| # Check for blank/black images | |
| mean_val = output.mean() | |
| if mean_val < 5: | |
| result.add_failure("output_quality", f"Output is nearly black (mean={mean_val:.1f})") | |
| return | |
| if mean_val > 250: | |
| result.add_failure("output_quality", f"Output is nearly white (mean={mean_val:.1f})") | |
| return | |
| # Check for artifacts (extreme variance) | |
| std_val = output.std() | |
| if std_val < 10: | |
| result.add_warning( | |
| "output_quality", | |
| f"Output has very low variance (std={std_val:.1f}), may be uniform", | |
| ) | |
| result.add_pass("output_quality") | |
| result.details["output_mean"] = float(mean_val) | |
| result.details["output_std"] = float(std_val) | |
| def _check_ood(self, result: SafetyResult, image: np.ndarray) -> None: | |
| """Basic out-of-distribution detection. | |
| Checks image properties against expected ranges for face photos. | |
| """ | |
| h, w = image.shape[:2] | |
| # Resolution check | |
| if min(h, w) < 128: | |
| result.add_warning("ood", f"Image resolution too low: {w}x{h}") | |
| # Aspect ratio (faces should be roughly square after preprocessing) | |
| aspect = max(h, w) / max(min(h, w), 1) | |
| if aspect > 3.0: | |
| result.add_warning("ood", f"Unusual aspect ratio: {aspect:.1f}") | |
| # Color distribution (face photos should have some skin tones) | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| mean_b, mean_g, mean_r = image.mean(axis=(0, 1)) | |
| # Face images typically have red channel > blue channel | |
| if mean_b > mean_r * 1.5: | |
| result.add_warning("ood", "Image appears very blue (not typical face photo)") | |
| result.add_pass("ood_basic") | |
| def apply_watermark( | |
| self, | |
| image: np.ndarray, | |
| text: str | None = None, | |
| opacity: float = 0.3, | |
| ) -> np.ndarray: | |
| """Apply a text watermark to the output image. | |
| Places semi-transparent text at the bottom of the image to indicate | |
| it is AI-generated. | |
| """ | |
| if not self.watermark_enabled: | |
| return image | |
| text = text or self.watermark_text | |
| result = image.copy() | |
| h, w = result.shape[:2] | |
| # Create text overlay | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = max(0.3, w / 1500) | |
| thickness = max(1, int(w / 500)) | |
| text_size = cv2.getTextSize(text, font, font_scale, thickness)[0] | |
| x = (w - text_size[0]) // 2 | |
| y = h - 10 | |
| # Semi-transparent background bar | |
| bar_y1 = y - text_size[1] - 10 | |
| bar_y2 = h | |
| overlay = result.copy() | |
| cv2.rectangle(overlay, (0, bar_y1), (w, bar_y2), (0, 0, 0), -1) | |
| cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result) | |
| # White text | |
| cv2.putText(result, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| return result | |
| def embed_metadata( | |
| self, | |
| image_path: str, | |
| procedure: str, | |
| intensity: float, | |
| model_version: str = "0.3.0", | |
| ) -> None: | |
| """Embed provenance metadata in the output image. | |
| Writes EXIF/PNG metadata with generation parameters for traceability. | |
| """ | |
| import json | |
| from pathlib import Path | |
| meta = { | |
| "generator": "LandmarkDiff", | |
| "version": model_version, | |
| "procedure": procedure, | |
| "intensity": intensity, | |
| "disclaimer": "AI-generated surgical prediction for visualization only. " | |
| "Not a guarantee of surgical outcome.", | |
| } | |
| # Save as sidecar JSON (PNG doesn't have easy EXIF support) | |
| meta_path = Path(image_path).with_suffix(".meta.json") | |
| with open(meta_path, "w") as f: | |
| json.dump(meta, f, indent=2) | |