Spaces:
Running
Running
| """4-term loss function module for ControlNet fine-tuning. | |
| L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual | |
| Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against | |
| rubbery TPS warps — it would penalize realism. | |
| Phase B (FEM/clinical data): All 4 terms enabled. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn.functional as F | |
| class LossWeights: | |
| """Loss term weights.""" | |
| diffusion: float = 1.0 | |
| landmark: float = 0.1 | |
| identity: float = 0.1 | |
| perceptual: float = 0.05 | |
| class DiffusionLoss: | |
| """Standard epsilon-prediction MSE loss (primary training signal).""" | |
| def __call__( | |
| self, | |
| noise_pred: torch.Tensor, | |
| noise_target: torch.Tensor, | |
| ) -> torch.Tensor: | |
| return F.mse_loss(noise_pred, noise_target) | |
| class LandmarkLoss: | |
| """L2 landmark distance normalized by inter-ocular distance. | |
| Computed INSIDE surgical mask only. Requires MediaPipe re-extraction | |
| from generated image (done at eval, not every training step for speed). | |
| """ | |
| def __call__( | |
| self, | |
| pred_landmarks: torch.Tensor, # (B, N, 2) | |
| target_landmarks: torch.Tensor, # (B, N, 2) | |
| mask: torch.Tensor | None = None, # (B, N) binary | |
| iod: torch.Tensor | None = None, # (B,) inter-ocular distance | |
| ) -> torch.Tensor: | |
| diff = pred_landmarks - target_landmarks # (B, N, 2) | |
| dist = torch.norm(diff, dim=-1) # (B, N) | |
| if mask is not None: | |
| dist = dist * mask | |
| count = mask.sum(dim=-1).clamp(min=1) | |
| mean_dist = dist.sum(dim=-1) / count | |
| else: | |
| mean_dist = dist.mean(dim=-1) | |
| if iod is not None: | |
| mean_dist = mean_dist / iod.clamp(min=1.0) | |
| return mean_dist.mean() | |
| class IdentityLoss: | |
| """ArcFace cosine similarity loss with procedure-dependent masking. | |
| Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings. | |
| Falls back to pixel-level cosine similarity if InsightFace is unavailable. | |
| - Full face for blepharoplasty | |
| - Upper-face crop for rhinoplasty | |
| - Disabled for orthognathic | |
| Input images MUST be normalized to [-1, 1] and cropped to 112x112 | |
| before passing to ArcFace (AdaFace outputs garbage for 1024x1024). | |
| """ | |
| def __init__(self, device: torch.device | None = None): | |
| self._model = None | |
| self._device = device | |
| self._has_arcface = None # None = not checked yet | |
| def _ensure_loaded(self, device: torch.device) -> None: | |
| """Lazy-load ArcFace model on first use.""" | |
| if self._has_arcface is not None: | |
| return | |
| try: | |
| from insightface.app import FaceAnalysis | |
| self._app = FaceAnalysis( | |
| name="buffalo_l", | |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"], | |
| ) | |
| ctx_id = device.index if device.type == "cuda" and device.index is not None else (0 if device.type == "cuda" else -1) | |
| self._app.prepare(ctx_id=ctx_id, det_size=(320, 320)) | |
| self._has_arcface = True | |
| except Exception: | |
| self._has_arcface = False | |
| def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor: | |
| """Extract ArcFace embedding from a batch of images. | |
| Args: | |
| image_tensor: (B, 3, 112, 112) in [-1, 1] | |
| Returns: | |
| (B, 512) identity embeddings, or (B, D) pixel-level if fallback. | |
| """ | |
| if self._has_arcface: | |
| import numpy as np | |
| embeddings = [] | |
| valid_mask = [] | |
| for i in range(image_tensor.shape[0]): | |
| # Convert to uint8 BGR for InsightFace | |
| img = ((image_tensor[i].permute(1, 2, 0) + 1) / 2 * 255).clamp(0, 255) | |
| img_np = img.cpu().numpy().astype(np.uint8) | |
| img_bgr = img_np[:, :, ::-1].copy() | |
| faces = self._app.get(img_bgr) | |
| if faces and hasattr(faces[0], "embedding") and faces[0].embedding is not None: | |
| embeddings.append(torch.from_numpy(faces[0].embedding)) | |
| valid_mask.append(True) | |
| else: | |
| embeddings.append(torch.zeros(512)) | |
| valid_mask.append(False) | |
| return torch.stack(embeddings).to(image_tensor.device), valid_mask | |
| else: | |
| # Fallback: pixel-level features | |
| return image_tensor.flatten(1), [True] * image_tensor.shape[0] | |
| def __call__( | |
| self, | |
| pred_image: torch.Tensor, # (B, 3, H, W) in [0, 1] | |
| target_image: torch.Tensor, | |
| procedure: str = "rhinoplasty", | |
| ) -> torch.Tensor: | |
| if procedure == "orthognathic": | |
| return torch.tensor(0.0, device=pred_image.device) | |
| self._ensure_loaded(pred_image.device) | |
| # Crop based on procedure | |
| pred_crop = self._procedure_crop(pred_image, procedure) | |
| target_crop = self._procedure_crop(target_image, procedure) | |
| # Resize to 112x112 for ArcFace | |
| pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False) | |
| target_112 = F.interpolate(target_crop, size=(112, 112), mode="bilinear", align_corners=False) | |
| # Normalize to [-1, 1] | |
| pred_norm = pred_112 * 2 - 1 | |
| target_norm = target_112 * 2 - 1 | |
| # Extract embeddings (ArcFace or fallback) | |
| pred_emb, pred_valid = self._extract_embedding(pred_norm) | |
| target_emb, target_valid = self._extract_embedding(target_norm) | |
| # Only compute loss for samples where both faces were detected | |
| valid = [p and t for p, t in zip(pred_valid, target_valid)] | |
| if not any(valid): | |
| return torch.tensor(0.0, device=pred_image.device) | |
| valid_indices = [i for i, v in enumerate(valid) if v] | |
| valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long) | |
| # Select ONLY valid embeddings before normalization to avoid 0/0 NaN | |
| pred_valid_emb = pred_emb[valid_idx_t].float() | |
| target_valid_emb = target_emb[valid_idx_t].float() | |
| # L2 normalize (safe — zero vectors excluded above) | |
| pred_valid_emb = F.normalize(pred_valid_emb, dim=1) | |
| target_valid_emb = F.normalize(target_valid_emb, dim=1) | |
| cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1) | |
| return (1 - cosine_sim).mean() | |
| def _procedure_crop( | |
| self, | |
| image: torch.Tensor, | |
| procedure: str, | |
| ) -> torch.Tensor: | |
| """Crop image based on procedure for identity comparison.""" | |
| _, _, h, w = image.shape | |
| if procedure == "rhinoplasty": | |
| # Upper face crop (forehead to nose tip) | |
| return image[:, :, : h * 2 // 3, :] | |
| elif procedure == "blepharoplasty": | |
| # Full face | |
| return image | |
| elif procedure == "rhytidectomy": | |
| # Upper face (above jawline) | |
| return image[:, :, : h * 3 // 4, :] | |
| else: | |
| return image | |
| class PerceptualLoss: | |
| """LPIPS perceptual loss on regions OUTSIDE surgical mask only. | |
| LPIPS expects [-1, 1] input. VAE outputs [0, 1]. | |
| Must apply (x * 2) - 1 before every call. | |
| """ | |
| def __init__(self): | |
| self._lpips = None | |
| def _ensure_loaded(self, device: torch.device) -> None: | |
| if self._lpips is None: | |
| try: | |
| import lpips | |
| self._lpips = lpips.LPIPS(net="alex").to(device) | |
| self._lpips.eval() | |
| for p in self._lpips.parameters(): | |
| p.requires_grad_(False) | |
| except ImportError: | |
| self._lpips = "unavailable" | |
| def __call__( | |
| self, | |
| pred: torch.Tensor, # (B, 3, H, W) in [0, 1] | |
| target: torch.Tensor, | |
| mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1] | |
| ) -> torch.Tensor: | |
| self._ensure_loaded(pred.device) | |
| # Normalize to [-1, 1] for LPIPS | |
| pred_norm = pred * 2 - 1 | |
| target_norm = target * 2 - 1 | |
| # When mask is all-ones (no mask file available), compute on full image. | |
| # Otherwise invert mask to get loss OUTSIDE the surgical region only. | |
| has_mask = mask.sum() < mask.numel() * 0.99 | |
| if has_mask: | |
| outside_mask = 1 - mask | |
| erode_kernel = 5 | |
| if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel: | |
| outside_mask = -F.max_pool2d( | |
| -outside_mask, | |
| kernel_size=erode_kernel, | |
| stride=1, | |
| padding=erode_kernel // 2, | |
| ) | |
| pred_norm = pred_norm * outside_mask | |
| target_norm = target_norm * outside_mask | |
| if self._lpips == "unavailable": | |
| # Fallback: simple L1 loss | |
| return F.l1_loss(pred_norm, target_norm) | |
| return self._lpips(pred_norm, target_norm).mean() | |
| class CombinedLoss: | |
| """Combined 4-term loss with configurable weights. | |
| Use phase='A' for Phase A training (diffusion only). | |
| Use phase='B' for Phase B training (all terms). | |
| For Phase B, set ``use_differentiable_arcface=True`` to use the | |
| PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides | |
| actual gradient signal. The default ONNX-based IdentityLoss produces | |
| zero gradients (DA2-03). | |
| """ | |
| def __init__( | |
| self, | |
| weights: LossWeights | None = None, | |
| phase: str = "A", | |
| use_differentiable_arcface: bool = False, | |
| arcface_weights_path: str | None = None, | |
| ): | |
| self.weights = weights or LossWeights() | |
| self.phase = phase | |
| self.diffusion_loss = DiffusionLoss() | |
| self.landmark_loss = LandmarkLoss() | |
| self.perceptual_loss = PerceptualLoss() | |
| # Identity loss: differentiable PyTorch ArcFace for Phase B, | |
| # or ONNX-based fallback | |
| if use_differentiable_arcface: | |
| from landmarkdiff.arcface_torch import ArcFaceLoss | |
| self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path) | |
| else: | |
| self.identity_loss = IdentityLoss() | |
| def __call__( | |
| self, | |
| noise_pred: torch.Tensor, | |
| noise_target: torch.Tensor, | |
| **kwargs, | |
| ) -> dict[str, torch.Tensor]: | |
| losses = {} | |
| # Always compute diffusion loss | |
| losses["diffusion"] = self.weights.diffusion * self.diffusion_loss(noise_pred, noise_target) | |
| losses["total"] = losses["diffusion"] | |
| if self.phase == "B": | |
| # Phase B: add auxiliary losses | |
| if "pred_landmarks" in kwargs and "target_landmarks" in kwargs: | |
| losses["landmark"] = self.weights.landmark * self.landmark_loss( | |
| kwargs["pred_landmarks"], | |
| kwargs["target_landmarks"], | |
| kwargs.get("landmark_mask"), | |
| kwargs.get("iod"), | |
| ) | |
| losses["total"] = losses["total"] + losses["landmark"] | |
| if "pred_image" in kwargs and "target_image" in kwargs: | |
| procedure = kwargs.get("procedure", "rhinoplasty") | |
| losses["identity"] = self.weights.identity * self.identity_loss( | |
| kwargs["pred_image"], | |
| kwargs["target_image"], | |
| procedure, | |
| ) | |
| losses["total"] = losses["total"] + losses["identity"] | |
| if "pred_image" in kwargs and "target_image" in kwargs and "mask" in kwargs: | |
| losses["perceptual"] = self.weights.perceptual * self.perceptual_loss( | |
| kwargs["pred_image"], | |
| kwargs["target_image"], | |
| kwargs["mask"], | |
| ) | |
| losses["total"] = losses["total"] + losses["perceptual"] | |
| return losses | |