File size: 19,347 Bytes
544c445
c107618
544c445
 
c107618
 
 
 
 
433e26f
c107618
 
 
433e26f
c107618
 
 
 
 
 
 
544c445
c107618
 
 
544c445
c107618
 
 
 
 
544c445
 
c107618
 
544c445
 
 
c107618
 
544c445
 
 
c107618
 
544c445
 
 
 
c107618
 
544c445
 
 
c107618
 
544c445
 
 
c107618
 
544c445
c107618
 
 
 
 
 
 
544c445
c107618
 
 
 
 
 
 
544c445
 
 
 
 
 
 
 
 
 
 
c107618
 
 
 
 
544c445
c107618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
 
81d5fa1
c107618
81d5fa1
 
 
 
544c445
81d5fa1
 
 
 
 
 
 
 
 
 
 
c107618
 
 
 
81d5fa1
 
 
 
 
544c445
81d5fa1
 
c107618
 
 
 
 
 
 
 
 
544c445
c107618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544c445
c107618
 
 
544c445
 
c107618
 
 
 
 
 
 
 
 
81d5fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544c445
81d5fa1
 
 
 
544c445
 
81d5fa1
 
 
 
 
 
 
 
 
c107618
 
 
 
 
544c445
c107618
544c445
 
 
 
 
 
 
c107618
 
 
 
544c445
 
c107618
 
544c445
 
 
 
 
 
 
c107618
 
544c445
 
 
 
 
c107618
 
 
 
 
544c445
 
 
 
 
c107618
 
 
 
 
544c445
 
 
 
 
c107618
 
 
544c445
 
 
 
 
c107618
 
 
 
 
 
 
544c445
 
 
 
 
c107618
 
 
 
 
544c445
 
 
 
 
c107618
 
 
 
 
544c445
 
 
 
 
c107618
 
 
 
 
 
 
544c445
 
 
 
 
c107618
 
 
544c445
 
 
 
 
c107618
 
 
 
544c445
 
 
 
 
c107618
 
 
 
 
544c445
 
 
 
 
c107618
 
544c445
 
 
 
 
c107618
 
 
 
 
 
544c445
 
 
 
 
c107618
 
 
 
544c445
 
 
 
 
c107618
 
 
 
544c445
 
 
 
 
c107618
 
 
544c445
 
 
 
 
c107618
 
544c445
 
 
 
 
 
 
 
 
 
 
 
c107618
544c445
 
 
 
 
 
 
 
c107618
544c445
 
 
 
 
 
 
c107618
544c445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c107618
544c445
 
 
 
c107618
544c445
 
 
 
 
 
 
 
