dreamlessx commited on
Commit
b0a1702
·
verified ·
1 Parent(s): 82d5f3d

Update landmarkdiff/synthetic/tps_warp.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/synthetic/tps_warp.py +59 -14
landmarkdiff/synthetic/tps_warp.py CHANGED
@@ -1,7 +1,8 @@
1
- """TPS warping for synthetic pair generation.
2
 
3
- Only warps deformable tissue - rigid structures (teeth, sclera) get
4
- rigid translation instead. Prevents "rubber teeth" from naive TPS.
 
5
  """
6
 
7
  from __future__ import annotations
@@ -14,7 +15,15 @@ def compute_tps_transform(
14
  src_pts: np.ndarray,
15
  dst_pts: np.ndarray,
16
  ) -> cv2.ThinPlateSplineShapeTransformer:
17
- """Fit a TPS transform from src to dst points."""
 
 
 
 
 
 
 
 
18
  src = src_pts.reshape(1, -1, 2).astype(np.float32)
19
  dst = dst_pts.reshape(1, -1, 2).astype(np.float32)
20
  matches = [cv2.DMatch(i, i, 0) for i in range(len(src_pts))]
@@ -30,7 +39,12 @@ def _subsample_control_points(
30
  max_points: int = 80,
31
  anchor_stride: int = 8,
32
  ) -> tuple[np.ndarray, np.ndarray]:
33
- """Keep all displaced points + sparse anchors. ~80 pts instead of 478, ~30x faster."""
 
 
 
 
 
34
  displacements = np.linalg.norm(dst - src, axis=1)
35
  displaced_mask = displacements > 0.5 # moved by > 0.5px
36
  displaced_idx = np.where(displaced_mask)[0]
@@ -61,7 +75,18 @@ def warp_image_tps(
61
  dst_landmarks: np.ndarray,
62
  rigid_mask: np.ndarray | None = None,
63
  ) -> np.ndarray:
64
- """Apply TPS warp to an image with optional rigid region preservation."""
 
 
 
 
 
 
 
 
 
 
 
65
  h, w = image.shape[:2]
66
 
67
  src_pts = src_landmarks.astype(np.float32)
@@ -87,8 +112,10 @@ def warp_image_tps(
87
  rigid_translation = _compute_rigid_translation(src_pts, dst_pts, rigid_mask, w, h)
88
  rigid_warped = _apply_rigid_translation(image, rigid_translation)
89
 
 
 
90
  # Composite: use rigid warp in rigid regions, TPS elsewhere
91
- mask_f = rigid_mask.astype(np.float32)
92
  if len(mask_f.shape) == 2:
93
  mask_f = np.stack([mask_f] * 3, axis=-1)
94
  mask_f = mask_f / 255.0 if mask_f.max() > 1 else mask_f
@@ -103,7 +130,10 @@ def _compute_tps_map(
103
  width: int,
104
  height: int,
105
  ) -> tuple[np.ndarray, np.ndarray]:
106
- """Build remap arrays from TPS control points via RBF interpolation."""
 
 
 
107
  # Displacement at control points
108
  dx = dst[:, 0] - src[:, 0]
109
  dy = dst[:, 1] - src[:, 1]
@@ -151,12 +181,16 @@ def _solve_tps_weights(
151
  control_pts: np.ndarray,
152
  values: np.ndarray,
153
  ) -> np.ndarray:
154
- """Solve TPS system -> weight vector [w1..wn, a0, a1, a2]."""
 
 
 
 
155
  n = len(control_pts)
156
 
157
  # Build kernel matrix K (vectorized)
158
  diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
159
- r_mat = np.sqrt((diff**2).sum(axis=2)) # (n, n)
160
  K = np.zeros((n, n))
161
  nz = r_mat > 0
162
  K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
@@ -205,7 +239,7 @@ def _evaluate_tps(
205
  # Compute all distances at once: (M, n)
206
  dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
207
  dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
208
- r = np.sqrt(dx**2 + dy**2)
209
 
210
  # TPS kernel: r^2 * log(r), with r=0 -> 0
211
  kernel = np.zeros_like(r)
@@ -230,8 +264,9 @@ def _compute_rigid_translation(
230
  inside = []
231
  for i, (x, y) in enumerate(src):
232
  ix, iy = int(x), int(y)
233
- if 0 <= ix < width and 0 <= iy < height and mask[iy, ix] > 0:
234
- inside.append(i)
 
235
 
236
  if not inside:
237
  return np.array([0.0, 0.0])
@@ -257,7 +292,17 @@ def generate_random_warp(
257
  max_displacement: float = 15.0,
258
  rng: np.random.Generator | None = None,
259
  ) -> np.ndarray:
260
- """Generate randomly warped landmarks for synthetic data."""
 
 
 
 
 
 
 
 
 
 
261
  rng = rng or np.random.default_rng()
262
  result = landmarks.copy()
263
 
 
1
+ """Thin-Plate Spline warping for synthetic training pair generation.
2
 
3
+ Applies TPS warp ONLY to deformable tissue regions. Rigid structures
4
+ (teeth, sclera) are rigidly translated, not warped. This prevents
5
+ the "rubber teeth" artifact from naive TPS.
6
  """
7
 
8
  from __future__ import annotations
 
15
  src_pts: np.ndarray,
16
  dst_pts: np.ndarray,
17
  ) -> cv2.ThinPlateSplineShapeTransformer:
18
+ """Compute a TPS transform from source to destination point pairs.
19
+
20
+ Args:
21
+ src_pts: (N, 2) source control points.
22
+ dst_pts: (N, 2) destination control points.
23
+
24
+ Returns:
25
+ Fitted TPS transformer.
26
+ """
27
  src = src_pts.reshape(1, -1, 2).astype(np.float32)
28
  dst = dst_pts.reshape(1, -1, 2).astype(np.float32)
29
  matches = [cv2.DMatch(i, i, 0) for i in range(len(src_pts))]
 
39
  max_points: int = 80,
40
  anchor_stride: int = 8,
41
  ) -> tuple[np.ndarray, np.ndarray]:
42
+ """Subsample control points for faster TPS: all displaced + sparse anchors.
43
+
44
+ With 478 MediaPipe landmarks, full TPS requires solving a 481x481 system
45
+ and evaluating 478 RBFs at each pixel — very slow. Subsampling to ~80
46
+ points gives nearly identical results ~30x faster.
47
+ """
48
  displacements = np.linalg.norm(dst - src, axis=1)
49
  displaced_mask = displacements > 0.5 # moved by > 0.5px
50
  displaced_idx = np.where(displaced_mask)[0]
 
75
  dst_landmarks: np.ndarray,
76
  rigid_mask: np.ndarray | None = None,
77
  ) -> np.ndarray:
