Spaces:
Running
Running
| """PyTorch-native ArcFace model for differentiable identity loss. | |
| Drop-in replacement for the ONNX-based InsightFace ArcFace used in losses.py. | |
| The original IdentityLoss extracts embeddings under @torch.no_grad(), which | |
| means the identity loss term contributes zero gradients during Phase B training. | |
| This module provides a fully differentiable path so that gradients flow back | |
| through the predicted image into the ControlNet. | |
| Architecture: IResNet-50 matching the InsightFace w600k_r50 ONNX model. | |
| conv1(3->64, 3x3, bias) -> PReLU -> | |
| 4 IResNet stages [3, 4, 14, 3] with channels [64, 128, 256, 512] -> | |
| BN2d -> Flatten -> FC(512*7*7 -> 512) -> BN1d -> L2-normalize | |
| Each IBasicBlock: BN -> conv3x3(bias) -> PReLU -> conv3x3(bias) + residual. | |
| No SE module. Convolutions use bias=True. | |
| Pretrained weights: converted from the InsightFace buffalo_l w600k_r50.onnx | |
| model to a PyTorch state dict (backbone.pth). The conversion extracts weights | |
| from the ONNX graph and maps them to matching PyTorch module keys. | |
| Usage in losses.py: | |
| from landmarkdiff.arcface_torch import ArcFaceLoss | |
| identity_loss = ArcFaceLoss(device=device) | |
| loss = identity_loss(pred_image, target_image) # gradients flow through pred | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import warnings | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Building blocks | |
| # --------------------------------------------------------------------------- | |
| class IBasicBlock(nn.Module): | |
| """Improved basic residual block for IResNet. | |
| Structure: BN -> conv3x3(bias) -> PReLU -> conv3x3(bias) -> + residual | |
| Uses pre-activation style BatchNorm. Convolutions have bias=True to match | |
| the InsightFace w600k_r50 ONNX weights. | |
| """ | |
| expansion: int = 1 | |
| def __init__( | |
| self, | |
| inplanes: int, | |
| planes: int, | |
| stride: int = 1, | |
| downsample: nn.Module | None = None, | |
| ): | |
| super().__init__() | |
| self.bn1 = nn.BatchNorm2d(inplanes, eps=2e-5, momentum=0.1) | |
| self.conv1 = nn.Conv2d( | |
| inplanes, planes, kernel_size=3, stride=1, padding=1, bias=True, | |
| ) | |
| self.prelu = nn.PReLU(planes) | |
| self.conv2 = nn.Conv2d( | |
| planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, | |
| ) | |
| self.downsample = downsample | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| identity = x | |
| out = self.bn1(x) | |
| out = self.conv1(out) | |
| out = self.prelu(out) | |
| out = self.conv2(out) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| return out + identity | |
| # --------------------------------------------------------------------------- | |
| # Backbone | |
| # --------------------------------------------------------------------------- | |
| class ArcFaceBackbone(nn.Module): | |
| """IResNet-50 backbone for ArcFace identity embeddings. | |
| Input: (B, 3, 112, 112) face crops normalized to [-1, 1]. | |
| Output: (B, 512) L2-normalized embeddings. | |
| Architecture matches the InsightFace w600k_r50 ONNX model exactly: | |
| Conv(bias) -> PReLU -> 4 stages -> BN2d -> Flatten -> FC -> BN1d -> L2norm. | |
| """ | |
| def __init__( | |
| self, | |
| layers: tuple[int, ...] = (3, 4, 14, 3), | |
| embedding_dim: int = 512, | |
| ): | |
| super().__init__() | |
| self.inplanes = 64 | |
| # Stem: conv1(bias) -> PReLU (no BN in stem) | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True) | |
| self.prelu = nn.PReLU(64) | |
| # 4 residual stages | |
| self.layer1 = self._make_layer(64, layers[0], stride=2) | |
| self.layer2 = self._make_layer(128, layers[1], stride=2) | |
| self.layer3 = self._make_layer(256, layers[2], stride=2) | |
| self.layer4 = self._make_layer(512, layers[3], stride=2) | |
| # Head: BN2d -> Flatten -> FC -> BN1d | |
| self.bn2 = nn.BatchNorm2d(512, eps=2e-5, momentum=0.1) | |
| self.fc = nn.Linear(512 * 7 * 7, embedding_dim) | |
| self.features = nn.BatchNorm1d(embedding_dim, eps=2e-5, momentum=0.1) | |
| # Weight initialization | |
| self._initialize_weights() | |
| def _make_layer( | |
| self, | |
| planes: int, | |
| num_blocks: int, | |
| stride: int = 1, | |
| ) -> nn.Sequential: | |
| downsample = None | |
| if stride != 1 or self.inplanes != planes: | |
| downsample = nn.Conv2d( | |
| self.inplanes, planes, kernel_size=1, stride=stride, bias=True, | |
| ) | |
| layers = [IBasicBlock(self.inplanes, planes, stride, downsample)] | |
| self.inplanes = planes | |
| for _ in range(1, num_blocks): | |
| layers.append(IBasicBlock(self.inplanes, planes)) | |
| return nn.Sequential(*layers) | |
| def _initialize_weights(self) -> None: | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): | |
| nn.init.constant_(m.weight, 1) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (B, 3, 112, 112) in [-1, 1]. | |
| Returns: | |
| (B, 512) L2-normalized embeddings. | |
| """ | |
| x = self.conv1(x) | |
| x = self.prelu(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.bn2(x) | |
| x = torch.flatten(x, 1) | |
| x = self.fc(x) | |
| x = self.features(x) | |
| # L2 normalize | |
| x = F.normalize(x, p=2, dim=1) | |
| return x | |
| # --------------------------------------------------------------------------- | |
| # Pretrained weight loading | |
| # --------------------------------------------------------------------------- | |
| # Known locations where converted backbone.pth may live | |
| _KNOWN_WEIGHT_PATHS = [ | |
| Path.home() / ".cache" / "arcface" / "backbone.pth", | |
| Path.home() / ".insightface" / "models" / "buffalo_l" / "backbone.pth", | |
| ] | |
| def _find_pretrained_weights() -> Path | None: | |
| """Search known locations for pretrained IResNet-50 weights.""" | |
| for p in _KNOWN_WEIGHT_PATHS: | |
| if p.exists() and p.suffix == ".pth" and p.stat().st_size > 0: | |
| return p | |
| return None | |
| def load_pretrained_weights( | |
| model: ArcFaceBackbone, | |
| weights_path: str | None = None, | |
| ) -> bool: | |
| """Load pretrained InsightFace IResNet-50 weights into the model. | |
| Weights are a PyTorch state dict converted from the InsightFace | |
| w600k_r50.onnx model. Key names match our module structure exactly. | |
| Args: | |
| model: An ``ArcFaceBackbone`` instance. | |
| weights_path: Explicit path to a ``.pth`` file. If ``None``, searches | |
| known locations. | |
| Returns: | |
| ``True`` if weights were loaded successfully, ``False`` otherwise | |
| (model keeps random initialization). | |
| """ | |
| path: Path | None = None | |
| if weights_path is not None: | |
| path = Path(weights_path) | |
| if not path.exists(): | |
| logger.warning("Specified weights path does not exist: %s", path) | |
| path = None | |
| if path is None: | |
| path = _find_pretrained_weights() | |
| if path is None: | |
| warnings.warn( | |
| "No pretrained ArcFace weights found. The model will use random " | |
| "initialization. Identity loss values will be meaningless until " | |
| "proper weights are loaded. Place backbone.pth at " | |
| f"{Path.home() / '.cache' / 'arcface' / 'backbone.pth'}", | |
| UserWarning, | |
| stacklevel=2, | |
| ) | |
| return False | |
| logger.info("Loading ArcFace weights from %s", path) | |
| state_dict = torch.load(str(path), map_location="cpu", weights_only=True) | |
| # Handle the case where the checkpoint wraps the state dict | |
| if "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| # Try direct load first | |
| try: | |
| model.load_state_dict(state_dict, strict=True) | |
| logger.info("Loaded ArcFace weights (strict match)") | |
| return True | |
| except RuntimeError: | |
| pass | |
| # Try non-strict load (some checkpoints may have extra keys) | |
| try: | |
| # Remap common differences | |
| remapped = {} | |
| for k, v in state_dict.items(): | |
| new_k = k | |
| if k.startswith("output_layer."): | |
| new_k = k.replace("output_layer.", "features.") | |
| remapped[new_k] = v | |
| missing, unexpected = model.load_state_dict(remapped, strict=False) | |
| if missing: | |
| logger.warning( | |
| "Missing keys when loading ArcFace weights: %s", | |
| missing[:10], | |
| ) | |
| if unexpected: | |
| logger.info("Unexpected keys (ignored): %s", unexpected[:10]) | |
| logger.info("Loaded ArcFace weights (non-strict)") | |
| return True | |
| except Exception as e: | |
| warnings.warn( | |
| f"Failed to load ArcFace weights from {path}: {e}. " | |
| "Using random initialization.", | |
| UserWarning, | |
| stacklevel=2, | |
| ) | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Differentiable face alignment | |
| # --------------------------------------------------------------------------- | |
| def align_face( | |
| images: torch.Tensor, | |
| size: int = 112, | |
| ) -> torch.Tensor: | |
| """Center-crop and resize face images to (size x size) differentiably. | |
| Uses ``F.grid_sample`` with bilinear interpolation so that gradients | |
| flow back through the spatial transform into the input images. | |
| The crop extracts the central 80% of the image (removes background | |
| padding that is common in generated 512x512 face images) and resizes | |
| to the target size. | |
| Args: | |
| images: (B, 3, H, W) tensor, any normalization. | |
| size: Target spatial size (default 112 for ArcFace). | |
| Returns: | |
| (B, 3, size, size) tensor with the same normalization as input. | |
| """ | |
| B, C, H, W = images.shape | |
| if H == size and W == size: | |
| return images | |
| # Crop fraction: keep central 80% to remove background padding | |
| crop_frac = 0.8 | |
| # Build a normalized grid [-1, 1] covering the center crop region | |
| half_crop = crop_frac / 2.0 | |
| theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype) | |
| theta[:, 0, 0] = half_crop # x scale | |
| theta[:, 1, 1] = half_crop # y scale | |
| grid = F.affine_grid(theta, [B, C, size, size], align_corners=False) | |
| aligned = F.grid_sample( | |
| images, grid, mode="bilinear", padding_mode="border", align_corners=False, | |
| ) | |
| return aligned | |
| def align_face_no_crop( | |
| images: torch.Tensor, | |
| size: int = 112, | |
| ) -> torch.Tensor: | |
| """Resize face images to (size x size) without cropping, differentiably. | |
| Simple bilinear resize using ``F.interpolate`` for gradient flow. Use | |
| this when images are already tightly cropped faces. | |
| Args: | |
| images: (B, 3, H, W) tensor. | |
| size: Target spatial size. | |
| Returns: | |
| (B, 3, size, size) tensor. | |
| """ | |
| if images.shape[-2] == size and images.shape[-1] == size: | |
| return images | |
| return F.interpolate( | |
| images, size=(size, size), mode="bilinear", align_corners=False, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # ArcFaceLoss: differentiable identity preservation loss | |
| # --------------------------------------------------------------------------- | |
| class ArcFaceLoss(nn.Module): | |
| """Differentiable identity loss using PyTorch-native ArcFace. | |
| Replaces the ONNX-based InsightFace ArcFace in ``IdentityLoss`` from | |
| ``losses.py``. Gradients flow through the predicted image into the | |
| generator, while the target embedding is detached. | |
| Loss = mean(1 - cosine_similarity(embed(pred), embed(target).detach())) | |
| The backbone is frozen (no gradient updates to ArcFace itself) but | |
| gradients DO flow through the forward pass of the backbone when | |
| computing pred embeddings. | |
| Example:: | |
| loss_fn = ArcFaceLoss(device=torch.device("cuda")) | |
| loss = loss_fn(pred_images, target_images) | |
| loss.backward() # gradients flow into pred_images | |
| """ | |
| def __init__( | |
| self, | |
| device: torch.device | None = None, | |
| weights_path: str | None = None, | |
| crop_face: bool = True, | |
| ): | |
| """ | |
| Args: | |
| device: Device to place the backbone on. If ``None``, determined | |
| from the first forward call. | |
| weights_path: Path to pretrained backbone.pth. If ``None``, | |
| searches known locations. | |
| crop_face: Whether to center-crop images before embedding. | |
| Set ``False`` if images are already 112x112 face crops. | |
| """ | |
| super().__init__() | |
| self.crop_face = crop_face | |
| self._weights_path = weights_path | |
| self._target_device = device | |
| self._initialized = False | |
| # Build backbone (lazy device placement) | |
| self.backbone = ArcFaceBackbone() | |
| def _ensure_initialized(self, device: torch.device) -> None: | |
| """Lazy initialization: load weights and move to device on first use.""" | |
| if self._initialized: | |
| return | |
| # Load pretrained weights | |
| loaded = load_pretrained_weights(self.backbone, self._weights_path) | |
| if not loaded: | |
| logger.warning( | |
| "ArcFaceLoss using random weights -- identity loss will not " | |
| "be meaningful. Download pretrained weights for proper training." | |
| ) | |
| # Move to device and freeze | |
| self.backbone = self.backbone.to(device) | |
| self.backbone.eval() | |
| for param in self.backbone.parameters(): | |
| param.requires_grad_(False) | |
| self._initialized = True | |
| def _prepare_images(self, images: torch.Tensor) -> torch.Tensor: | |
| """Prepare images for ArcFace: crop, resize, normalize to [-1, 1]. | |
| Args: | |
| images: (B, 3, H, W) in [0, 1]. | |
| Returns: | |
| (B, 3, 112, 112) in [-1, 1]. | |
| """ | |
| if self.crop_face: | |
| x = align_face(images, size=112) | |
| else: | |
| x = align_face_no_crop(images, size=112) | |
| # Normalize from [0, 1] to [-1, 1] | |
| x = x * 2.0 - 1.0 | |
| return x | |
| def _extract_embedding( | |
| self, | |
| images: torch.Tensor, | |
| enable_grad: bool = True, | |
| ) -> torch.Tensor: | |
| """Extract ArcFace embeddings. | |
| Args: | |
| images: (B, 3, 112, 112) in [-1, 1]. | |
| enable_grad: If ``True``, gradients flow through the backbone's | |
| forward pass (used for pred). If ``False``, detached (target). | |
| Returns: | |
| (B, 512) L2-normalized embeddings. | |
| """ | |
| if enable_grad: | |
| return self.backbone(images) | |
| else: | |
| with torch.no_grad(): | |
| return self.backbone(images) | |
| def forward( | |
| self, | |
| pred_image: torch.Tensor, | |
| target_image: torch.Tensor, | |
| procedure: str = "rhinoplasty", | |
| ) -> torch.Tensor: | |
| """Compute differentiable identity loss. | |
| Args: | |
| pred_image: (B, 3, H, W) predicted images in [0, 1]. | |
| Gradients WILL flow back through this tensor. | |
| target_image: (B, 3, H, W) target images in [0, 1]. | |
| Gradients will NOT flow through this (detached). | |
| procedure: Surgical procedure type. ``"orthognathic"`` returns | |
| zero loss (identity irrelevant for jaw surgery). | |
| Returns: | |
| Scalar loss: mean(1 - cosine_similarity(pred_emb, target_emb)). | |
| Returns 0 for orthognathic or empty batches. | |
| """ | |
| if procedure == "orthognathic": | |
| return torch.tensor(0.0, device=pred_image.device, dtype=pred_image.dtype) | |
| device = pred_image.device | |
| self._ensure_initialized(device) | |
| # Procedure-specific cropping (before ArcFace alignment) | |
| pred_crop = self._procedure_crop(pred_image, procedure) | |
| target_crop = self._procedure_crop(target_image, procedure) | |
| # Prepare for ArcFace (crop, resize to 112x112, normalize to [-1, 1]) | |
| pred_prepared = self._prepare_images(pred_crop) | |
| target_prepared = self._prepare_images(target_crop) | |
| # Extract embeddings | |
| pred_emb = self._extract_embedding(pred_prepared, enable_grad=True) | |
| target_emb = self._extract_embedding(target_prepared, enable_grad=False) | |
| # Detach target to be absolutely sure no gradients leak | |
| target_emb = target_emb.detach() | |
| # Cosine similarity loss: 1 - cos_sim | |
| cosine_sim = (pred_emb * target_emb).sum(dim=1) # (B,) | |
| # Clamp to valid range (numerical safety for BF16) | |
| cosine_sim = cosine_sim.clamp(-1.0, 1.0) | |
| loss = (1.0 - cosine_sim).mean() | |
| return loss | |
| def _procedure_crop( | |
| self, | |
| image: torch.Tensor, | |
| procedure: str, | |
| ) -> torch.Tensor: | |
| """Crop image based on surgical procedure for identity comparison.""" | |
| _, _, h, w = image.shape | |
| if procedure == "rhinoplasty": | |
| return image[:, :, : h * 2 // 3, :] | |
| elif procedure == "blepharoplasty": | |
| return image | |
| elif procedure == "rhytidectomy": | |
| return image[:, :, : h * 3 // 4, :] | |
| else: | |
| return image | |
| def get_embedding(self, images: torch.Tensor) -> torch.Tensor: | |
| """Extract identity embeddings (utility method for evaluation). | |
| Args: | |
| images: (B, 3, H, W) in [0, 1]. | |
| Returns: | |
| (B, 512) L2-normalized embeddings (detached). | |
| """ | |
| self._ensure_initialized(images.device) | |
| prepared = self._prepare_images(images) | |
| return self._extract_embedding(prepared, enable_grad=False) | |
| # --------------------------------------------------------------------------- | |
| # Convenience: create a pre-configured loss instance | |
| # --------------------------------------------------------------------------- | |
| def create_arcface_loss( | |
| device: torch.device | None = None, | |
| weights_path: str | None = None, | |
| ) -> ArcFaceLoss: | |
| """Factory function for creating an ArcFaceLoss with sensible defaults. | |
| Args: | |
| device: Target device (auto-detected if ``None``). | |
| weights_path: Path to backbone.pth (auto-searched if ``None``). | |
| Returns: | |
| Configured ``ArcFaceLoss`` instance. | |
| """ | |
| return ArcFaceLoss(device=device, weights_path=weights_path) | |