dreamlessx's picture
Update landmarkdiff/safety.py to v0.3.2
ded6c17 verified
"""Clinical safety validation for responsible deployment.
Implements safety checks for surgical outcome predictions:
1. Identity preservation: verify output preserves patient identity
2. Anatomical plausibility: check landmark displacements are realistic
3. Out-of-distribution detection: flag unusual inputs
4. Watermarking: mark AI-generated images
5. Consent metadata: embed provenance information
Usage:
from landmarkdiff.safety import SafetyValidator
validator = SafetyValidator()
result = validator.validate(
input_image=image,
output_image=generated,
landmarks_original=face.landmarks,
landmarks_manipulated=manip.landmarks,
procedure="rhinoplasty",
)
if not result.passed:
print(f"Safety check failed: {result.failures}")
"""
from __future__ import annotations
from dataclasses import dataclass, field
import cv2
import numpy as np
@dataclass
class SafetyResult:
"""Result of safety validation checks."""
passed: bool = True
failures: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
checks: dict[str, bool] = field(default_factory=dict)
details: dict[str, object] = field(default_factory=dict)
def __repr__(self) -> str:
return (
f"SafetyResult(passed={self.passed}, "
f"failures={self.failures}, "
f"warnings={self.warnings}, "
f"checks={self.checks}, "
f"details={self.details})"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SafetyResult):
return NotImplemented
return (
self.passed == other.passed
and self.failures == other.failures
and self.warnings == other.warnings
and self.checks == other.checks
and self.details == other.details
)
def add_failure(self, name: str, message: str) -> None:
self.passed = False
self.failures.append(message)
self.checks[name] = False
def add_warning(self, name: str, message: str) -> None:
self.warnings.append(message)
def add_pass(self, name: str) -> None:
self.checks[name] = True
def summary(self) -> str:
lines = [f"Safety: {'PASS' if self.passed else 'FAIL'}"]
for name, ok in self.checks.items():
lines.append(f" [{'OK' if ok else 'FAIL'}] {name}")
for w in self.warnings:
lines.append(f" [WARN] {w}")
return "\n".join(lines)
class SafetyValidator:
"""Clinical safety validation for surgical predictions."""
def __init__(
self,
identity_threshold: float = 0.5,
max_displacement_fraction: float = 0.05,
min_face_confidence: float = 0.5,
max_yaw_degrees: float = 45.0,
watermark_enabled: bool = True,
watermark_text: str = "AI-GENERATED PREDICTION",
):
self.identity_threshold = identity_threshold
self.max_displacement_fraction = max_displacement_fraction
self.min_face_confidence = min_face_confidence
self.max_yaw_degrees = max_yaw_degrees
self.watermark_enabled = watermark_enabled
self.watermark_text = watermark_text
def validate(
self,
input_image: np.ndarray,
output_image: np.ndarray,
landmarks_original: np.ndarray | None = None,
landmarks_manipulated: np.ndarray | None = None,
procedure: str | None = None,
face_confidence: float = 1.0,
) -> SafetyResult:
"""Run all safety checks on a prediction.
Args:
input_image: Original patient image (BGR, uint8).
output_image: Generated prediction (BGR, uint8).
landmarks_original: Original landmarks (N, 2-3), normalized [0, 1].
landmarks_manipulated: Manipulated landmarks (N, 2-3), normalized [0, 1].
procedure: Surgical procedure name.
face_confidence: MediaPipe face detection confidence.
Returns:
SafetyResult with all check results.
"""
result = SafetyResult()
# 1. Face detection confidence
self._check_face_confidence(result, face_confidence)
# 2. Identity preservation
self._check_identity(result, input_image, output_image)
# 3. Anatomical plausibility
if landmarks_original is not None and landmarks_manipulated is not None:
self._check_anatomical_plausibility(
result, landmarks_original, landmarks_manipulated, procedure
)
# 4. Output quality
self._check_output_quality(result, output_image)
# 5. OOD detection (basic)
self._check_ood(result, input_image)
return result
def _check_face_confidence(self, result: SafetyResult, confidence: float) -> None:
"""Check face detection confidence."""
if confidence < self.min_face_confidence:
result.add_failure(
"face_confidence",
f"Face detection confidence {confidence:.2f} below threshold "
f"{self.min_face_confidence}",
)
else:
result.add_pass("face_confidence")
result.details["face_confidence"] = confidence
def _check_identity(
self,
result: SafetyResult,
input_image: np.ndarray,
output_image: np.ndarray,
) -> None:
"""Check identity preservation using ArcFace similarity."""
try:
from landmarkdiff.evaluation import compute_identity_similarity
sim = compute_identity_similarity(output_image, input_image)
result.details["identity_similarity"] = float(sim)
if sim < self.identity_threshold:
result.add_failure(
"identity",
f"Identity similarity {sim:.3f} below threshold {self.identity_threshold}",
)
else:
result.add_pass("identity")
except Exception as e:
result.add_warning("identity", f"Identity check failed: {e}")
def _check_anatomical_plausibility(
self,
result: SafetyResult,
landmarks_orig: np.ndarray,
landmarks_manip: np.ndarray,
procedure: str | None,
) -> None:
"""Check that landmark displacements are anatomically plausible."""
if len(landmarks_orig) != len(landmarks_manip):
result.add_failure(
"anatomical",
f"Landmark count mismatch: {len(landmarks_orig)} vs {len(landmarks_manip)}",
)
return
# Compute displacement magnitudes
n = min(len(landmarks_orig), len(landmarks_manip))
orig = landmarks_orig[:n, :2] # (N, 2), normalized [0, 1]
manip = landmarks_manip[:n, :2]
displacements = np.linalg.norm(manip - orig, axis=1)
max_disp = float(displacements.max())
mean_disp = float(displacements.mean())
result.details["max_displacement"] = max_disp
result.details["mean_displacement"] = mean_disp
# Check maximum displacement
if max_disp > self.max_displacement_fraction:
result.add_failure(
"anatomical_magnitude",
f"Maximum displacement {max_disp:.4f} exceeds threshold "
f"{self.max_displacement_fraction}",
)
else:
result.add_pass("anatomical_magnitude")
# Check procedure-specific regions
if procedure:
self._check_procedure_regions(result, orig, manip, displacements, procedure)
def _check_procedure_regions(
self,
result: SafetyResult,
orig: np.ndarray,
manip: np.ndarray,
displacements: np.ndarray,
procedure: str,
) -> None:
"""Verify displacement is concentrated in expected anatomical regions."""
from landmarkdiff.landmarks import LANDMARK_REGIONS
# Expected regions by procedure
expected_regions = {
"rhinoplasty": ["nose"],
"blepharoplasty": ["eye_left", "eye_right"],
"rhytidectomy": ["jawline"],
"orthognathic": ["jawline", "lips"],
}
expected = expected_regions.get(procedure, [])
if not expected:
result.add_pass("procedure_region")
return
# Get expected region indices
expected_indices = set()
for region in expected:
if region in LANDMARK_REGIONS:
expected_indices.update(LANDMARK_REGIONS[region])
if not expected_indices:
result.add_pass("procedure_region")
return
# Check: is most displacement in expected regions?
n = min(len(displacements), len(orig))
expected_mask = np.array([i in expected_indices for i in range(n)])
if expected_mask.sum() > 0 and (~expected_mask).sum() > 0:
expected_disp = displacements[expected_mask].mean()
unexpected_disp = displacements[~expected_mask].mean()
result.details["expected_region_disp"] = float(expected_disp)
result.details["unexpected_region_disp"] = float(unexpected_disp)
# Expected regions should have more displacement
if unexpected_disp > expected_disp * 2 and unexpected_disp > 0.005:
result.add_warning(
"procedure_region",
f"{procedure}: unexpected regions displaced more than expected "
f"({unexpected_disp:.4f} vs {expected_disp:.4f})",
)
else:
result.add_pass("procedure_region")
else:
result.add_pass("procedure_region")
def _check_output_quality(self, result: SafetyResult, output: np.ndarray) -> None:
"""Check output image quality (not blank, not corrupted)."""
if output is None or output.size == 0:
result.add_failure("output_quality", "Output image is empty")
return
# Check for blank/black images
mean_val = output.mean()
if mean_val < 5:
result.add_failure("output_quality", f"Output is nearly black (mean={mean_val:.1f})")
return
if mean_val > 250:
result.add_failure("output_quality", f"Output is nearly white (mean={mean_val:.1f})")
return
# Check for artifacts (extreme variance)
std_val = output.std()
if std_val < 10:
result.add_warning(
"output_quality",
f"Output has very low variance (std={std_val:.1f}), may be uniform",
)
result.add_pass("output_quality")
result.details["output_mean"] = float(mean_val)
result.details["output_std"] = float(std_val)
def _check_ood(self, result: SafetyResult, image: np.ndarray) -> None:
"""Basic out-of-distribution detection.
Checks image properties against expected ranges for face photos.
"""
h, w = image.shape[:2]
# Resolution check
if min(h, w) < 128:
result.add_warning("ood", f"Image resolution too low: {w}x{h}")
# Aspect ratio (faces should be roughly square after preprocessing)
aspect = max(h, w) / max(min(h, w), 1)
if aspect > 3.0:
result.add_warning("ood", f"Unusual aspect ratio: {aspect:.1f}")
# Color distribution (face photos should have some skin tones)
if len(image.shape) == 3 and image.shape[2] == 3:
mean_b, mean_g, mean_r = image.mean(axis=(0, 1))
# Face images typically have red channel > blue channel
if mean_b > mean_r * 1.5:
result.add_warning("ood", "Image appears very blue (not typical face photo)")
result.add_pass("ood_basic")
def apply_watermark(
self,
image: np.ndarray,
text: str | None = None,
opacity: float = 0.3,
) -> np.ndarray:
"""Apply a text watermark to the output image.
Places semi-transparent text at the bottom of the image to indicate
it is AI-generated.
"""
if not self.watermark_enabled:
return image
text = text or self.watermark_text
result = image.copy()
h, w = result.shape[:2]
# Create text overlay
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = max(0.3, w / 1500)
thickness = max(1, int(w / 500))
text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
x = (w - text_size[0]) // 2
y = h - 10
# Semi-transparent background bar
bar_y1 = y - text_size[1] - 10
bar_y2 = h
overlay = result.copy()
cv2.rectangle(overlay, (0, bar_y1), (w, bar_y2), (0, 0, 0), -1)
cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
# White text
cv2.putText(result, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
return result
def embed_metadata(
self,
image_path: str,
procedure: str,
intensity: float,
model_version: str = "0.3.0",
) -> None:
"""Embed provenance metadata in the output image.
Writes EXIF/PNG metadata with generation parameters for traceability.
"""
import json
from pathlib import Path
meta = {
"generator": "LandmarkDiff",
"version": model_version,
"procedure": procedure,
"intensity": intensity,
"disclaimer": "AI-generated surgical prediction for visualization only. "
"Not a guarantee of surgical outcome.",
}
# Save as sidecar JSON (PNG doesn't have easy EXIF support)
meta_path = Path(image_path).with_suffix(".meta.json")
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)