c107618
544c445
 
 
 
 
 
 
c107618
544c445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c107618
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
"""Landmark manipulation engine with Free-Form Deformation (FFD/RBF).

All v1/v2 UI uses RELATIVE sliders (0-100 intensity).
Millimeter inputs exist only in v3+ with FLAME calibrated metric space.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np

from landmarkdiff.landmarks import FaceLandmarks

if TYPE_CHECKING:
    from landmarkdiff.clinical import ClinicalFlags


@dataclass(frozen=True)
class DeformationHandle:
    """A control handle for FFD manipulation."""

    landmark_index: int
    displacement: np.ndarray  # (2,) or (3,) pixel displacement
    influence_radius: float   # Gaussian RBF radius in pixels


# Procedure-specific landmark indices from the technical specification
PROCEDURE_LANDMARKS: dict[str, list[int]] = {
    "rhinoplasty": [
        1, 2, 4, 5, 6, 19, 94, 141, 168, 195, 197, 236, 240,
        274, 275, 278, 279, 294, 326, 327, 360, 363, 370, 456, 460,
    ],
    "blepharoplasty": [
        33, 7, 163, 144, 145, 153, 154, 155, 157, 158, 159, 160, 161, 246,
        362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386,
        385, 384, 398,
    ],
    "rhytidectomy": [
        10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 150, 152,
        162, 172, 176, 187, 207, 213, 234, 284, 297, 323, 332, 338, 356,
        361, 365, 377, 378, 379, 389, 397, 400, 427, 454,
    ],
    "orthognathic": [
        0, 17, 18, 36, 37, 39, 40, 57, 61, 78, 80, 81, 82, 84, 87, 88,
        91, 95, 146, 167, 169, 170, 175, 181, 191, 200, 201, 202, 204,
        208, 211, 212, 214, 269, 270, 291, 311, 312, 317, 321, 324, 325,
        375, 396, 405, 407, 415,
    ],
    "brow_lift": [
        10, 21, 46, 52, 53, 54, 55, 63, 65, 66, 67, 68, 69, 70, 71,
        103, 104, 105, 107, 108, 109, 151, 282, 283, 284, 285, 293, 295,
        296, 297, 298, 299, 300, 301, 332, 333, 334, 336, 337, 338,
    ],
    "mentoplasty": [
        0, 17, 18, 57, 83, 84, 85, 86, 87, 146, 167, 169, 170, 175,
        181, 191, 199, 200, 201, 202, 204, 208, 211, 212, 214, 316, 317,
        321, 324, 325, 375, 396, 405, 411, 415, 419, 421, 422, 424,
    ],
}

# Default influence radii per procedure (in pixels at 512x512)
PROCEDURE_RADIUS: dict[str, float] = {
    "rhinoplasty": 30.0,
    "blepharoplasty": 15.0,
    "rhytidectomy": 40.0,
    "orthognathic": 35.0,
    "brow_lift": 25.0,
    "mentoplasty": 30.0,
}


def gaussian_rbf_deform(
    landmarks: np.ndarray,
    handle: DeformationHandle,
) -> np.ndarray:
    """Apply Gaussian RBF deformation around a control handle.

    Formula: delta_p_i = delta_handle * exp(-||p_i - p_handle||^2 / (2 * r^2))

    Args:
        landmarks: (N, 2) or (N, 3) landmark coordinates in pixels.
        handle: Control handle specifying index, displacement, and radius.

    Returns:
        New landmark array with deformation applied (immutable — returns copy).
    """
    result = landmarks.copy()
    center = landmarks[handle.landmark_index, :2]
    displacement = handle.displacement[:2]

    distances_sq = np.sum((landmarks[:, :2] - center) ** 2, axis=1)
    weights = np.exp(-distances_sq / (2.0 * handle.influence_radius ** 2))

    result[:, 0] += displacement[0] * weights
    result[:, 1] += displacement[1] * weights

    if landmarks.shape[1] > 2 and len(handle.displacement) > 2:
        result[:, 2] += handle.displacement[2] * weights

    return result


def apply_procedure_preset(
    face: FaceLandmarks,
    procedure: str,
    intensity: float = 50.0,
    image_size: int = 512,
    clinical_flags: ClinicalFlags | None = None,
    displacement_model_path: str | None = None,
    noise_scale: float = 0.0,
) -> FaceLandmarks:
    """Apply a surgical procedure preset to landmarks.

    Args:
        face: Input face landmarks.
        procedure: One of the supported procedures (see PROCEDURE_LANDMARKS).
        intensity: Relative intensity 0-100 (mild=33, moderate=66, aggressive=100).
        image_size: Reference image size for displacement scaling.
        clinical_flags: Optional clinical condition flags.
        displacement_model_path: Path to a fitted DisplacementModel (.npz).
            When provided, uses data-driven displacements from real surgery pairs
            instead of hand-tuned RBF vectors.
        noise_scale: Variation noise scale for data-driven mode (0=deterministic).

    Returns:
        New FaceLandmarks with manipulated landmarks.
    """
    if procedure not in PROCEDURE_LANDMARKS:
        raise ValueError(f"Unknown procedure: {procedure}. Choose from {list(PROCEDURE_LANDMARKS)}")

    landmarks = face.landmarks.copy()
    scale = intensity / 100.0

    # Data-driven displacement mode
    if displacement_model_path is not None:
        return _apply_data_driven(
            face, procedure, scale, displacement_model_path, noise_scale,
        )

    indices = PROCEDURE_LANDMARKS[procedure]
    radius = PROCEDURE_RADIUS[procedure]

    # Ehlers-Danlos: wider influence radii for hypermobile tissue
    if clinical_flags and clinical_flags.ehlers_danlos:
        radius *= 1.5

    # Procedure-specific displacement vectors (normalized to image_size)
    pixel_scale = image_size / 512.0
    handles = _get_procedure_handles(procedure, indices, scale, radius * pixel_scale, pixel_scale)

    # Bell's palsy: remove handles on the affected (paralyzed) side
    if clinical_flags and clinical_flags.bells_palsy:
        from landmarkdiff.clinical import get_bells_palsy_side_indices
        affected = get_bells_palsy_side_indices(clinical_flags.bells_palsy_side)
        affected_indices = set()
        for region_indices in affected.values():
            affected_indices.update(region_indices)
        handles = [h for h in handles if h.landmark_index not in affected_indices]

    # Convert to pixel space for deformation
    pixel_landmarks = landmarks.copy()
    pixel_landmarks[:, 0] *= face.image_width
    pixel_landmarks[:, 1] *= face.image_height

    for handle in handles:
        pixel_landmarks = gaussian_rbf_deform(pixel_landmarks, handle)

    # Convert back to normalized and clamp to [0, 1]
    result = pixel_landmarks.copy()
    result[:, 0] /= face.image_width
    result[:, 1] /= face.image_height
    result[:, :2] = np.clip(result[:, :2], 0.0, 1.0)
    result[:, 2] = np.clip(result[:, 2], 0.0, 1.0)

    return FaceLandmarks(
        landmarks=result,
        image_width=face.image_width,
        image_height=face.image_height,
        confidence=face.confidence,
    )


def _apply_data_driven(
    face: FaceLandmarks,
    procedure: str,
    scale: float,
    model_path: str,
    noise_scale: float = 0.0,
) -> FaceLandmarks:
    """Apply data-driven displacements from a fitted DisplacementModel.

    The model provides mean displacement vectors learned from real surgery pairs,
    applied directly to all 478 landmarks (not just procedure-specific subset).
    """
    from landmarkdiff.displacement_model import DisplacementModel

    model = DisplacementModel.load(model_path)
    field = model.get_displacement_field(
        procedure=procedure,
        intensity=scale,
        noise_scale=noise_scale,
    )

    # field is (478, 2) in normalized coordinates — add to landmarks
    landmarks = face.landmarks.copy()
    n_lm = min(landmarks.shape[0], field.shape[0])
    landmarks[:n_lm, :2] += field[:n_lm]

    # Clamp x,y to [0, 1] (preserve z-depth coordinate)
    landmarks[:n_lm, :2] = np.clip(landmarks[:n_lm, :2], 0.0, 1.0)

    return FaceLandmarks(
        landmarks=landmarks,
        image_width=face.image_width,
        image_height=face.image_height,
        confidence=face.confidence,
    )


def _get_procedure_handles(
    procedure: str,
    indices: list[int],
    scale: float,
    radius: float,
    pixel_scale: float = 1.0,
) -> list[DeformationHandle]:
    """Generate anatomically-grounded deformation handles for a procedure.

    Displacements are in 2D pixel space (X, Y) since the mesh conditioning
    and TPS warp are both 2D. Values calibrated to look natural at 512x512
    and scaled by pixel_scale for other resolutions.
    Based on anthropometric studies (Singh et al. TIFS 2010).
    """
    handles = []

    if procedure == "rhinoplasty":
        # --- Alar base narrowing: move nostrils inward (toward midline) ---
        # Left nostril landmarks (viewer's left) → move RIGHT (+X) toward midline
        left_alar = [240, 236, 141]
        for idx in left_alar:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([2.5 * scale, 0.0]),
                    influence_radius=radius * 0.6,
                ))
        # Right nostril landmarks (viewer's right) → move LEFT (-X) toward midline
        right_alar = [460, 456, 274, 275, 278, 279, 363, 370]
        for idx in right_alar:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-2.5 * scale, 0.0]),
                    influence_radius=radius * 0.6,
                ))

        # --- Tip refinement: subtle upward rotation + narrowing ---
        tip_indices = [1, 2, 94, 19]
        for idx in tip_indices:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -2.0 * scale]),
                    influence_radius=radius * 0.5,
                ))

        # --- Dorsum narrowing: bilateral squeeze of nasal bridge ---
        dorsum_left = [195, 197, 236]
        for idx in dorsum_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([1.5 * scale, 0.0]),
                    influence_radius=radius * 0.5,
                ))
        dorsum_right = [326, 327, 456]
        for idx in dorsum_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-1.5 * scale, 0.0]),
                    influence_radius=radius * 0.5,
                ))

    elif procedure == "blepharoplasty":
        # --- Upper lid elevation (primary effect) ---
        upper_lid_left = [159, 160, 161]  # central upper lid
        upper_lid_right = [386, 385, 384]
        for idx in upper_lid_left + upper_lid_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -2.0 * scale]),
                    influence_radius=radius,
                ))
        # --- Medial/lateral lid corners: less displacement (tapered) ---
        corner_left = [158, 157, 133, 33]
        corner_right = [387, 388, 362, 263]
        for idx in corner_left + corner_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -0.8 * scale]),
                    influence_radius=radius * 0.7,
                ))
        # --- Subtle lower lid tightening ---
        lower_lid_left = [145, 153, 154]
        lower_lid_right = [374, 380, 381]
        for idx in lower_lid_left + lower_lid_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, 0.5 * scale]),
                    influence_radius=radius * 0.5,
                ))

    elif procedure == "rhytidectomy":
        # Different displacement vectors by anatomical sub-region.
        # Jowl area: strongest lift (upward + toward ear)
        jowl_left = [132, 136, 172, 58, 150, 176]
        for idx in jowl_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-2.5 * scale, -3.0 * scale]),
                    influence_radius=radius,
                ))
        jowl_right = [361, 365, 397, 288, 379, 400]
        for idx in jowl_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([2.5 * scale, -3.0 * scale]),
                    influence_radius=radius,
                ))
        # Chin/submental: upward only (no lateral)
        chin = [152, 148, 377, 378]
        for idx in chin:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -2.0 * scale]),
                    influence_radius=radius * 0.8,
                ))
        # Temple/upper face: very mild lift
        temple_left = [10, 21, 54, 67, 103, 109, 162, 127]
        temple_right = [284, 297, 332, 338, 323, 356, 389, 454]
        for idx in temple_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-0.5 * scale, -1.0 * scale]),
                    influence_radius=radius * 0.6,
                ))
        for idx in temple_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.5 * scale, -1.0 * scale]),
                    influence_radius=radius * 0.6,
                ))

    elif procedure == "orthognathic":
        # --- Mandible repositioning: move jaw up and forward (visible as upward in 2D) ---
        lower_jaw = [17, 18, 200, 201, 202, 204, 208, 211, 212, 214]
        for idx in lower_jaw:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -3.0 * scale]),
                    influence_radius=radius,
                ))
        # --- Chin projection: move chin point forward/upward ---
        chin_pts = [175, 170, 169, 167, 396]
        for idx in chin_pts:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -2.0 * scale]),
                    influence_radius=radius * 0.7,
                ))
        # --- Lateral jaw: bilateral symmetric inward pull for narrowing ---
        jaw_left = [57, 61, 78, 91, 95, 146, 181]
        for idx in jaw_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([1.5 * scale, -1.0 * scale]),
                    influence_radius=radius * 0.8,
                ))
        jaw_right = [291, 311, 312, 321, 324, 325, 375, 405]
        for idx in jaw_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-1.5 * scale, -1.0 * scale]),
                    influence_radius=radius * 0.8,
                ))

    elif procedure == "brow_lift":
        # --- Forehead/brow elevation: lift eyebrows upward ---
        # Central brow landmarks
        brow_left = [46, 52, 53, 55, 65, 66, 105, 107]
        for idx in brow_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -3.0 * scale]),
                    influence_radius=radius,
                ))
        brow_right = [282, 283, 285, 295, 296, 334, 336]
        for idx in brow_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -3.0 * scale]),
                    influence_radius=radius,
                ))
        # Lateral brow: slightly less lift, mild outward pull
        lateral_left = [63, 67, 68, 69, 70, 71, 103, 104, 108, 109]
        for idx in lateral_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-0.5 * scale, -2.0 * scale]),
                    influence_radius=radius * 0.8,
                ))
        lateral_right = [293, 297, 298, 299, 300, 301, 332, 333, 337, 338]
        for idx in lateral_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.5 * scale, -2.0 * scale]),
                    influence_radius=radius * 0.8,
                ))
        # Forehead hairline: subtle upward shift
        hairline = [10, 21, 54, 151, 284]
        for idx in hairline:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, -1.0 * scale]),
                    influence_radius=radius * 1.2,
                ))

    elif procedure == "mentoplasty":
        # --- Chin augmentation/reduction: project chin forward and down ---
        # Central chin point: strongest projection
        chin_center = [175, 170, 169, 199, 200]
        for idx in chin_center:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, 2.5 * scale]),
                    influence_radius=radius,
                ))
        # Lateral chin contour: bilateral symmetric outward projection
        chin_left = [17, 18, 83, 84, 85, 86, 146, 167, 181, 191]
        for idx in chin_left:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([-1.0 * scale, 1.5 * scale]),
                    influence_radius=radius * 0.8,
                ))
        chin_right = [316, 317, 321, 324, 325, 375, 396, 411, 415, 419]
        for idx in chin_right:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([1.0 * scale, 1.5 * scale]),
                    influence_radius=radius * 0.8,
                ))
        # Jawline transition: subtle smoothing
        jaw_transition = [57, 87, 201, 202, 204, 208, 211, 212, 214, 405, 421, 422, 424]
        for idx in jaw_transition:
            if idx in indices:
                handles.append(DeformationHandle(
                    landmark_index=idx,
                    displacement=np.array([0.0, 0.8 * scale]),
                    influence_radius=radius * 0.6,
                ))

    # Scale displacements for non-512 image sizes
    if pixel_scale != 1.0:
        handles = [
            DeformationHandle(
                landmark_index=h.landmark_index,
                displacement=h.displacement * pixel_scale,
                influence_radius=h.influence_radius,
            )
            for h in handles
        ]

    return handles