Spaces:
Running
Running
Update landmarkdiff/synthetic/tps_warp.py to v0.3.2
Browse files
landmarkdiff/synthetic/tps_warp.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
|
@@ -14,7 +15,15 @@ def compute_tps_transform(
|
|
| 14 |
src_pts: np.ndarray,
|
| 15 |
dst_pts: np.ndarray,
|
| 16 |
) -> cv2.ThinPlateSplineShapeTransformer:
|
| 17 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
src = src_pts.reshape(1, -1, 2).astype(np.float32)
|
| 19 |
dst = dst_pts.reshape(1, -1, 2).astype(np.float32)
|
| 20 |
matches = [cv2.DMatch(i, i, 0) for i in range(len(src_pts))]
|
|
@@ -30,7 +39,12 @@ def _subsample_control_points(
|
|
| 30 |
max_points: int = 80,
|
| 31 |
anchor_stride: int = 8,
|
| 32 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 33 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
displacements = np.linalg.norm(dst - src, axis=1)
|
| 35 |
displaced_mask = displacements > 0.5 # moved by > 0.5px
|
| 36 |
displaced_idx = np.where(displaced_mask)[0]
|
|
@@ -61,7 +75,18 @@ def warp_image_tps(
|
|
| 61 |
dst_landmarks: np.ndarray,
|
| 62 |
rigid_mask: np.ndarray | None = None,
|
| 63 |
) -> np.ndarray:
|
| 64 |
-
"""Apply TPS warp to an image with optional rigid region preservation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
h, w = image.shape[:2]
|
| 66 |
|
| 67 |
src_pts = src_landmarks.astype(np.float32)
|
|
@@ -87,8 +112,10 @@ def warp_image_tps(
|
|
| 87 |
rigid_translation = _compute_rigid_translation(src_pts, dst_pts, rigid_mask, w, h)
|
| 88 |
rigid_warped = _apply_rigid_translation(image, rigid_translation)
|
| 89 |
|
|
|
|
|
|
|
| 90 |
# Composite: use rigid warp in rigid regions, TPS elsewhere
|
| 91 |
-
mask_f =
|
| 92 |
if len(mask_f.shape) == 2:
|
| 93 |
mask_f = np.stack([mask_f] * 3, axis=-1)
|
| 94 |
mask_f = mask_f / 255.0 if mask_f.max() > 1 else mask_f
|
|
@@ -103,7 +130,10 @@ def _compute_tps_map(
|
|
| 103 |
width: int,
|
| 104 |
height: int,
|
| 105 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 106 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 107 |
# Displacement at control points
|
| 108 |
dx = dst[:, 0] - src[:, 0]
|
| 109 |
dy = dst[:, 1] - src[:, 1]
|
|
@@ -151,12 +181,16 @@ def _solve_tps_weights(
|
|
| 151 |
control_pts: np.ndarray,
|
| 152 |
values: np.ndarray,
|
| 153 |
) -> np.ndarray:
|
| 154 |
-
"""Solve TPS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
n = len(control_pts)
|
| 156 |
|
| 157 |
# Build kernel matrix K (vectorized)
|
| 158 |
diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
|
| 159 |
-
r_mat = np.sqrt((diff**2).sum(axis=2)) # (n, n)
|
| 160 |
K = np.zeros((n, n))
|
| 161 |
nz = r_mat > 0
|
| 162 |
K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
|
|
@@ -205,7 +239,7 @@ def _evaluate_tps(
|
|
| 205 |
# Compute all distances at once: (M, n)
|
| 206 |
dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
|
| 207 |
dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
|
| 208 |
-
r = np.sqrt(dx**2 + dy**2)
|
| 209 |
|
| 210 |
# TPS kernel: r^2 * log(r), with r=0 -> 0
|
| 211 |
kernel = np.zeros_like(r)
|
|
@@ -230,8 +264,9 @@ def _compute_rigid_translation(
|
|
| 230 |
inside = []
|
| 231 |
for i, (x, y) in enumerate(src):
|
| 232 |
ix, iy = int(x), int(y)
|
| 233 |
-
if 0 <= ix < width and 0 <= iy < height
|
| 234 |
-
|
|
|
|
| 235 |
|
| 236 |
if not inside:
|
| 237 |
return np.array([0.0, 0.0])
|
|
@@ -257,7 +292,17 @@ def generate_random_warp(
|
|
| 257 |
max_displacement: float = 15.0,
|
| 258 |
rng: np.random.Generator | None = None,
|
| 259 |
) -> np.ndarray:
|
| 260 |
-
"""Generate randomly warped landmarks for synthetic data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
rng = rng or np.random.default_rng()
|
| 262 |
result = landmarks.copy()
|
| 263 |
|
|
|
|
| 1 |
+
"""Thin-Plate Spline warping for synthetic training pair generation.
|
| 2 |
|
| 3 |
+
Applies TPS warp ONLY to deformable tissue regions. Rigid structures
|
| 4 |
+
(teeth, sclera) are rigidly translated, not warped. This prevents
|
| 5 |
+
the "rubber teeth" artifact from naive TPS.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
|
|
| 15 |
src_pts: np.ndarray,
|
| 16 |
dst_pts: np.ndarray,
|
| 17 |
) -> cv2.ThinPlateSplineShapeTransformer:
|
| 18 |
+
"""Compute a TPS transform from source to destination point pairs.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
src_pts: (N, 2) source control points.
|
| 22 |
+
dst_pts: (N, 2) destination control points.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Fitted TPS transformer.
|
| 26 |
+
"""
|
| 27 |
src = src_pts.reshape(1, -1, 2).astype(np.float32)
|
| 28 |
dst = dst_pts.reshape(1, -1, 2).astype(np.float32)
|
| 29 |
matches = [cv2.DMatch(i, i, 0) for i in range(len(src_pts))]
|
|
|
|
| 39 |
max_points: int = 80,
|
| 40 |
anchor_stride: int = 8,
|
| 41 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 42 |
+
"""Subsample control points for faster TPS: all displaced + sparse anchors.
|
| 43 |
+
|
| 44 |
+
With 478 MediaPipe landmarks, full TPS requires solving a 481x481 system
|
| 45 |
+
and evaluating 478 RBFs at each pixel — very slow. Subsampling to ~80
|
| 46 |
+
points gives nearly identical results ~30x faster.
|
| 47 |
+
"""
|
| 48 |
displacements = np.linalg.norm(dst - src, axis=1)
|
| 49 |
displaced_mask = displacements > 0.5 # moved by > 0.5px
|
| 50 |
displaced_idx = np.where(displaced_mask)[0]
|
|
|
|
| 75 |
dst_landmarks: np.ndarray,
|
| 76 |
rigid_mask: np.ndarray | None = None,
|
| 77 |
) -> np.ndarray:
|
| 78 |
+
"""Apply TPS warp to an image with optional rigid region preservation.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: BGR input image.
|
| 82 |
+
src_landmarks: (N, 2) original landmark pixel coords.
|
| 83 |
+
dst_landmarks: (N, 2) target landmark pixel coords.
|
| 84 |
+
rigid_mask: Optional binary mask of rigid regions (teeth, sclera).
|
| 85 |
+
These regions are rigidly translated, not TPS-warped.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Warped image.
|
| 89 |
+
"""
|
| 90 |
h, w = image.shape[:2]
|
| 91 |
|
| 92 |
src_pts = src_landmarks.astype(np.float32)
|
|
|
|
| 112 |
rigid_translation = _compute_rigid_translation(src_pts, dst_pts, rigid_mask, w, h)
|
| 113 |
rigid_warped = _apply_rigid_translation(image, rigid_translation)
|
| 114 |
|
| 115 |
+
# Translate the mask to match the rigidly-shifted content
|
| 116 |
+
translated_mask = _apply_rigid_translation(rigid_mask, rigid_translation)
|
| 117 |
# Composite: use rigid warp in rigid regions, TPS elsewhere
|
| 118 |
+
mask_f = translated_mask.astype(np.float32)
|
| 119 |
if len(mask_f.shape) == 2:
|
| 120 |
mask_f = np.stack([mask_f] * 3, axis=-1)
|
| 121 |
mask_f = mask_f / 255.0 if mask_f.max() > 1 else mask_f
|
|
|
|
| 130 |
width: int,
|
| 131 |
height: int,
|
| 132 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 133 |
+
"""Compute pixel displacement maps from TPS control points.
|
| 134 |
+
|
| 135 |
+
Uses RBF interpolation of control point displacements.
|
| 136 |
+
"""
|
| 137 |
# Displacement at control points
|
| 138 |
dx = dst[:, 0] - src[:, 0]
|
| 139 |
dy = dst[:, 1] - src[:, 1]
|
|
|
|
| 181 |
control_pts: np.ndarray,
|
| 182 |
values: np.ndarray,
|
| 183 |
) -> np.ndarray:
|
| 184 |
+
"""Solve for TPS weights given control points and target values.
|
| 185 |
+
|
| 186 |
+
Returns weight vector [w1..wn, a0, a1, a2] for n control points
|
| 187 |
+
plus affine terms.
|
| 188 |
+
"""
|
| 189 |
n = len(control_pts)
|
| 190 |
|
| 191 |
# Build kernel matrix K (vectorized)
|
| 192 |
diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
|
| 193 |
+
r_mat = np.sqrt((diff ** 2).sum(axis=2)) # (n, n)
|
| 194 |
K = np.zeros((n, n))
|
| 195 |
nz = r_mat > 0
|
| 196 |
K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
|
|
|
|
| 239 |
# Compute all distances at once: (M, n)
|
| 240 |
dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
|
| 241 |
dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
|
| 242 |
+
r = np.sqrt(dx ** 2 + dy ** 2)
|
| 243 |
|
| 244 |
# TPS kernel: r^2 * log(r), with r=0 -> 0
|
| 245 |
kernel = np.zeros_like(r)
|
|
|
|
| 264 |
inside = []
|
| 265 |
for i, (x, y) in enumerate(src):
|
| 266 |
ix, iy = int(x), int(y)
|
| 267 |
+
if 0 <= ix < width and 0 <= iy < height:
|
| 268 |
+
if mask[iy, ix] > 0:
|
| 269 |
+
inside.append(i)
|
| 270 |
|
| 271 |
if not inside:
|
| 272 |
return np.array([0.0, 0.0])
|
|
|
|
| 292 |
max_displacement: float = 15.0,
|
| 293 |
rng: np.random.Generator | None = None,
|
| 294 |
) -> np.ndarray:
|
| 295 |
+
"""Generate randomly warped landmarks for synthetic data.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
landmarks: (N, 2) pixel coordinates.
|
| 299 |
+
procedure_indices: Which landmarks to warp.
|
| 300 |
+
max_displacement: Max pixel displacement.
|
| 301 |
+
rng: Random number generator.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
New landmark array with random deformations.
|
| 305 |
+
"""
|
| 306 |
rng = rng or np.random.default_rng()
|
| 307 |
result = landmarks.copy()
|
| 308 |
|