78
+ """Apply TPS warp to an image with optional rigid region preservation.
79
+
80
+ Args:
81
+ image: BGR input image.
82
+ src_landmarks: (N, 2) original landmark pixel coords.
83
+ dst_landmarks: (N, 2) target landmark pixel coords.
84
+ rigid_mask: Optional binary mask of rigid regions (teeth, sclera).
85
+ These regions are rigidly translated, not TPS-warped.
86
+
87
+ Returns:
88
+ Warped image.
89
+ """
90
  h, w = image.shape[:2]
91
 
92
  src_pts = src_landmarks.astype(np.float32)
 
112
  rigid_translation = _compute_rigid_translation(src_pts, dst_pts, rigid_mask, w, h)
113
  rigid_warped = _apply_rigid_translation(image, rigid_translation)
114
 
115
+ # Translate the mask to match the rigidly-shifted content
116
+ translated_mask = _apply_rigid_translation(rigid_mask, rigid_translation)
117
  # Composite: use rigid warp in rigid regions, TPS elsewhere
118
+ mask_f = translated_mask.astype(np.float32)
119
  if len(mask_f.shape) == 2:
120
  mask_f = np.stack([mask_f] * 3, axis=-1)
121
  mask_f = mask_f / 255.0 if mask_f.max() > 1 else mask_f
 
130
  width: int,
131
  height: int,
132
  ) -> tuple[np.ndarray, np.ndarray]:
133
+ """Compute pixel displacement maps from TPS control points.
134
+
135
+ Uses RBF interpolation of control point displacements.
136
+ """
137
  # Displacement at control points
138
  dx = dst[:, 0] - src[:, 0]
139
  dy = dst[:, 1] - src[:, 1]
 
181
  control_pts: np.ndarray,
182
  values: np.ndarray,
183
  ) -> np.ndarray:
184
+ """Solve for TPS weights given control points and target values.
185
+
186
+ Returns weight vector [w1..wn, a0, a1, a2] for n control points
187
+ plus affine terms.
188
+ """
189
  n = len(control_pts)
190
 
191
  # Build kernel matrix K (vectorized)
192
  diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
193
+ r_mat = np.sqrt((diff ** 2).sum(axis=2)) # (n, n)
194
  K = np.zeros((n, n))
195
  nz = r_mat > 0
196
  K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
 
239
  # Compute all distances at once: (M, n)
240
  dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
241
  dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
242
+ r = np.sqrt(dx ** 2 + dy ** 2)
243
 
244
  # TPS kernel: r^2 * log(r), with r=0 -> 0
245
  kernel = np.zeros_like(r)
 
264
  inside = []
265
  for i, (x, y) in enumerate(src):
266
  ix, iy = int(x), int(y)
267
+ if 0 <= ix < width and 0 <= iy < height:
268
+ if mask[iy, ix] > 0:
269
+ inside.append(i)
270
 
271
  if not inside:
272
  return np.array([0.0, 0.0])
 
292
  max_displacement: float = 15.0,
293
  rng: np.random.Generator | None = None,
294
  ) -> np.ndarray:
295
+ """Generate randomly warped landmarks for synthetic data.
296
+
297
+ Args:
298
+ landmarks: (N, 2) pixel coordinates.
299
+ procedure_indices: Which landmarks to warp.
300
+ max_displacement: Max pixel displacement.
301
+ rng: Random number generator.
302
+
303
+ Returns:
304
+ New landmark array with random deformations.
305
+ """
306
  rng = rng or np.random.default_rng()
307
  result = landmarks.copy()
308