VisionLanguageGroup commited on
Commit
4ce5a27
Β·
1 Parent(s): 02e04fb
Files changed (31) hide show
  1. _utils/load_track_data.py +1 -1
  2. _utils/track_args.py +1 -96
  3. inference_track.py +1 -1
  4. models/tra_post_model/{trackastra/data/wrfeat.py β†’ data.py} +0 -348
  5. models/tra_post_model/{trackastra/model/model.py β†’ model.py} +280 -12
  6. models/tra_post_model/trackastra/data/__init__.py +0 -18
  7. models/tra_post_model/trackastra/data/_check_ctc.py +0 -114
  8. models/tra_post_model/trackastra/data/_compute_overlap.py +0 -209
  9. models/tra_post_model/trackastra/data/augmentations.py +0 -557
  10. models/tra_post_model/trackastra/data/data.py +0 -1509
  11. models/tra_post_model/trackastra/data/distributed.py +0 -316
  12. models/tra_post_model/trackastra/data/example_data.py +0 -48
  13. models/tra_post_model/trackastra/data/features.py +0 -148
  14. models/tra_post_model/trackastra/data/matching.py +0 -251
  15. models/tra_post_model/trackastra/data/utils.py +0 -232
  16. models/tra_post_model/trackastra/model/__init__.py +0 -4
  17. models/tra_post_model/trackastra/model/model_api.py +0 -338
  18. models/tra_post_model/trackastra/model/model_parts.py +0 -287
  19. models/tra_post_model/trackastra/model/model_sd.py +0 -338
  20. models/tra_post_model/trackastra/model/predict.py +0 -188
  21. models/tra_post_model/trackastra/model/pretrained.json +0 -81
  22. models/tra_post_model/trackastra/model/pretrained.py +0 -90
  23. models/tra_post_model/trackastra/model/rope.py +0 -94
  24. models/tra_post_model/trackastra/utils/__init__.py +0 -14
  25. models/tra_post_model/{trackastra/tracking β†’ tracking}/__init__.py +0 -3
  26. models/tra_post_model/{trackastra/tracking β†’ tracking}/ilp.py +0 -0
  27. models/tra_post_model/{trackastra/tracking β†’ tracking}/track_graph.py +0 -0
  28. models/tra_post_model/{trackastra/tracking β†’ tracking}/tracking.py +1 -9
  29. models/tra_post_model/{trackastra/tracking β†’ tracking}/utils.py +73 -103
  30. models/tra_post_model/{trackastra/utils/utils.py β†’ utils.py} +0 -338
  31. tracking_one.py +6 -8
_utils/load_track_data.py CHANGED
@@ -9,7 +9,7 @@ import skimage.io as io
9
  import torchvision.transforms as T
10
  import cv2
11
  from tqdm import tqdm
12
- from models.tra_post_model.trackastra.utils import normalize_01, normalize
13
  IMG_SIZE = 512
14
 
15
  def _load_tiffs(folder: Path, dtype=None):
 
9
  import torchvision.transforms as T
10
  import cv2
11
  from tqdm import tqdm
12
+ from models.tra_post_model.utils import normalize_01, normalize
13
  IMG_SIZE = 512
14
 
15
  def _load_tiffs(folder: Path, dtype=None):
_utils/track_args.py CHANGED
@@ -14,49 +14,14 @@ def parse_train_args():
14
  is_config_file=True,
15
  help="config file path",
16
  )
17
- parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
18
- parser.add_argument("-o", "--outdir", type=str, default="runs")
19
- parser.add_argument("--name", type=str, help="Name to append to timestamp")
20
- parser.add_argument("--timestamp", type=bool, default=True)
21
- parser.add_argument(
22
- "-m",
23
- "--model",
24
- type=str,
25
- default="",
26
- help="load this model at start (e.g. to continue training)",
27
- )
28
- parser.add_argument(
29
- "--ndim", type=int, default=2, help="number of spatial dimensions"
30
- )
31
  parser.add_argument("-d", "--d_model", type=int, default=256)
32
  parser.add_argument("-w", "--window", type=int, default=10)
33
- parser.add_argument("--epochs", type=int, default=100)
34
- parser.add_argument("--warmup_epochs", type=int, default=10)
35
- parser.add_argument(
36
- "--detection_folders",
37
- type=str,
38
- nargs="+",
39
- default=["TRA"],
40
- help=(
41
- "Subfolders to search for detections. Defaults to `TRA`, which corresponds"
42
- " to using only the GT."
43
- ),
44
- )
45
- parser.add_argument("--downscale_temporal", type=int, default=1)
46
- parser.add_argument("--downscale_spatial", type=int, default=1)
47
  parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
48
- parser.add_argument("--from_subfolder", action="store_true")
49
- # parser.add_argument("--train_samples", type=int, default=50000)
50
  parser.add_argument("--num_encoder_layers", type=int, default=6)
51
  parser.add_argument("--num_decoder_layers", type=int, default=6)
52
  parser.add_argument("--pos_embed_per_dim", type=int, default=32)
53
  parser.add_argument("--feat_embed_per_dim", type=int, default=8)
54
  parser.add_argument("--dropout", type=float, default=0.00)
55
- parser.add_argument("--num_workers", type=int, default=4)
56
- parser.add_argument("--batch_size", type=int, default=1)
57
- parser.add_argument("--max_tokens", type=int, default=None)
58
- parser.add_argument("--delta_cutoff", type=int, default=2)
59
- parser.add_argument("--lr", type=float, default=1e-4)
60
  parser.add_argument(
61
  "--attn_positional_bias",
62
  type=str,
@@ -65,72 +30,12 @@ def parse_train_args():
65
  )
66
  parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
67
  parser.add_argument("--attn_dist_mode", default="v0")
68
- parser.add_argument("--mixedp", type=bool, default=True)
69
- parser.add_argument("--dry", action="store_true")
70
- parser.add_argument("--profile", action="store_true")
71
- parser.add_argument(
72
- "--features",
73
- type=str,
74
- choices=[
75
- "none",
76
- "regionprops",
77
- "regionprops2",
78
- "patch",
79
- "patch_regionprops",
80
- "wrfeat",
81
- ],
82
- default="wrfeat",
83
- )
84
  parser.add_argument(
85
  "--causal_norm",
86
  type=str,
87
  choices=["none", "linear", "softmax", "quiet_softmax"],
88
  default="quiet_softmax",
89
  )
90
- parser.add_argument("--div_upweight", type=float, default=2)
91
-
92
- parser.add_argument("--augment", type=int, default=3)
93
- parser.add_argument("--tracking_frequency", type=int, default=-1)
94
-
95
- parser.add_argument("--sanity_dist", action="store_true")
96
- parser.add_argument("--preallocate", type=bool, default=False)
97
- parser.add_argument("--only_prechecks", action="store_true")
98
- parser.add_argument(
99
- "--compress", type=bool, default=True, help="compress dataset"
100
- )
101
-
102
-
103
- parser.add_argument("--seed", type=int, default=None)
104
- parser.add_argument(
105
- "--logger",
106
- type=str,
107
- default="tensorboard",
108
- choices=["tensorboard", "wandb", "none"],
109
- )
110
- parser.add_argument("--wandb_project", type=str, default="trackastra")
111
- parser.add_argument(
112
- "--crop_size",
113
- type=int,
114
- # required=True,
115
- nargs="+",
116
- default=None,
117
- help="random crop size for augmentation",
118
- )
119
- parser.add_argument(
120
- "--weight_by_ndivs",
121
- type=bool,
122
- default=True,
123
- help="Oversample windows that contain divisions",
124
- )
125
- parser.add_argument(
126
- "--weight_by_dataset",
127
- type=bool,
128
- default=False,
129
- help=(
130
- "Inversely weight datasets by number of samples (to counter dataset size"
131
- " imbalance)"
132
- ),
133
- )
134
 
135
  args, unknown_args = parser.parse_known_args()
136
 
@@ -154,4 +59,4 @@ def parse_train_args():
154
  # "--train_samples must be > 0, full dataset pass not supported."
155
  # )
156
 
157
- return args
 
14
  is_config_file=True,
15
  help="config file path",
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  parser.add_argument("-d", "--d_model", type=int, default=256)
18
  parser.add_argument("-w", "--window", type=int, default=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
 
 
20
  parser.add_argument("--num_encoder_layers", type=int, default=6)
21
  parser.add_argument("--num_decoder_layers", type=int, default=6)
22
  parser.add_argument("--pos_embed_per_dim", type=int, default=32)
23
  parser.add_argument("--feat_embed_per_dim", type=int, default=8)
24
  parser.add_argument("--dropout", type=float, default=0.00)
 
 
 
 
 
25
  parser.add_argument(
26
  "--attn_positional_bias",
27
  type=str,
 
30
  )
31
  parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
32
  parser.add_argument("--attn_dist_mode", default="v0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  parser.add_argument(
34
  "--causal_norm",
35
  type=str,
36
  choices=["none", "linear", "softmax", "quiet_softmax"],
37
  default="quiet_softmax",
38
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  args, unknown_args = parser.parse_known_args()
41
 
 
59
  # "--train_samples must be > 0, full dataset pass not supported."
60
  # )
61
 
62
+ return args
inference_track.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  from tqdm import tqdm
8
  from huggingface_hub import hf_hub_download
9
  from tracking_one import TrackingModule
10
- from models.tra_post_model.trackastra.tracking import graph_to_ctc
11
 
12
  MODEL = None
13
  DEVICE = torch.device("cpu")
 
7
  from tqdm import tqdm
8
  from huggingface_hub import hf_hub_download
9
  from tracking_one import TrackingModule
10
+ from models.tra_post_model.tracking import graph_to_ctc
11
 
12
  MODEL = None
13
  DEVICE = torch.device("cpu")
models/tra_post_model/{trackastra/data/wrfeat.py β†’ data.py} RENAMED
@@ -17,11 +17,6 @@ from skimage.measure import regionprops, regionprops_table
17
  from tqdm import tqdm
18
  from typing import Tuple, Optional, Sequence, Union, List
19
  import typing
20
-
21
- try:
22
- from .utils import load_tiff_timeseries
23
- except:
24
- from utils import load_tiff_timeseries
25
  import torch
26
  logger = logging.getLogger(__name__)
27
 
@@ -42,36 +37,6 @@ _PROPERTIES = {
42
  }
43
 
44
 
45
- def _filter_points(
46
- points: np.ndarray, shape: Tuple[int], origin: Optional[Tuple[int]] = None
47
- ) -> np.ndarray:
48
- """Returns indices of points that are inside the shape extent and given origin."""
49
- ndim = points.shape[-1]
50
- if origin is None:
51
- origin = (0,) * ndim
52
-
53
- idx = tuple(
54
- np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i])
55
- for i in range(ndim)
56
- )
57
- idx = np.where(np.all(idx, axis=0))[0]
58
- return idx
59
-
60
-
61
- def _border_dist(mask: np.ndarray, cutoff: float = 5):
62
- """Returns distance to border normalized to 0 (at least cutoff away) and 1 (at border)."""
63
- border = np.zeros_like(mask)
64
-
65
- # only apply to last two dimensions
66
- ss = tuple(
67
- slice(None) if i < mask.ndim - 2 else slice(1, -1)
68
- for i, s in enumerate(mask.shape)
69
- )
70
- border[ss] = 1
71
- dist = 1 - np.minimum(edt(border) / cutoff, 1)
72
- return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist))
73
-
74
-
75
  def _border_dist_fast(mask: np.ndarray, cutoff: float = 5):
76
  cutoff = int(cutoff)
77
  border = np.ones(mask.shape, dtype=np.float32)
@@ -241,268 +206,6 @@ class WRFeatures:
241
  )
242
 
243
 
244
- # augmentations
245
-
246
-
247
- class WRRandomCrop:
248
- """windowed region random crop augmentation."""
249
-
250
- def __init__(
251
- self,
252
- crop_size: Optional[Union[int, Tuple[int]]] = None,
253
- ndim: int = 2,
254
- ) -> None:
255
- """crop_size: tuple of int
256
- can be tuple of length 1 (all dimensions)
257
- of length ndim (y,x,...)
258
- of length 2*ndim (y1,y2, x1,x2, ...).
259
- """
260
- if isinstance(crop_size, int):
261
- crop_size = (crop_size,) * 2 * ndim
262
- elif isinstance(crop_size, Iterable):
263
- pass
264
- else:
265
- raise ValueError(f"{crop_size} has to be int or tuple of int")
266
-
267
- if len(crop_size) == 1:
268
- crop_size = (crop_size[0],) * 2 * ndim
269
- elif len(crop_size) == ndim:
270
- crop_size = tuple(itertools.chain(*tuple((c, c) for c in crop_size)))
271
- elif len(crop_size) == 2 * ndim:
272
- pass
273
- else:
274
- raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}")
275
-
276
- crop_size = np.array(crop_size)
277
- self._ndim = ndim
278
- self._crop_bounds = crop_size[::2], crop_size[1::2]
279
- self._rng = np.random.RandomState()
280
-
281
- def __call__(self, features: WRFeatures):
282
- crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1)
283
- points = features.coords
284
-
285
- if len(points) == 0:
286
- print("No points given, cannot ensure inside points")
287
- return features
288
-
289
- # sample point and corner relative to it
290
-
291
- _idx = np.random.randint(len(points))
292
- corner = (
293
- points[_idx]
294
- - crop_size
295
- + 1
296
- + self._rng.randint(crop_size // 4, 3 * crop_size // 4)
297
- )
298
-
299
- idx = _filter_points(points, shape=crop_size, origin=corner)
300
-
301
- return (
302
- WRFeatures(
303
- coords=points[idx],
304
- labels=features.labels[idx],
305
- timepoints=features.timepoints[idx],
306
- features=OrderedDict((k, v[idx]) for k, v in features.features.items()),
307
- ),
308
- idx,
309
- )
310
-
311
-
312
- class WRBaseAugmentation:
313
- def __init__(self, p: float = 0.5) -> None:
314
- self._p = p
315
- self._rng = np.random.RandomState()
316
-
317
- def __call__(self, features: WRFeatures):
318
- if self._rng.rand() > self._p or len(features) == 0:
319
- return features
320
- return self._augment(features)
321
-
322
- def _augment(self, features: WRFeatures):
323
- raise NotImplementedError()
324
-
325
-
326
- class WRRandomFlip(WRBaseAugmentation):
327
- def _augment(self, features: WRFeatures):
328
- ndim = features.ndim
329
- flip = self._rng.randint(0, 2, features.ndim)
330
- points = features.coords.copy()
331
- for i, f in enumerate(flip):
332
- if f == 1:
333
- points[:, ndim - i - 1] *= -1
334
- return WRFeatures(
335
- coords=points,
336
- labels=features.labels,
337
- timepoints=features.timepoints,
338
- features=features.features,
339
- )
340
-
341
-
342
- def _scale_matrix(sz: float, sy: float, sx: float):
343
- return np.diag([sz, sy, sx])
344
-
345
-
346
- # def _scale_matrix(sy: float, sx: float):
347
- # return np.array([[1, 0, 0], [0, sy, 0], [0, 0, sx]])
348
-
349
-
350
- def _shear_matrix(shy: float, shx: float):
351
- return np.array([[1, 0, 0], [0, 1 + shx * shy, shy], [0, shx, 1]])
352
-
353
-
354
- def _rotation_matrix(theta: float):
355
- return np.array([
356
- [1, 0, 0],
357
- [0, np.cos(theta), -np.sin(theta)],
358
- [0, np.sin(theta), np.cos(theta)],
359
- ])
360
-
361
-
362
- def _transform_affine(k: str, v: np.ndarray, M: np.ndarray):
363
- ndim = len(M)
364
- if k == "area":
365
- v = np.linalg.det(M) * v
366
- elif k == "equivalent_diameter_area":
367
- v = np.linalg.det(M) ** (1 / len(M)) * v
368
-
369
- elif k == "inertia_tensor":
370
- # v' = M * v * M^T
371
- v = v.reshape(-1, ndim, ndim)
372
- # v * M^T
373
- v = np.einsum("ijk, mk -> ijm", v, M)
374
- # M * v
375
- v = np.einsum("ij, kjm -> kim", M, v)
376
- v = v.reshape(-1, ndim * ndim)
377
- elif k in (
378
- "intensity_mean",
379
- "intensity_std",
380
- "intensity_max",
381
- "intensity_min",
382
- "border_dist",
383
- ):
384
- pass
385
- else:
386
- raise ValueError(f"Don't know how to affinely transform {k}")
387
- return v
388
-
389
-
390
- class WRRandomAffine(WRBaseAugmentation):
391
- def __init__(
392
- self,
393
- degrees: float = 10,
394
- scale: float = (0.9, 1.1),
395
- shear: float = (0.1, 0.1),
396
- p: float = 0.5,
397
- ):
398
- super().__init__(p)
399
- self.degrees = degrees if degrees is not None else 0
400
- self.scale = scale if scale is not None else (1, 1)
401
- self.shear = shear if shear is not None else (0, 0)
402
-
403
- def _augment(self, features: WRFeatures):
404
- degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi
405
- scale = self._rng.uniform(*self.scale, 3)
406
- shy = self._rng.uniform(-self.shear[0], self.shear[0])
407
- shx = self._rng.uniform(-self.shear[1], self.shear[1])
408
-
409
- self._M = (
410
- _rotation_matrix(degrees) @ _scale_matrix(*scale) @ _shear_matrix(shy, shx)
411
- )
412
-
413
- # M is by default 3D , we need to remove the last dimension for 2D
414
- self._M = self._M[-features.ndim :, -features.ndim :]
415
- points = features.coords @ self._M.T
416
-
417
- feats = OrderedDict(
418
- (k, _transform_affine(k, v, self._M)) for k, v in features.features.items()
419
- )
420
-
421
- return WRFeatures(
422
- coords=points,
423
- labels=features.labels,
424
- timepoints=features.timepoints,
425
- features=feats,
426
- )
427
-
428
-
429
- class WRRandomBrightness(WRBaseAugmentation):
430
- def __init__(
431
- self,
432
- scale: Tuple[float] = (0.5, 2.0),
433
- shift: Tuple[float] = (-0.1, 0.1),
434
- p: float = 0.5,
435
- ):
436
- super().__init__(p)
437
- self.scale = scale
438
- self.shift = shift
439
-
440
- def _augment(self, features: WRFeatures):
441
- scale = self._rng.uniform(*self.scale)
442
- shift = self._rng.uniform(*self.shift)
443
-
444
- key_vals = []
445
-
446
- for k, v in features.features.items():
447
- if "intensity" in k:
448
- v = v * scale + shift
449
- key_vals.append((k, v))
450
- feats = OrderedDict(key_vals)
451
- return WRFeatures(
452
- coords=features.coords,
453
- labels=features.labels,
454
- timepoints=features.timepoints,
455
- features=feats,
456
- )
457
-
458
-
459
- class WRRandomOffset(WRBaseAugmentation):
460
- def __init__(self, offset: float = (-3, 3), p: float = 0.5):
461
- super().__init__(p)
462
- self.offset = offset
463
-
464
- def _augment(self, features: WRFeatures):
465
- offset = self._rng.uniform(*self.offset, features.coords.shape)
466
- coords = features.coords + offset
467
- return WRFeatures(
468
- coords=coords,
469
- labels=features.labels,
470
- timepoints=features.timepoints,
471
- features=features.features,
472
- )
473
-
474
-
475
- class WRRandomMovement(WRBaseAugmentation):
476
- """random global linear shift."""
477
-
478
- def __init__(self, offset: float = (-10, 10), p: float = 0.5):
479
- super().__init__(p)
480
- self.offset = offset
481
-
482
- def _augment(self, features: WRFeatures):
483
- base_offset = self._rng.uniform(*self.offset, features.coords.shape[-1])
484
- tmin = features.timepoints.min()
485
- offset = (features.timepoints[:, None] - tmin) * base_offset[None]
486
- coords = features.coords + offset
487
-
488
- return WRFeatures(
489
- coords=coords,
490
- labels=features.labels,
491
- timepoints=features.timepoints,
492
- features=features.features,
493
- )
494
-
495
-
496
- class WRAugmentationPipeline:
497
- def __init__(self, augmentations: Sequence[WRBaseAugmentation]):
498
- self.augmentations = augmentations
499
-
500
- def __call__(self, feats: WRFeatures):
501
- for aug in self.augmentations:
502
- feats = aug(feats)
503
- return feats
504
-
505
-
506
  def get_features(
507
  detections: np.ndarray,
508
  imgs: Optional[np.ndarray] = None,
@@ -561,36 +264,6 @@ def _check_dimensions(x: np.ndarray, ndim: int):
561
  return x
562
 
563
 
564
- def build_windows(
565
- features: List[WRFeatures], window_size: int, progbar_class=tqdm
566
- ) -> List[dict]:
567
- windows = []
568
- for t1, t2 in progbar_class(
569
- zip(range(0, len(features)), range(window_size, len(features) + 1)),
570
- total=len(features) - window_size + 1,
571
- desc="Building windows",
572
- ):
573
- feat = WRFeatures.concat(features[t1:t2])
574
-
575
- labels = feat.labels
576
- timepoints = feat.timepoints
577
- coords = feat.coords
578
-
579
- if len(feat) == 0:
580
- coords = np.zeros((0, feat.ndim), dtype=int)
581
-
582
- w = dict(
583
- coords=coords,
584
- t1=t1,
585
- labels=labels,
586
- timepoints=timepoints,
587
- features=feat.features_stacked,
588
- )
589
- windows.append(w)
590
-
591
- logger.debug(f"Built {len(windows)} track windows.\n")
592
- return windows
593
-
594
  def build_windows_sd(
595
  features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm
596
  ) -> List[dict]:
@@ -632,24 +305,3 @@ def build_windows_sd(
632
  logger.debug(f"Built {len(windows)} track windows.\n")
633
  return windows
634
 
635
- if __name__ == "__main__":
636
- imgs = load_tiff_timeseries(
637
- # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01",
638
- "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01",
639
- )
640
- masks = load_tiff_timeseries(
641
- # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA",
642
- "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA",
643
- dtype=int,
644
- )
645
-
646
- features = get_features(detections=masks, imgs=imgs, ndim=3)
647
- windows = build_windows(features, window_size=4)
648
-
649
-
650
- # if __name__ == "__main__":
651
- # y = np.zeros((1, 100, 100), np.uint8)
652
- # y[:, 20:40, 20:60] = 1
653
- # x = y + np.random.normal(0, 0.1, y.shape)
654
-
655
- # f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area"))
 
17
  from tqdm import tqdm
18
  from typing import Tuple, Optional, Sequence, Union, List
19
  import typing
 
 
 
 
 
20
  import torch
21
  logger = logging.getLogger(__name__)
22
 
 
37
  }
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def _border_dist_fast(mask: np.ndarray, cutoff: float = 5):
41
  cutoff = int(cutoff)
42
  border = np.ones(mask.shape, dtype=np.float32)
 
206
  )
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def get_features(
210
  detections: np.ndarray,
211
  imgs: Optional[np.ndarray] = None,
 
264
  return x
265
 
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  def build_windows_sd(
268
  features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm
269
  ) -> List[dict]:
 
305
  logger.debug(f"Built {len(windows)} track windows.\n")
306
  return windows
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/{trackastra/model/model.py β†’ model.py} RENAMED
@@ -1,32 +1,300 @@
1
  """Transformer class."""
2
 
3
  import logging
 
4
  from collections import OrderedDict
5
  from pathlib import Path
6
- from typing import Literal
7
 
8
  import torch
 
9
 
10
- # from torch_geometric.nn import GATv2Conv
11
  import yaml
12
  from torch import nn
13
 
14
  import sys, os
15
- sys.path.append(os.path.join(os.getcwd(), "External_Repos", "trackastra"))
16
 
17
- # NoPositionalEncoding,
18
- from ..utils import blockwise_causal_norm
19
 
20
- from .model_parts import (
21
- FeedForward,
22
- PositionalEncoding,
23
- RelativePositionalAttention,
24
- )
25
-
26
- # from memory_profiler import profile
27
  logger = logging.getLogger(__name__)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class EncoderLayer(nn.Module):
31
  def __init__(
32
  self,
 
1
  """Transformer class."""
2
 
3
  import logging
4
+ import math
5
  from collections import OrderedDict
6
  from pathlib import Path
7
+ from typing import Literal, Tuple
8
 
9
  import torch
10
+ import torch.nn.functional as F
11
 
 
12
  import yaml
13
  from torch import nn
14
 
15
  import sys, os
 
16
 
17
+ from .utils import blockwise_causal_norm
 
18
 
 
 
 
 
 
 
 
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
+ def _pos_embed_fourier1d_init(
23
+ cutoff: float = 256, n: int = 32, cutoff_start: float = 1
24
+ ):
25
+ return (
26
+ torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n))
27
+ .unsqueeze(0)
28
+ .unsqueeze(0)
29
+ )
30
+
31
+
32
+ def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
33
+ # Maximum initial frequency is 1
34
+ return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
35
+
36
+
37
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
38
+ """Rotate pairs of scalars as 2d vectors by pi/2."""
39
+ x = x.unflatten(-1, (-1, 2))
40
+ x1, x2 = x.unbind(dim=-1)
41
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
42
+
43
+
44
+ class RotaryPositionalEncoding(nn.Module):
45
+ def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
46
+ super().__init__()
47
+ assert len(cutoffs) == len(n_pos)
48
+ if not all(n % 2 == 0 for n in n_pos):
49
+ raise ValueError("n_pos must be even")
50
+
51
+ self._n_dim = len(cutoffs)
52
+ self.freqs = nn.ParameterList([
53
+ nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2))
54
+ for cutoff, n in zip(cutoffs, n_pos)
55
+ ])
56
+
57
+ def get_co_si(self, coords: torch.Tensor):
58
+ _B, _N, D = coords.shape
59
+ assert D == len(self.freqs)
60
+ co = torch.cat(
61
+ tuple(
62
+ torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
63
+ for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
64
+ ),
65
+ axis=-1,
66
+ )
67
+ si = torch.cat(
68
+ tuple(
69
+ torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
70
+ for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
71
+ ),
72
+ axis=-1,
73
+ )
74
+ return co, si
75
+
76
+ def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
77
+ _B, _N, D = coords.shape
78
+ _B, _H, _N, _C = q.shape
79
+
80
+ if D != self._n_dim:
81
+ raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
82
+
83
+ co, si = self.get_co_si(coords)
84
+ co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
85
+ si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
86
+ q2 = q * co + _rotate_half(q) * si
87
+ k2 = k * co + _rotate_half(k) * si
88
+ return q2, k2
89
+
90
+
91
+ class FeedForward(nn.Module):
92
+ def __init__(self, d_model, expand: float = 2, bias: bool = True):
93
+ super().__init__()
94
+ self.fc1 = nn.Linear(d_model, int(d_model * expand))
95
+ self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias)
96
+ self.act = nn.GELU()
97
+
98
+ def forward(self, x):
99
+ return self.fc2(self.act(self.fc1(x)))
100
+
101
+
102
+ class PositionalEncoding(nn.Module):
103
+ def __init__(
104
+ self,
105
+ cutoffs: Tuple[float] = (256,),
106
+ n_pos: Tuple[int] = (32,),
107
+ cutoffs_start=None,
108
+ ):
109
+ super().__init__()
110
+ if cutoffs_start is None:
111
+ cutoffs_start = (1,) * len(cutoffs)
112
+
113
+ assert len(cutoffs) == len(n_pos)
114
+ self.freqs = nn.ParameterList([
115
+ nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
116
+ for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start)
117
+ ])
118
+
119
+ def forward(self, coords: torch.Tensor):
120
+ _B, _N, D = coords.shape
121
+ assert D == len(self.freqs)
122
+ embed = torch.cat(
123
+ tuple(
124
+ torch.cat(
125
+ (
126
+ torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq),
127
+ torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq),
128
+ ),
129
+ axis=-1,
130
+ )
131
+ / math.sqrt(len(freq))
132
+ for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
133
+ ),
134
+ axis=-1,
135
+ )
136
+ return embed
137
+
138
+
139
+ def _bin_init_exp(cutoff: float, n: int):
140
+ return torch.exp(torch.linspace(0, math.log(cutoff + 1), n))
141
+
142
+
143
+ def _bin_init_linear(cutoff: float, n: int):
144
+ return torch.linspace(-cutoff, cutoff, n)
145
+
146
+
147
+ class RelativePositionalBias(nn.Module):
148
+ def __init__(
149
+ self,
150
+ n_head: int,
151
+ cutoff_spatial: float,
152
+ cutoff_temporal: float,
153
+ n_spatial: int = 32,
154
+ n_temporal: int = 16,
155
+ ):
156
+ super().__init__()
157
+ self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial)
158
+ self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1)
159
+ self.register_buffer("spatial_bins", self._spatial_bins)
160
+ self.register_buffer("temporal_bins", self._temporal_bins)
161
+ self.n_spatial = n_spatial
162
+ self.n_head = n_head
163
+ self.bias = nn.Parameter(
164
+ -0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head)
165
+ )
166
+
167
+ def forward(self, coords: torch.Tensor):
168
+ _B, _N, _D = coords.shape
169
+ t = coords[..., 0]
170
+ yx = coords[..., 1:]
171
+ temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2)
172
+ spatial_dist = torch.cdist(yx, yx)
173
+
174
+ spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins)
175
+ torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1)
176
+ temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins)
177
+ torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1)
178
+
179
+ idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial
180
+ bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head))
181
+ bias = bias.transpose(-1, 1)
182
+ return bias
183
+
184
+
185
+ class RelativePositionalAttention(nn.Module):
186
+ def __init__(
187
+ self,
188
+ coord_dim: int,
189
+ embed_dim: int,
190
+ n_head: int,
191
+ cutoff_spatial: float = 256,
192
+ cutoff_temporal: float = 16,
193
+ n_spatial: int = 32,
194
+ n_temporal: int = 16,
195
+ dropout: float = 0.0,
196
+ mode: Literal["bias", "rope", "none"] = "bias",
197
+ attn_dist_mode: str = "v0",
198
+ ):
199
+ super().__init__()
200
+
201
+ if not embed_dim % (2 * n_head) == 0:
202
+ raise ValueError(
203
+ f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}"
204
+ )
205
+
206
+ self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True)
207
+ self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True)
208
+ self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True)
209
+ self.proj = nn.Linear(embed_dim, embed_dim)
210
+ self.dropout = dropout
211
+ self.n_head = n_head
212
+ self.embed_dim = embed_dim
213
+ self.cutoff_spatial = cutoff_spatial
214
+ self.attn_dist_mode = attn_dist_mode
215
+
216
+ if mode == "bias" or mode is True:
217
+ self.pos_bias = RelativePositionalBias(
218
+ n_head=n_head,
219
+ cutoff_spatial=cutoff_spatial,
220
+ cutoff_temporal=cutoff_temporal,
221
+ n_spatial=n_spatial,
222
+ n_temporal=n_temporal,
223
+ )
224
+ elif mode == "rope":
225
+ n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head))
226
+ self.rot_pos_enc = RotaryPositionalEncoding(
227
+ cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim),
228
+ n_pos=(embed_dim // n_head - coord_dim * n_split,)
229
+ + (n_split,) * coord_dim,
230
+ )
231
+ elif mode == "none":
232
+ pass
233
+ elif mode is None or mode is False:
234
+ logger.warning(
235
+ "attn_positional_bias is not set (None or False), no positional bias."
236
+ )
237
+ else:
238
+ raise ValueError(f"Unknown mode {mode}")
239
+
240
+ self._mode = mode
241
+
242
+ def forward(
243
+ self,
244
+ query: torch.Tensor,
245
+ key: torch.Tensor,
246
+ value: torch.Tensor,
247
+ coords: torch.Tensor,
248
+ padding_mask: torch.Tensor = None,
249
+ ):
250
+ B, N, D = query.size()
251
+ q = self.q_pro(query)
252
+ k = self.k_pro(key)
253
+ v = self.v_pro(value)
254
+ k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
255
+ q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
256
+ v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
257
+
258
+ attn_mask = torch.zeros(
259
+ (B, self.n_head, N, N), device=query.device, dtype=q.dtype
260
+ )
261
+ attn_ignore_val = -1e3
262
+
263
+ yx = coords[..., 1:]
264
+ spatial_dist = torch.cdist(yx, yx)
265
+ spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
266
+ attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
267
+
268
+ if coords is not None:
269
+ if self._mode == "bias":
270
+ attn_mask = attn_mask + self.pos_bias(coords)
271
+ elif self._mode == "rope":
272
+ q, k = self.rot_pos_enc(q, k, coords)
273
+
274
+ if self.attn_dist_mode == "v0":
275
+ dist = torch.cdist(coords, coords, p=2)
276
+ attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
277
+ elif self.attn_dist_mode == "v1":
278
+ attn_mask += torch.exp(
279
+ -5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial
280
+ )
281
+ else:
282
+ raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
283
+
284
+ if padding_mask is not None:
285
+ ignore_mask = torch.logical_or(
286
+ padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
287
+ ).unsqueeze(1)
288
+ attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
289
+
290
+ y = F.scaled_dot_product_attention(
291
+ q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
292
+ )
293
+ y = y.transpose(1, 2).contiguous().view(B, N, D)
294
+ y = self.proj(y)
295
+ return y
296
+
297
+
298
  class EncoderLayer(nn.Module):
299
  def __init__(
300
  self,
models/tra_post_model/trackastra/data/__init__.py DELETED
@@ -1,18 +0,0 @@
1
- # ruff: noqa: F401
2
-
3
- from .augmentations import AugmentationPipeline, RandomCrop
4
- from .data import (
5
- CTCData,
6
- _ctc_lineages,
7
- # load_ctc_data_from_subfolders,
8
- collate_sequence_padding,
9
- extract_features_regionprops,
10
- )
11
- from .distributed import (
12
- BalancedBatchSampler,
13
- # BalancedDataModule,
14
- BalancedDistributedSampler,
15
- )
16
- from .example_data import example_data_bacteria, example_data_fluo_3d, example_data_hela, data_hela
17
- from .utils import filter_track_df, load_tiff_timeseries, load_tracklet_links
18
- from .wrfeat import WRFeatures, build_windows, get_features, build_windows_sd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/_check_ctc.py DELETED
@@ -1,114 +0,0 @@
1
- import logging
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from skimage.measure import label, regionprops_table
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
-
10
- # from https://github.com/Janelia-Trackathon-2023/traccuracy/blob/main/src/traccuracy/loaders/_ctc.py
11
- def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray):
12
- """Sanity checks for valid CTC format.
13
-
14
- Hard checks (throws exception):
15
- - Tracklet IDs in tracks file must be unique and positive
16
- - Parent tracklet IDs must exist in the tracks file
17
- - Intertracklet edges must be directed forward in time.
18
- - In each time point, the set of segmentation IDs present in the detections must equal the set
19
- of tracklet IDs in the tracks file that overlap this time point.
20
-
21
- Soft checks (prints warning):
22
- - No duplicate tracklet IDs (non-connected pixels with same ID) in a single timepoint.
23
-
24
- Args:
25
- tracks (pd.DataFrame): Tracks in CTC format with columns Cell_ID, Start, End, Parent_ID.
26
- detections (pd.DataFrame): Detections extracted from masks, containing columns
27
- segmentation_id, t.
28
- masks (np.ndarray): Set of masks with time in the first axis.
29
-
30
- Raises:
31
- ValueError: If any of the hard checks fail.
32
- """
33
- logger.debug("Running CTC format checks")
34
- tracks = tracks.copy()
35
- tracks.columns = ["Cell_ID", "Start", "End", "Parent_ID"]
36
- if tracks["Cell_ID"].min() < 1:
37
- raise ValueError("Cell_IDs in tracks file must be positive integers.")
38
- if len(tracks["Cell_ID"]) < len(tracks["Cell_ID"].unique()):
39
- raise ValueError("Cell_IDs in tracks file must be unique integers.")
40
-
41
- for _, row in tracks.iterrows():
42
- if row["Parent_ID"] != 0:
43
- if row["Parent_ID"] not in tracks["Cell_ID"].values:
44
- raise ValueError(
45
- f"Parent_ID {row['Parent_ID']} is not present in tracks."
46
- )
47
- parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0]
48
- if parent_end >= row["Start"]:
49
- raise ValueError(
50
- f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} "
51
- f"starts at t={row['Start']}, "
52
- f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}."
53
- )
54
-
55
- for t in range(tracks["Start"].min(), tracks["End"].max()):
56
- track_ids = set(
57
- tracks[(tracks["Start"] <= t) & (tracks["End"] >= t)]["Cell_ID"]
58
- )
59
- det_ids = set(detections[(detections["t"] == t)]["segmentation_id"])
60
- if not track_ids.issubset(det_ids):
61
- raise ValueError(f"Missing IDs in masks at t={t}: {track_ids - det_ids}")
62
- if not det_ids.issubset(track_ids):
63
- raise ValueError(
64
- f"IDs {det_ids - track_ids} at t={t} not represented in tracks file."
65
- )
66
-
67
- for t, frame in enumerate(masks):
68
- _, n_components = label(frame, return_num=True)
69
- n_labels = len(detections[detections["t"] == t])
70
- if n_labels < n_components:
71
- logger.warning(f"{n_components - n_labels} non-connected masks at t={t}.")
72
-
73
-
74
- def _get_node_attributes(masks):
75
- """Calculates x,y,z,t,label for each detection in a movie.
76
-
77
- Args:
78
- masks (np.ndarray): Set of masks with time in the first axis
79
-
80
- Returns:
81
- pd.DataFrame: Dataframe with one detection per row. Columns
82
- segmentation_id, x, y, z, t
83
- """
84
- data_df = pd.concat([
85
- _detections_from_image(masks, idx) for idx in range(masks.shape[0])
86
- ]).reset_index(drop=True)
87
- data_df = data_df.rename(
88
- columns={
89
- "label": "segmentation_id",
90
- "centroid-2": "z",
91
- "centroid-1": "y",
92
- "centroid-0": "x",
93
- }
94
- )
95
- data_df["segmentation_id"] = data_df["segmentation_id"].astype(int)
96
- data_df["t"] = data_df["t"].astype(int)
97
- return data_df
98
-
99
-
100
- def _detections_from_image(stack, idx):
101
- """Return the unique track label, centroid and time for each track vertex.
102
-
103
- Args:
104
- stack (np.ndarray): Stack of masks
105
- idx (int): Index of the image to calculate the centroids and track labels
106
-
107
- Returns:
108
- pd.DataFrame: The dataframe of track data for one time step (specified by idx)
109
- """
110
- props = regionprops_table(
111
- np.asarray(stack[idx, ...]), properties=("label", "centroid")
112
- )
113
- props["t"] = np.full(props["label"].shape, idx)
114
- return pd.DataFrame(props)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/_compute_overlap.py DELETED
@@ -1,209 +0,0 @@
1
- """Adapted from Fast R-CNN
2
- Written by Sergey Karayev
3
- Licensed under The MIT License
4
- Copyright (c) 2015 Microsoft.
5
- """
6
-
7
- import numpy as np
8
- from skimage.measure import regionprops
9
-
10
-
11
- def _union_slice(a: tuple[slice], b: tuple[slice]):
12
- """Returns the union of slice tuples a and b."""
13
- starts = tuple(min(_a.start, _b.start) for _a, _b in zip(a, b))
14
- stops = tuple(max(_a.stop, _b.stop) for _a, _b in zip(a, b))
15
- return tuple(slice(start, stop) for start, stop in zip(starts, stops))
16
-
17
-
18
- def get_labels_with_overlap(gt_frame, res_frame):
19
- """Get all labels IDs in gt_frame and res_frame whose bounding boxes
20
- overlap.
21
-
22
- Args:
23
- gt_frame (np.ndarray): ground truth segmentation for a single frame
24
- res_frame (np.ndarray): result segmentation for a given frame
25
-
26
- Returns:
27
- overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes
28
- overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes
29
- intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area)
30
- """
31
- gt_frame = gt_frame.astype(np.uint16, copy=False)
32
- res_frame = res_frame.astype(np.uint16, copy=False)
33
- gt_props = regionprops(gt_frame)
34
- gt_boxes = [np.array(gt_prop.bbox) for gt_prop in gt_props]
35
- gt_boxes = np.array(gt_boxes).astype(np.float64)
36
- gt_box_labels = np.asarray(
37
- [int(gt_prop.label) for gt_prop in gt_props], dtype=np.uint16
38
- )
39
-
40
- res_props = regionprops(res_frame)
41
- res_boxes = [np.array(res_prop.bbox) for res_prop in res_props]
42
- res_boxes = np.array(res_boxes).astype(np.float64)
43
- res_box_labels = np.asarray(
44
- [int(res_prop.label) for res_prop in res_props], dtype=np.uint16
45
- )
46
- if len(gt_props) == 0 or len(res_props) == 0:
47
- return [], [], []
48
-
49
- if gt_frame.ndim == 3:
50
- overlaps = compute_overlap_3D(gt_boxes, res_boxes)
51
- else:
52
- overlaps = compute_overlap(
53
- gt_boxes, res_boxes
54
- ) # has the form [gt_bbox, res_bbox]
55
-
56
- # Find the bboxes that have overlap at all (ind_ corresponds to box number - starting at 0)
57
- ind_gt, ind_res = np.nonzero(overlaps)
58
- ind_gt = np.asarray(ind_gt, dtype=np.uint16)
59
- ind_res = np.asarray(ind_res, dtype=np.uint16)
60
- overlapping_gt_labels = gt_box_labels[ind_gt]
61
- overlapping_res_labels = res_box_labels[ind_res]
62
-
63
- intersections_over_gt = []
64
- for i, j in zip(ind_gt, ind_res):
65
- sslice = _union_slice(gt_props[i].slice, res_props[j].slice)
66
- gt_mask = gt_frame[sslice] == gt_box_labels[i]
67
- res_mask = res_frame[sslice] == res_box_labels[j]
68
- area_inter = np.count_nonzero(np.logical_and(gt_mask, res_mask))
69
- area_gt = np.count_nonzero(gt_mask)
70
- intersections_over_gt.append(area_inter / area_gt)
71
-
72
- return overlapping_gt_labels, overlapping_res_labels, intersections_over_gt
73
-
74
-
75
- def compute_overlap(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray:
76
- """Args:
77
- a: (N, 4) ndarray of float
78
- b: (K, 4) ndarray of float.
79
-
80
- Returns:
81
- overlaps: (N, K) ndarray of overlap between boxes and query_boxes
82
- """
83
- N = boxes.shape[0]
84
- K = query_boxes.shape[0]
85
- overlaps = np.zeros((N, K), dtype=np.float64)
86
- for k in range(K):
87
- box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) * (
88
- query_boxes[k, 3] - query_boxes[k, 1] + 1
89
- )
90
- for n in range(N):
91
- iw = (
92
- min(boxes[n, 2], query_boxes[k, 2])
93
- - max(boxes[n, 0], query_boxes[k, 0])
94
- + 1
95
- )
96
- if iw > 0:
97
- ih = (
98
- min(boxes[n, 3], query_boxes[k, 3])
99
- - max(boxes[n, 1], query_boxes[k, 1])
100
- + 1
101
- )
102
- if ih > 0:
103
- ua = np.float64(
104
- (boxes[n, 2] - boxes[n, 0] + 1)
105
- * (boxes[n, 3] - boxes[n, 1] + 1)
106
- + box_area
107
- - iw * ih
108
- )
109
- overlaps[n, k] = iw * ih / ua
110
- return overlaps
111
-
112
-
113
- def compute_overlap_3D(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray:
114
- """Args:
115
- a: (N, 6) ndarray of float
116
- b: (K, 6) ndarray of float.
117
-
118
- Returns:
119
- overlaps: (N, K) ndarray of overlap between boxes and query_boxes
120
- """
121
- N = boxes.shape[0]
122
- K = query_boxes.shape[0]
123
- overlaps = np.zeros((N, K), dtype=np.float64)
124
- for k in range(K):
125
- box_volume = (
126
- (query_boxes[k, 3] - query_boxes[k, 0] + 1)
127
- * (query_boxes[k, 4] - query_boxes[k, 1] + 1)
128
- * (query_boxes[k, 5] - query_boxes[k, 2] + 1)
129
- )
130
- for n in range(N):
131
- id_ = (
132
- min(boxes[n, 3], query_boxes[k, 3])
133
- - max(boxes[n, 0], query_boxes[k, 0])
134
- + 1
135
- )
136
- if id_ > 0:
137
- iw = (
138
- min(boxes[n, 4], query_boxes[k, 4])
139
- - max(boxes[n, 1], query_boxes[k, 1])
140
- + 1
141
- )
142
- if iw > 0:
143
- ih = (
144
- min(boxes[n, 5], query_boxes[k, 5])
145
- - max(boxes[n, 2], query_boxes[k, 2])
146
- + 1
147
- )
148
- if ih > 0:
149
- ua = np.float64(
150
- (boxes[n, 3] - boxes[n, 0] + 1)
151
- * (boxes[n, 4] - boxes[n, 1] + 1)
152
- * (boxes[n, 5] - boxes[n, 2] + 1)
153
- + box_volume
154
- - iw * ih * id_
155
- )
156
- overlaps[n, k] = iw * ih * id_ / ua
157
- return overlaps
158
-
159
-
160
- try:
161
- import numba
162
- except ImportError:
163
- import os
164
- import warnings
165
-
166
- if not os.getenv("NO_JIT_WARNING", False):
167
- warnings.warn(
168
- "Numba not installed, falling back to slower numpy implementation. "
169
- "Install numba for a significant speedup. Set the environment "
170
- "variable NO_JIT_WARNING=1 to disable this warning.",
171
- stacklevel=2,
172
- )
173
- else:
174
- # compute_overlap 2d and 3d have the same signature
175
- signature = [
176
- "f8[:,::1](f8[:,::1], f8[:,::1])",
177
- numba.types.Array(numba.float64, 2, "C", readonly=True)(
178
- numba.types.Array(numba.float64, 2, "C", readonly=True),
179
- numba.types.Array(numba.float64, 2, "C", readonly=True),
180
- ),
181
- ]
182
-
183
- # variables that appear in the body of each function
184
- common_locals = {
185
- "N": numba.uint64,
186
- "K": numba.uint64,
187
- "overlaps": numba.types.Array(numba.float64, 2, "C"),
188
- "iw": numba.float64,
189
- "ih": numba.float64,
190
- "ua": numba.float64,
191
- "n": numba.uint64,
192
- "k": numba.uint64,
193
- }
194
-
195
- compute_overlap = numba.njit(
196
- signature,
197
- locals={**common_locals, "box_area": numba.float64},
198
- fastmath=True,
199
- nogil=True,
200
- boundscheck=False,
201
- )(compute_overlap)
202
-
203
- compute_overlap_3D = numba.njit(
204
- signature,
205
- locals={**common_locals, "id_": numba.float64, "box_volume": numba.float64},
206
- fastmath=True,
207
- nogil=True,
208
- boundscheck=False,
209
- )(compute_overlap_3D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/augmentations.py DELETED
@@ -1,557 +0,0 @@
1
- """#TODO: dont convert to numpy and back to torch."""
2
-
3
- from collections.abc import Iterable, Sequence
4
- from itertools import chain
5
- from typing import Any
6
-
7
- import kornia.augmentation as K
8
- import numpy as np
9
- import torch
10
- from kornia.augmentation import random_generator as rg
11
- from kornia.augmentation.utils import _range_bound
12
- from kornia.constants import DataKey, Resample
13
- from typing import Optional, Tuple, Sequence, Dict, Union
14
-
15
- def default_augmenter(coords: np.ndarray):
16
- # TODO parametrize magnitude of different augmentations
17
- ndim = coords.shape[1]
18
-
19
- assert coords.ndim == 2 and ndim in (2, 3)
20
-
21
- # first remove offset
22
- center = coords.mean(axis=0, keepdims=True)
23
-
24
- coords = coords - center
25
-
26
- # apply random flip
27
- coords *= 2 * np.random.randint(0, 2, (1, ndim)) - 1
28
-
29
- # apply rotation along the last two dimensions
30
- phi = np.random.uniform(0, 2 * np.pi)
31
- coords = _rotate(coords, phi, center=None)
32
-
33
- if ndim == 3:
34
- # rotate along the first two dimensions too
35
- phi2, phi3 = np.random.uniform(0, 2 * np.pi, 2)
36
- coords = _rotate(coords, phi2, rot_axis=(0, 1), center=None)
37
- coords = _rotate(coords, phi3, rot_axis=(0, 2), center=None)
38
-
39
- coords += center
40
-
41
- # translation
42
- trans = 128 * np.random.uniform(-1, 1, (1, ndim))
43
- coords += trans
44
-
45
- # elastic
46
- coords += 1.5 * np.random.normal(0, 1, coords.shape)
47
-
48
- return coords
49
-
50
-
51
- def _rotate(
52
- coords: np.ndarray, phi: float, rot_axis=(-2, -1), center: Optional[Tuple] = None
53
- ):
54
- """Rotation along the last two dimensions of coords[..,:-2:]."""
55
- ndim = coords.shape[1]
56
- assert coords.ndim == 2 and ndim in (2, 3)
57
-
58
- if center is None:
59
- center = (0,) * ndim
60
-
61
- assert len(center) == ndim
62
-
63
- center = np.asarray(center)
64
- co, si = np.cos(phi), np.sin(phi)
65
- Rot = np.eye(ndim)
66
- Rot[np.ix_(rot_axis, rot_axis)] = np.array(((co, -si), (si, co)))
67
- x = coords - center
68
- x = x @ Rot.T
69
- x += center
70
- return x
71
-
72
-
73
- def _filter_points(
74
- points: np.ndarray, shape: tuple, origin: Optional[Tuple] = None
75
- ) -> np.ndarray:
76
- """Returns indices of points that are inside the shape extent and given origin."""
77
- ndim = points.shape[-1]
78
- if origin is None:
79
- origin = (0,) * ndim
80
-
81
- idx = tuple(
82
- np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i])
83
- for i in range(ndim)
84
- )
85
- idx = np.where(np.all(idx, axis=0))[0]
86
- return idx
87
-
88
-
89
- class ConcatAffine(K.RandomAffine):
90
- """Concatenate multiple affine transformations without intermediates."""
91
-
92
- def __init__(self, affines: Sequence[K.RandomAffine]):
93
- super().__init__(degrees=0)
94
- self._affines = affines
95
- if not all([a.same_on_batch for a in affines]):
96
- raise ValueError("all affines must have same_on_batch=True")
97
-
98
- def merge_params(self, params: Sequence[Dict[str, torch.Tensor]]):
99
- """Merge params from affines."""
100
- out = params[0].copy()
101
-
102
- def _torchmax(x, dim):
103
- return torch.max(x, dim=dim).values
104
-
105
- ops = {
106
- "translations": torch.sum,
107
- "center": torch.mean,
108
- "scale": torch.prod,
109
- "shear_x": torch.sum,
110
- "shear_y": torch.sum,
111
- "angle": torch.sum,
112
- "batch_prob": _torchmax,
113
- }
114
- for k, v in params[0].items():
115
- ps = [p[k] for p in params if len(p[k]) > 0]
116
- if len(ps) > 0 and k in ops:
117
- v_new = torch.stack(ps, dim=0).float()
118
- v_new = ops[k](v_new, dim=0)
119
- v_new = v_new.to(v.dtype)
120
- else:
121
- v_new = v
122
- out[k] = v_new
123
-
124
- return out
125
-
126
- def forward_parameters(
127
- self, batch_shape: Tuple[int, ...]
128
- ) -> Dict[str, torch.Tensor]:
129
- params = tuple(a.forward_parameters(batch_shape) for a in self._affines)
130
- # print(params)
131
- return self.merge_params(params)
132
-
133
-
134
- # custom augmentations
135
- class RandomIntensityScaleShift(K.IntensityAugmentationBase2D):
136
- r"""Apply a random scale and shift to the image intensity.
137
-
138
- Args:
139
- p: probability of applying the transformation.
140
- scale: the scale factor to apply
141
- shift: the offset to apply
142
- clip_output: if true clip output
143
- same_on_batch: apply the same transformation across the batch.
144
- keepdim: whether to keep the output shape the same as input (True) or broadcast it
145
- to the batch form (False).
146
- Shape:
147
- - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)`
148
- - Output: :math:`(B, C, H, W)`
149
-
150
- .. note::
151
- This function internally uses :func:`kornia.enhance.adjust_brightness`
152
-
153
- """
154
-
155
- def __init__(
156
- self,
157
- scale: Tuple[float, float] = (0.5, 2.0),
158
- shift: Tuple[float, float] = (-0.1, 0.1),
159
- clip_output: bool = True,
160
- same_on_batch: bool = False,
161
- p: float = 1.0,
162
- keepdim: bool = False,
163
- ) -> None:
164
- super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
165
- self.scale = _range_bound(
166
- scale, "scale", center=0, bounds=(-float("inf"), float("inf"))
167
- )
168
- self.shift = _range_bound(
169
- shift, "shift", center=0, bounds=(-float("inf"), float("inf"))
170
- )
171
- self._param_generator = rg.PlainUniformGenerator(
172
- (self.scale, "scale_factor", None, None),
173
- (self.shift, "shift_factor", None, None),
174
- )
175
-
176
- self.clip_output = clip_output
177
-
178
- def apply_transform(
179
- self,
180
- input: torch.Tensor,
181
- params: Dict[str, torch.Tensor],
182
- flags: Dict[str, Any],
183
- transform: Optional[torch.Tensor] = None,
184
- ) -> torch.Tensor:
185
- scale_factor = params["scale_factor"].to(input)
186
- shift_factor = params["shift_factor"].to(input)
187
- scale_factor = scale_factor.view(len(scale_factor), 1, 1, 1)
188
- shift_factor = shift_factor.view(len(scale_factor), 1, 1, 1)
189
- img_adjust = input * scale_factor + shift_factor
190
- if self.clip_output:
191
- img_adjust = img_adjust.clamp(min=0.0, max=1.0)
192
- return img_adjust
193
-
194
-
195
- class RandomTemporalAffine(K.RandomAffine):
196
- r"""Apply a random 2D affine transformation to a batch of images while
197
- varying the transformation across the time dimension from 0 to 1.
198
-
199
- Same args/kwargs as K.RandomAffine
200
-
201
- """
202
-
203
- def __init__(self, *args, **kwargs) -> None:
204
- super().__init__(*args, same_on_batch=True, **kwargs)
205
-
206
- def forward_parameters(
207
- self, batch_shape: Tuple[int, ...]
208
- ) -> Dict[str, torch.Tensor]:
209
- params = super().forward_parameters(batch_shape)
210
- factor = torch.linspace(0, 1, batch_shape[0]).to(params["translations"])
211
- for key in ["translations", "center", "angle", "shear_x", "shear_y"]:
212
- v = params[key]
213
- if len(v) > 0:
214
- params[key] = v * factor.view(*((-1,) + (1,) * len(v.shape[1:])))
215
-
216
- for key in [
217
- "scale",
218
- ]:
219
- v = params[key]
220
- if len(v) > 0:
221
- params[key] = 1 + (v - 1) * factor.view(
222
- *((-1,) + (1,) * len(v.shape[1:]))
223
- )
224
- return params
225
-
226
- # def compute_transformation(self, input: torch.Tensor,
227
- # params: Dict[str, torch.Tensor],
228
- # flags: Dict[str, Any]) -> torch.Tensor:
229
- # factor = torch.linspace(0, 1, input.shape[0]).to(input)
230
- # for key in ["translations", "center", "angle", "shear_x", "shear_y"]:
231
- # v = params[key]
232
- # params[key] = v * factor.view(*((-1,)+(1,)*len(v.shape[1:])))
233
-
234
- # for key in ["scale", ]:
235
- # v = params[key]
236
- # params[key] = 1 + (v-1) * factor.view(*((-1,)+(1,)*len(v.shape[1:])))
237
-
238
- # return super().compute_transformation(input, params, flags)
239
-
240
-
241
- class BasicPipeline:
242
- """transforms img, mask, and points.
243
-
244
- Only supports 2D transformations for now (any 3D object will preserve its z coordinates/dimensions)
245
- """
246
-
247
- def __init__(self, augs: tuple, filter_points: bool = True):
248
- self.data_keys = ("input", "mask", "keypoints")
249
- self.pipeline = K.AugmentationSequential(
250
- *augs,
251
- # disable align_corners to not trigger lots of warnings from kornia
252
- extra_args={
253
- DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": False}
254
- },
255
- data_keys=self.data_keys,
256
- )
257
- self.filter_points = filter_points
258
-
259
- def __call__(
260
- self,
261
- img: np.ndarray,
262
- mask: np.ndarray,
263
- points: np.ndarray,
264
- timepoints: np.ndarray,
265
- ):
266
- ndim = img.ndim - 1
267
- assert (
268
- ndim in (2, 3)
269
- and points.ndim == 2
270
- and points.shape[-1] == ndim
271
- and timepoints.ndim == 1
272
- and img.shape == mask.shape
273
- )
274
-
275
- x = torch.from_numpy(img).float()
276
- y = torch.from_numpy(mask.astype(np.int64)).float()
277
-
278
- # if 2D add dummy channel
279
- if ndim == 2:
280
- x = x.unsqueeze(1)
281
- y = y.unsqueeze(1)
282
- p = points[..., [1, 0]]
283
- # if 3D we use z as channel (i.e. fix augs across z)
284
- elif ndim == 3:
285
- p = points[..., [2, 1]]
286
-
287
- # flip as kornia expects xy and not yx
288
- p = torch.from_numpy(p).unsqueeze(0).float()
289
- # add batch by duplicating to make kornia happy
290
- p = p.expand(len(x), -1, -1)
291
- # create a mask to know which timepoint the points belong to
292
- ts = torch.from_numpy(timepoints).long()
293
- n_points = p.shape[1]
294
- if n_points > 0:
295
- x, y, p = self.pipeline(x, y, p)
296
- else:
297
- # dummy keypoints
298
- x, y = self.pipeline(x, y, torch.zeros((len(x), 1, 2)))[:2]
299
-
300
- # remove batch
301
- p = p[ts, torch.arange(n_points)]
302
- # flip back
303
- p = p[..., [1, 0]]
304
-
305
- # remove channel
306
- if ndim == 2:
307
- x = x.squeeze(1)
308
- y = y.squeeze(1)
309
-
310
- x = x.numpy()
311
- y = y.numpy().astype(np.uint16)
312
- # p = p.squeeze(0).numpy()
313
- p = p.numpy()
314
- # add back z coordinates
315
- if ndim == 3:
316
- p = np.concatenate([points[..., 0:1], p], axis=-1)
317
- ts = ts.numpy()
318
- # remove points outside of img/mask
319
-
320
- if self.filter_points:
321
- idx = _filter_points(p, shape=x.shape[-ndim:])
322
-
323
- else:
324
- idx = np.arange(len(p), dtype=int)
325
-
326
- p = p[idx]
327
- return (x, y, p), idx
328
-
329
-
330
- class RandomCrop:
331
- def __init__(
332
- self,
333
- crop_size: Optional[Union[int, Tuple[int]]] = None,
334
- ndim: int = 2,
335
- ensure_inside_points: bool = False,
336
- use_padding: bool = True,
337
- padding_mode="constant",
338
- ) -> None:
339
- """crop_size: tuple of int
340
- can be tuple of length 1 (all dimensions)
341
- of length ndim (y,x,...)
342
- of length 2*ndim (y1,y2, x1,x2, ...).
343
- """
344
- if isinstance(crop_size, int):
345
- crop_size = (crop_size,) * 2 * ndim
346
- elif isinstance(crop_size, Iterable):
347
- pass
348
- else:
349
- raise ValueError(f"{crop_size} has to be int or tuple of int")
350
-
351
- if len(crop_size) == 1:
352
- crop_size = (crop_size[0],) * 2 * ndim
353
- elif len(crop_size) == ndim:
354
- crop_size = tuple(chain(*tuple((c, c) for c in crop_size)))
355
- elif len(crop_size) == 2 * ndim:
356
- pass
357
- else:
358
- raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}")
359
-
360
- crop_size = np.array(crop_size)
361
- self._ndim = ndim
362
- self._crop_bounds = crop_size[::2], crop_size[1::2]
363
- self._use_padding = use_padding
364
- self._ensure_inside_points = ensure_inside_points
365
- self._rng = np.random.RandomState()
366
- self._padding_mode = padding_mode
367
-
368
- def crop_img(self, img: np.ndarray, corner: np.ndarray, crop_size: np.ndarray):
369
- if not img.ndim == self._ndim + 1:
370
- raise ValueError(
371
- f"img has to be 1 (time) + {self._ndim} spatial dimensions"
372
- )
373
-
374
- pad_left = np.maximum(0, -corner)
375
- pad_right = np.maximum(
376
- 0, corner + crop_size - np.array(img.shape[-self._ndim :])
377
- )
378
-
379
- img = np.pad(
380
- img,
381
- ((0, 0), *tuple(np.stack((pad_left, pad_right)).T)),
382
- mode=self._padding_mode,
383
- )
384
- slices = (
385
- slice(None),
386
- *tuple(slice(c, c + s) for c, s in zip(corner + pad_left, crop_size)),
387
- )
388
- return img[slices]
389
-
390
- def crop_points(
391
- self, points: np.ndarray, corner: np.ndarray, crop_size: np.ndarray
392
- ):
393
- idx = _filter_points(points, shape=crop_size, origin=corner)
394
- return points[idx] - corner, idx
395
-
396
- def __call__(self, img: np.ndarray, mask: np.ndarray, points: np.ndarray):
397
- assert (
398
- img.ndim == self._ndim + 1
399
- and points.ndim == 2
400
- and points.shape[-1] == self._ndim
401
- and img.shape == mask.shape
402
- )
403
-
404
- points = points.astype(int)
405
-
406
- crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1)
407
- # print(f'{crop_size=}')
408
-
409
- if self._ensure_inside_points:
410
- if len(points) == 0:
411
- print("No points given, cannot ensure inside points")
412
- return (img, mask, points), np.zeros((0,), int)
413
-
414
- # sample point and corner relative to it
415
-
416
- _idx = np.random.randint(len(points))
417
- corner = (
418
- points[_idx]
419
- - crop_size
420
- + 1
421
- + self._rng.randint(crop_size // 4, 3 * crop_size // 4)
422
- )
423
- else:
424
- corner = self._rng.randint(
425
- 0, np.maximum(1, np.array(img.shape[-self._ndim :]) - crop_size)
426
- )
427
-
428
- if not self._use_padding:
429
- corner = np.maximum(0, corner)
430
- crop_size = np.minimum(
431
- crop_size, np.array(img.shape[-self._ndim :]) - corner
432
- )
433
-
434
- img = self.crop_img(img, corner, crop_size)
435
- mask = self.crop_img(mask, corner, crop_size)
436
- points, idx = self.crop_points(points, corner, crop_size)
437
-
438
- return (img, mask, points), idx
439
-
440
-
441
- class AugmentationPipeline(BasicPipeline):
442
- """transforms img, mask, and points."""
443
-
444
- def __init__(self, p=0.5, filter_points=True, level=1):
445
- if level == 1:
446
- augs = [
447
- # Augmentations for all images in a window
448
- K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
449
- K.RandomVerticalFlip(p=0.5, same_on_batch=True),
450
- K.RandomAffine(
451
- degrees=180,
452
- shear=(-10, 10, -10, 10), # x_min, x_max, y_min, y_max
453
- translate=(0.05, 0.05),
454
- scale=(0.8, 1.2), # x_min, x_max, y_min, y_max
455
- p=p,
456
- same_on_batch=True,
457
- ),
458
- K.RandomBrightness(
459
- (0.5, 1.5), clip_output=False, p=p, same_on_batch=True
460
- ),
461
- K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
462
- ]
463
- elif level == 2:
464
- # Crafted for DeepCell crop size 256
465
- augs = [
466
- # Augmentations for all images in a window
467
- K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
468
- K.RandomVerticalFlip(p=0.5, same_on_batch=True),
469
- K.RandomAffine(
470
- degrees=180,
471
- shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max
472
- translate=(0.03, 0.03),
473
- scale=(0.8, 1.2), # isotropic
474
- p=p,
475
- same_on_batch=True,
476
- ),
477
- # Anisotropic scaling
478
- K.RandomAffine(
479
- degrees=0,
480
- scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max
481
- p=p,
482
- same_on_batch=True,
483
- ),
484
- # Independet augmentations for each image in window
485
- K.RandomAffine(
486
- degrees=3,
487
- shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max
488
- translate=(0.04, 0.04),
489
- p=p,
490
- same_on_batch=False,
491
- ),
492
- # not implemented for points in kornia 0.7.0
493
- # K.RandomElasticTransform(alpha=50, sigma=5, p=p, same_on_batch=False),
494
- # Intensity-based augmentations
495
- K.RandomBrightness(
496
- (0.5, 1.5), clip_output=False, p=p, same_on_batch=True
497
- ),
498
- K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
499
- ]
500
- elif level == 3:
501
- # Crafted for DeepCell crop size 256
502
- augs = [
503
- # Augmentations for all images in a window
504
- K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
505
- K.RandomVerticalFlip(p=0.5, same_on_batch=True),
506
- ConcatAffine([
507
- K.RandomAffine(
508
- degrees=180,
509
- shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max
510
- translate=(0.03, 0.03),
511
- scale=(0.8, 1.2), # isotropic
512
- p=p,
513
- same_on_batch=True,
514
- ),
515
- # Anisotropic scaling
516
- K.RandomAffine(
517
- degrees=0,
518
- scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max
519
- p=p,
520
- same_on_batch=True,
521
- ),
522
- ]),
523
- RandomTemporalAffine(
524
- degrees=10,
525
- translate=(0.05, 0.05),
526
- p=p,
527
- # same_on_batch=True,
528
- ),
529
- # Independet augmentations for each image in window
530
- K.RandomAffine(
531
- degrees=2,
532
- shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max
533
- translate=(0.01, 0.01),
534
- p=0.5 * p,
535
- same_on_batch=False,
536
- ),
537
- # Intensity-based augmentations
538
- RandomIntensityScaleShift(
539
- (0.5, 2.0), (-0.1, 0.1), clip_output=False, p=p, same_on_batch=True
540
- ),
541
- K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
542
- ]
543
- elif level == 4:
544
- # debug
545
- augs = [
546
- K.RandomAffine(
547
- degrees=30,
548
- shear=(-0, 0, -0, 0), # x_min, x_max, y_min, y_max
549
- translate=(0.0, 0.0),
550
- p=1,
551
- same_on_batch=True,
552
- ),
553
- ]
554
- else:
555
- raise ValueError(f"level {level} not supported")
556
-
557
- super().__init__(augs, filter_points)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/data.py DELETED
@@ -1,1509 +0,0 @@
1
- import logging
2
- # from collections.abc import Sequence
3
- from pathlib import Path
4
- from timeit import default_timer
5
- from typing import Literal
6
-
7
- import joblib
8
- import lz4.frame
9
- import networkx as nx
10
- import numpy as np
11
- import pandas as pd
12
- import tifffile
13
- import torch
14
- from numba import njit
15
- from scipy import ndimage as ndi
16
- from scipy.spatial.distance import cdist
17
- from skimage.measure import regionprops
18
- from skimage.segmentation import relabel_sequential
19
- from torch.utils.data import Dataset
20
- from tqdm import tqdm
21
-
22
- from . import wrfeat
23
- from ._check_ctc import _check_ctc, _get_node_attributes
24
- from .augmentations import (
25
- AugmentationPipeline,
26
- RandomCrop,
27
- default_augmenter,
28
- )
29
- from .features import (
30
- _PROPERTIES,
31
- extract_features_patch,
32
- extract_features_regionprops,
33
- )
34
- from .matching import matching
35
-
36
- from typing import List, Optional, Union, Tuple, Sequence
37
-
38
- # from ..utils import blockwise_sum, normalize
39
- from ..utils import blockwise_sum, normalize
40
-
41
- logger = logging.getLogger(__name__)
42
- logger.setLevel(logging.INFO)
43
-
44
-
45
- def _filter_track_df(df, start_frame, end_frame, downscale):
46
- """Only keep tracklets that are present in the given time interval."""
47
- # only retain cells in interval
48
- df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)]
49
-
50
- # shift start and end of each cell
51
- df.t1 = df.t1 - start_frame
52
- df.t2 = df.t2 - start_frame
53
- # set start/end to min/max
54
- df.t1 = df.t1.clip(0, end_frame - start_frame - 1)
55
- df.t2 = df.t2.clip(0, end_frame - start_frame - 1)
56
- # set all parents to 0 that are not in the interval
57
- df.loc[~df.parent.isin(df.label), "parent"] = 0
58
-
59
- if downscale > 1:
60
- if start_frame % downscale != 0:
61
- raise ValueError("start_frame must be a multiple of downscale")
62
-
63
- logger.info(f"Temporal downscaling of tracklet links by {downscale}")
64
-
65
- # remove tracklets that have been fully deleted by temporal downsampling
66
-
67
- mask = (
68
- # (df["t2"] - df["t1"] < downscale - 1)
69
- (df["t1"] % downscale != 0)
70
- & (df["t2"] % downscale != 0)
71
- & (df["t1"] // downscale == df["t2"] // downscale)
72
- )
73
- logger.info(
74
- f"Remove {mask.sum()} tracklets that are fully deleted by downsampling"
75
- )
76
- logger.debug(f"Remove {df[mask]}")
77
-
78
- df = df[~mask]
79
- # set parent to 0 if it has been deleted
80
- df.loc[~df.parent.isin(df.label), "parent"] = 0
81
-
82
- df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int)
83
- df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int)
84
-
85
- # Correct for edge case of single frame tracklet
86
- assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"]))
87
-
88
- return df
89
-
90
-
91
- class _CompressedArray:
92
- """a simple class to compress and decompress a numpy arrays using lz4."""
93
-
94
- # dont compress float types
95
- def __init__(self, data):
96
- self._data = lz4.frame.compress(data)
97
- self._dtype = data.dtype.type
98
- self._shape = data.shape
99
-
100
- def decompress(self):
101
- s = lz4.frame.decompress(self._data)
102
- data = np.frombuffer(s, dtype=self._dtype).reshape(self._shape)
103
- return data
104
-
105
-
106
- def debug_function(f):
107
- def wrapper(*args, **kwargs):
108
- try:
109
- batch = f(*args, **kwargs)
110
- except Exception as e:
111
- logger.error(f"Error in {f.__name__}: {e}")
112
- return None
113
- logger.info(f"XXXX {len(batch['coords'])}")
114
- return batch
115
-
116
- return wrapper
117
-
118
-
119
- class CTCData(Dataset):
120
- def __init__(
121
- self,
122
- root: str = "",
123
- ndim: int = 2,
124
- use_gt: bool = True,
125
- detection_folders: List[str] = ["TRA"],
126
- window_size: int = 10,
127
- max_tokens: Optional[int] = None,
128
- slice_pct: tuple = (0.0, 1.0),
129
- downscale_spatial: int = 1,
130
- downscale_temporal: int = 1,
131
- augment: int = 0,
132
- features: Literal[
133
- "none",
134
- "regionprops",
135
- "regionprops2",
136
- "patch",
137
- "patch_regionprops",
138
- "wrfeat",
139
- ] = "wrfeat",
140
- sanity_dist: bool = False,
141
- crop_size: Optional[tuple] = None,
142
- return_dense: bool = False,
143
- compress: bool = False,
144
- **kwargs,
145
- ) -> None:
146
- """_summary_.
147
-
148
- Args:
149
- root (str):
150
- Folder containing the CTC TRA folder.
151
- ndim (int):
152
- Number of dimensions of the data. Defaults to 2d
153
- (if ndim=3 and data is two dimensional, it will be cast to 3D)
154
- detection_folders:
155
- List of relative paths to folder with detections.
156
- Defaults to ["TRA"], which uses the ground truth detections.
157
- window_size (int):
158
- Window size for transformer.
159
- slice_pct (tuple):
160
- Slice the dataset by percentages (from, to).
161
- augment (int):
162
- if 0, no data augmentation. if > 0, defines level of data augmentation.
163
- features (str):
164
- Types of features to use.
165
- sanity_dist (bool):
166
- Use euclidian distance instead of the association matrix as a target.
167
- crop_size (tuple):
168
- Size of the crops to use for augmentation. If None, no cropping is used.
169
- return_dense (bool):
170
- Return dense masks and images in the data samples.
171
- compress (bool):
172
- Compress elements/remove img if not needed to save memory for large datasets
173
- """
174
- super().__init__()
175
-
176
- self.root = Path(root)
177
- self.name = self.root.name
178
- self.use_gt = use_gt
179
- self.slice_pct = slice_pct
180
- if not 0 <= slice_pct[0] < slice_pct[1] <= 1:
181
- raise ValueError(f"Invalid slice_pct {slice_pct}")
182
- self.downscale_spatial = downscale_spatial
183
- self.downscale_temporal = downscale_temporal
184
- self.detection_folders = detection_folders
185
- self.ndim = ndim
186
- self.features = features
187
-
188
- if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
189
- raise ValueError(
190
- f"'{features}' not one of the supported {ndim}D features"
191
- f" {tuple(_PROPERTIES[ndim].keys())}"
192
- )
193
-
194
- logger.info(f"ROOT (config): \t{self.root}")
195
- self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
196
- logger.info(f"ROOT (guessed): \t{self.root}")
197
- logger.info(f"GT TRA (guessed):\t{self.gt_tra_folder}")
198
- if self.use_gt:
199
- self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder)
200
- else:
201
- logger.info("Using dummy masks as GT")
202
- self.gt_mask_folder = self._guess_det_folder(
203
- self.root, self.detection_folders[0]
204
- )
205
- logger.info(f"GT MASK (guessed):\t{self.gt_mask_folder}")
206
-
207
- # dont load image data if not needed
208
- if features in ("none",):
209
- self.img_folder = None
210
- else:
211
- self.img_folder = self._guess_img_folder(self.root)
212
- logger.info(f"IMG (guessed):\t{self.img_folder}")
213
-
214
- self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
215
- ndim, features, augment, crop_size
216
- )
217
-
218
- if window_size <= 1:
219
- raise ValueError("window must be >1")
220
- self.window_size = window_size
221
- self.max_tokens = max_tokens
222
-
223
- self.slice_pct = slice_pct
224
- self.sanity_dist = sanity_dist
225
- self.return_dense = return_dense
226
- self.compress = compress
227
- self.start_frame = 0
228
- self.end_frame = None
229
-
230
- start = default_timer()
231
-
232
- if self.features == "wrfeat":
233
- self.windows = self._load_wrfeat()
234
- else:
235
- self.windows = self._load()
236
-
237
- self.n_divs = self._get_ndivs(self.windows)
238
-
239
- if len(self.windows) > 0:
240
- self.ndim = self.windows[0]["coords"].shape[1]
241
- self.n_objects = tuple(len(t["coords"]) for t in self.windows)
242
- logger.info(
243
- f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
244
- f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
245
- )
246
- else:
247
- self.n_objects = 0
248
- logger.warning(f"Could not load any tracks from {self.root}")
249
-
250
- if self.compress:
251
- self._compress_data()
252
-
253
- # def from_ctc
254
-
255
- @classmethod
256
- def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
257
- self = cls(**train_args)
258
- # for key, value in train_args.items():
259
- # setattr(self, key, value)
260
-
261
- # self.use_gt = use_gt
262
- # self.slice_pct = slice_pct
263
- # if not 0 <= slice_pct[0] < slice_pct[1] <= 1:
264
- # raise ValueError(f"Invalid slice_pct {slice_pct}")
265
- # self.downscale_spatial = downscale_spatial
266
- # self.downscale_temporal = downscale_temporal
267
- # self.detection_folders = detection_folders
268
- # self.ndim = ndim
269
- # self.features = features
270
-
271
- # if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
272
- # raise ValueError(
273
- # f"'{features}' not one of the supported {ndim}D features {tuple(_PROPERTIES[ndim].keys())}"
274
- # )
275
-
276
- # logger.info(f"ROOT (config): {self.root}")
277
- # self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
278
- # logger.info(f"ROOT: \t{self.root}")
279
- # logger.info(f"GT TRA:\t{self.gt_tra_folder}")
280
- # if self.use_gt:
281
- # self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder)
282
- # else:
283
- # logger.info("Using dummy masks as GT")
284
- # self.gt_mask_folder = self._guess_det_folder(
285
- # self.root, self.detection_folders[0]
286
- # )
287
- # logger.info(f"GT MASK:\t{self.gt_mask_folder}")
288
-
289
- # dont load image data if not needed
290
- # if features in ("none",):
291
- # self.img_folder = None
292
- # else:
293
- # self.img_folder = self._guess_img_folder(self.root)
294
- # logger.info(f"IMG:\t\t{self.img_folder}")
295
-
296
- self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
297
- self.ndim, self.features, self.augment, self.crop_size
298
- )
299
-
300
- start = default_timer()
301
-
302
- if self.features == "wrfeat":
303
- self.windows = self._load_wrfeat()
304
- else:
305
- self.windows = self._load()
306
-
307
- self.n_divs = self._get_ndivs(self.windows)
308
-
309
- if len(self.windows) > 0:
310
- self.ndim = self.windows[0]["coords"].shape[1]
311
- self.n_objects = tuple(len(t["coords"]) for t in self.windows)
312
- logger.info(
313
- f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
314
- f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
315
- )
316
- else:
317
- self.n_objects = 0
318
- logger.warning(f"Could not load any tracks from {self.root}")
319
-
320
- if self.compress:
321
- self._compress_data()
322
-
323
- def _get_ndivs(self, windows):
324
- n_divs = []
325
- for w in tqdm(windows, desc="Counting divisions", leave=False):
326
- _n = (
327
- (
328
- blockwise_sum(
329
- torch.from_numpy(w["assoc_matrix"]).float(),
330
- torch.from_numpy(w["timepoints"]).long(),
331
- ).max(dim=0)[0]
332
- == 2
333
- )
334
- .sum()
335
- .item()
336
- )
337
- n_divs.append(_n)
338
- return n_divs
339
-
340
- def _setup_features_augs(
341
- self, ndim: int, features: str, augment: int, crop_size: Tuple[int]
342
- ):
343
- if self.features == "wrfeat":
344
- return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size)
345
-
346
- cropper = (
347
- RandomCrop(
348
- crop_size=crop_size,
349
- ndim=ndim,
350
- use_padding=False,
351
- ensure_inside_points=True,
352
- )
353
- if crop_size is not None
354
- else None
355
- )
356
-
357
- # Hack
358
- if self.features == "none":
359
- return 0, default_augmenter, cropper
360
-
361
- if ndim == 2:
362
- augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
363
- feat_dim = {
364
- "none": 0,
365
- "regionprops": 7,
366
- "regionprops2": 6,
367
- "patch": 256,
368
- "patch_regionprops": 256 + 5,
369
- }[features]
370
- elif ndim == 3:
371
- augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
372
- feat_dim = {
373
- "none": 0,
374
- "regionprops2": 11,
375
- "patch_regionprops": 256 + 8,
376
- }[features]
377
-
378
- return feat_dim, augmenter, cropper
379
-
380
- def _compress_data(self):
381
- # compress masks and assoc_matrix
382
- logger.info("Compressing masks and assoc_matrix to save memory")
383
- for w in self.windows:
384
- w["mask"] = _CompressedArray(w["mask"])
385
- # dont compress full imgs (as needed for patch features)
386
- w["img"] = _CompressedArray(w["img"])
387
- w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"])
388
- self.gt_masks = _CompressedArray(self.gt_masks)
389
- self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()}
390
- # dont compress full imgs (as needed for patch features)
391
- self.imgs = _CompressedArray(self.imgs)
392
-
393
- def _guess_root_and_gt_tra_folder(self, inp: Path):
394
- """Guesses the root and the ground truth folder from a given input path.
395
-
396
- Args:
397
- inp (Path): _description_
398
-
399
- Returns:
400
- Path: root folder,
401
- """
402
- if inp.name == "TRA":
403
- # 01_GT/TRA --> 01, 01_GT/TRA
404
- root = inp.parent.parent / inp.parent.name.split("_")[0]
405
- return root, inp
406
- elif "ERR_SEG" in inp.name:
407
- # 01_ERR_SEG --> 01, 01_GT/TRA. We know that the data is in CTC folder format
408
- num = inp.name.split("_")[0]
409
- return inp.parent / num, inp.parent / f"{num}_GT" / "TRA"
410
- else:
411
- ctc_tra = Path(f"{inp}_GT") / "TRA"
412
- tra = ctc_tra if ctc_tra.exists() else inp / "TRA"
413
- # 01 --> 01, 01_GT/TRA or 01/TRA
414
- return inp, tra
415
-
416
- def _guess_img_folder(self, root: Path):
417
- """Guesses the image folder corresponding to a root."""
418
- if (root / "img").exists():
419
- return root / "img"
420
- else:
421
- return root
422
-
423
- def _guess_mask_folder(self, root: Path, gt_tra: Path):
424
- """Guesses the mask folder corresponding to a root.
425
-
426
- In CTC format, we use silver truth segmentation masks.
427
- """
428
- f = None
429
- # first try CTC format
430
- if gt_tra.parent.name.endswith("_GT"):
431
- # We use the silver truth segmentation masks
432
- f = root / str(gt_tra.parent.name).replace("_GT", "_ST") / "SEG"
433
- # try our simpler 'img' format
434
- if f is None or not f.exists():
435
- f = gt_tra
436
- if not f.exists():
437
- raise ValueError(f"Could not find mask folder for {root}")
438
- return f
439
-
440
- @classmethod
441
- def _guess_det_folder(cls, root: Path, suffix: str):
442
- """Checks for the annoying CTC format with dataset numbering as part of folder names."""
443
- guesses = (
444
- (root / suffix),
445
- Path(f"{root}_{suffix}"),
446
- Path(f"{root}_GT") / suffix,
447
- )
448
- for path in guesses:
449
- if path.exists():
450
- return path
451
-
452
- logger.warning(f"Skipping non-existing detection folder {root / suffix}")
453
- return None
454
-
455
- def __len__(self):
456
- return len(self.windows)
457
-
458
- def _load_gt(self):
459
- logger.info("Loading ground truth")
460
- self.start_frame = int(
461
- len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[0]
462
- )
463
- self.end_frame = int(
464
- len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[1]
465
- )
466
-
467
- masks = self._load_tiffs(self.gt_mask_folder, dtype=np.int32)
468
- masks = self._correct_gt_with_st(self.gt_mask_folder, masks, dtype=np.int32)
469
-
470
- if self.use_gt:
471
- track_df = self._load_tracklet_links(self.gt_tra_folder)
472
- track_df = _filter_track_df(
473
- track_df, self.start_frame, self.end_frame, self.downscale_temporal
474
- )
475
- else:
476
- # create dummy track dataframe
477
- logger.info("Using dummy track dataframe")
478
- track_df = self._build_tracklets_without_gt(masks)
479
-
480
- _check_ctc(track_df, _get_node_attributes(masks), masks)
481
-
482
- # Build ground truth lineage graph
483
- self.gt_labels, self.gt_timepoints, self.gt_graph = _ctc_lineages(
484
- track_df, masks
485
- )
486
-
487
- return masks, track_df
488
-
489
- def _correct_gt_with_st(
490
- self, folder: Path, x: np.ndarray, dtype: Optional[str] = None
491
- ):
492
- if str(folder).endswith("_GT/TRA"):
493
- st_path = (
494
- tuple(folder.parents)[1]
495
- / folder.parent.stem.replace("_GT", "_ST")
496
- / "SEG"
497
- )
498
- if not st_path.exists():
499
- logger.debug("No _ST folder found, skipping correction")
500
- else:
501
- logger.info(f"ST MASK:\t\t{st_path} for correcting with ST masks")
502
- st_masks = self._load_tiffs(st_path, dtype)
503
- x = np.maximum(x, st_masks)
504
-
505
- return x
506
-
507
- def _load_tiffs(self, folder: Path, dtype=None):
508
- assert isinstance(self.downscale_temporal, int)
509
- logger.debug(f"Loading tiffs from {folder} as {dtype}")
510
- logger.debug(
511
- f"Temporal downscaling of {folder.name} by {self.downscale_temporal}"
512
- )
513
- x = np.stack([
514
- tifffile.imread(f).astype(dtype)
515
- for f in tqdm(
516
- sorted(folder.glob("*.tif"))[
517
- self.start_frame : self.end_frame : self.downscale_temporal
518
- ],
519
- leave=False,
520
- desc=f"Loading [{self.start_frame}:{self.end_frame}]",
521
- )
522
- ])
523
-
524
- # T, (Z), Y, X
525
- assert isinstance(self.downscale_spatial, int)
526
- if self.downscale_spatial > 1 or self.downscale_temporal > 1:
527
- # TODO make safe for label arrays
528
- logger.debug(
529
- f"Spatial downscaling of {folder.name} by {self.downscale_spatial}"
530
- )
531
- slices = (
532
- slice(None),
533
- *tuple(
534
- slice(None, None, self.downscale_spatial) for _ in range(x.ndim - 1)
535
- ),
536
- )
537
- x = x[slices]
538
-
539
- logger.debug(f"Loaded array of shape {x.shape} from {folder}")
540
- return x
541
-
542
- def _masks2properties(self, masks):
543
- """Turn label masks into lists of properties, sorted (ascending) by time and label id.
544
-
545
- Args:
546
- masks (np.ndarray): T, (Z), H, W
547
-
548
- Returns:
549
- labels: List of labels
550
- ts: List of timepoints
551
- coords: List of coordinates
552
- """
553
- # Get coordinates, timepoints, and labels of detections
554
- labels = []
555
- ts = []
556
- coords = []
557
- properties_by_time = dict()
558
- assert len(self.imgs) == len(masks)
559
- for _t, frame in tqdm(
560
- enumerate(masks),
561
- # total=len(detections),
562
- leave=False,
563
- desc="Loading masks and properties",
564
- ):
565
- regions = regionprops(frame)
566
- t_labels = []
567
- t_ts = []
568
- t_coords = []
569
- for _r in regions:
570
- t_labels.append(_r.label)
571
- t_ts.append(_t)
572
- centroid = np.array(_r.centroid).astype(int)
573
- t_coords.append(centroid)
574
-
575
- properties_by_time[_t] = dict(coords=t_coords, labels=t_labels)
576
- labels.extend(t_labels)
577
- ts.extend(t_ts)
578
- coords.extend(t_coords)
579
-
580
- labels = np.array(labels, dtype=int)
581
- ts = np.array(ts, dtype=int)
582
- coords = np.array(coords, dtype=int)
583
-
584
- return labels, ts, coords, properties_by_time
585
-
586
- def _load_tracklet_links(self, folder: Path) -> pd.DataFrame:
587
- df = pd.read_csv(
588
- folder / "man_track.txt",
589
- delimiter=" ",
590
- names=["label", "t1", "t2", "parent"],
591
- dtype=int,
592
- )
593
- n_dets = (df.t2 - df.t1 + 1).sum()
594
- logger.debug(f"{folder} has {n_dets} detections")
595
-
596
- n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum()
597
- logger.debug(f"{folder} has {n_divs} divisions")
598
- return df
599
-
600
- def _build_tracklets_without_gt(self, masks):
601
- """Create a dataframe with tracklets from masks."""
602
- rows = []
603
- for t, m in enumerate(masks):
604
- for c in np.unique(m[m > 0]):
605
- rows.append([c, t, t, 0])
606
- df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"])
607
- return df
608
-
609
- def _check_dimensions(self, x: np.ndarray):
610
- if self.ndim == 2 and not x.ndim == 3:
611
- raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data")
612
- elif self.ndim == 3:
613
- # if ndim=3 and data is two dimensional, it will be cast to 3D
614
- if x.ndim == 3:
615
- x = np.expand_dims(x, axis=1)
616
- elif x.ndim == 4:
617
- pass
618
- else:
619
- raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data")
620
- return x
621
-
622
- def _load(self):
623
- # Load ground truth
624
- logger.info("Loading ground truth")
625
- self.gt_masks, self.gt_track_df = self._load_gt()
626
-
627
- self.gt_masks = self._check_dimensions(self.gt_masks)
628
-
629
- # Load images
630
- if self.img_folder is None:
631
- self.imgs = np.zeros_like(self.gt_masks)
632
- else:
633
- logger.info("Loading images")
634
- imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
635
- self.imgs = np.stack([
636
- normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
637
- ])
638
- self.imgs = self._check_dimensions(self.imgs)
639
- if self.compress:
640
- # prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
641
- self.imgs = np.stack([
642
- _compress_img_mask_preproc(im, mask, self.features)
643
- for im, mask in zip(self.imgs, self.gt_masks)
644
- ])
645
-
646
- assert len(self.gt_masks) == len(self.imgs)
647
-
648
- # Load each of the detection folders and create data samples with a sliding window
649
- windows = []
650
- self.properties_by_time = dict()
651
- self.det_masks = dict()
652
- for _f in self.detection_folders:
653
- det_folder = self.root / _f
654
-
655
- if det_folder == self.gt_mask_folder:
656
- det_masks = self.gt_masks
657
- logger.info("DET MASK:\tUsing GT masks")
658
- (
659
- det_labels,
660
- det_ts,
661
- det_coords,
662
- det_properties_by_time,
663
- ) = self._masks2properties(det_masks)
664
-
665
- det_gt_matching = {
666
- t: {_l: _l for _l in det_properties_by_time[t]["labels"]}
667
- for t in range(len(det_masks))
668
- }
669
- else:
670
- det_folder = self._guess_det_folder(root=self.root, suffix=_f)
671
- if det_folder is None:
672
- continue
673
-
674
- logger.info(f"DET MASK:\t{det_folder}")
675
- det_masks = self._load_tiffs(det_folder, dtype=np.int32)
676
- det_masks = self._correct_gt_with_st(
677
- det_folder, det_masks, dtype=np.int32
678
- )
679
- det_masks = self._check_dimensions(det_masks)
680
- (
681
- det_labels,
682
- det_ts,
683
- det_coords,
684
- det_properties_by_time,
685
- ) = self._masks2properties(det_masks)
686
-
687
- # FIXME matching can be slow for big images
688
- # raise NotImplementedError("Matching not implemented for 3d version")
689
- det_gt_matching = {
690
- t: {
691
- _d: _gt
692
- for _gt, _d in matching(
693
- self.gt_masks[t],
694
- det_masks[t],
695
- threshold=0.3,
696
- max_distance=16,
697
- )
698
- }
699
- for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
700
- }
701
-
702
- self.properties_by_time[_f] = det_properties_by_time
703
- self.det_masks[_f] = det_masks
704
- _w = self._build_windows(
705
- det_folder,
706
- det_masks,
707
- det_labels,
708
- det_ts,
709
- det_coords,
710
- det_gt_matching,
711
- )
712
-
713
- windows.extend(_w)
714
-
715
- return windows
716
-
717
- def _build_windows(
718
- self,
719
- det_folder,
720
- det_masks,
721
- labels,
722
- ts,
723
- coords,
724
- matching,
725
- ):
726
- """_summary_.
727
-
728
- Args:
729
- det_folder (_type_): _description_
730
- det_masks (_type_): _description_
731
- labels (_type_): _description_
732
- ts (_type_): _description_
733
- coords (_type_): _description_
734
- matching (_type_): _description_
735
-
736
- Raises:
737
- ValueError: _description_
738
- ValueError: _description_
739
-
740
- Returns:
741
- _type_: _description_
742
- """
743
- window_size = self.window_size
744
- windows = []
745
-
746
- # Creates the data samples with a sliding window
747
- masks = self.gt_masks
748
- for t1, t2 in tqdm(
749
- zip(range(0, len(masks)), range(window_size, len(masks) + 1)),
750
- total=len(masks) - window_size + 1,
751
- leave=False,
752
- desc="Building windows",
753
- ):
754
- idx = (ts >= t1) & (ts < t2)
755
- _ts = ts[idx]
756
- _labels = labels[idx]
757
- _coords = coords[idx]
758
-
759
- # Use GT
760
- # _labels = self.gt_labels[idx]
761
- # _ts = self.gt_timepoints[idx]
762
-
763
- if len(_labels) == 0:
764
- # raise ValueError(f"No detections in sample {det_folder}:{t1}")
765
- A = np.zeros((0, 0), dtype=bool)
766
- _coords = np.zeros((0, masks.ndim - 1), dtype=int)
767
- else:
768
- if len(np.unique(_ts)) == 1:
769
- logger.debug(
770
- "Only detections from a single timepoint in sample"
771
- f" {det_folder}:{t1}"
772
- )
773
-
774
- # build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
775
- A = _ctc_assoc_matrix(
776
- _labels,
777
- _ts,
778
- self.gt_graph,
779
- matching,
780
- )
781
-
782
- if self.sanity_dist:
783
- # # Sanity check: Can the model learn the euclidian distances?
784
- # c = coords - coords.mean(axis=0, keepdims=True)
785
- # c /= c.std(axis=0, keepdims=True)
786
- # A = np.einsum('id,jd',c,c)
787
- # A = 1 / (1 + np.exp(-A))
788
- A = np.exp(-0.01 * cdist(_coords, _coords))
789
-
790
- w = dict(
791
- coords=_coords,
792
- # TODO imgs and masks are unaltered here
793
- t1=t1,
794
- img=self.imgs[t1:t2],
795
- mask=det_masks[t1:t2],
796
- assoc_matrix=A,
797
- labels=_labels,
798
- timepoints=_ts,
799
- )
800
-
801
- windows.append(w)
802
-
803
- logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n")
804
- return windows
805
-
806
- def __getitem__(self, n: int, return_dense=None):
807
- # if not set, use default
808
- if self.features == "wrfeat":
809
- return self._getitem_wrfeat(n, return_dense)
810
-
811
- if return_dense is None:
812
- return_dense = self.return_dense
813
-
814
- track = self.windows[n]
815
- coords = track["coords"]
816
- assoc_matrix = track["assoc_matrix"]
817
- labels = track["labels"]
818
- img = track["img"]
819
- mask = track["mask"]
820
- timepoints = track["timepoints"]
821
- min_time = track["t1"]
822
-
823
- if isinstance(mask, _CompressedArray):
824
- mask = mask.decompress()
825
- if isinstance(img, _CompressedArray):
826
- img = img.decompress()
827
- if isinstance(assoc_matrix, _CompressedArray):
828
- assoc_matrix = assoc_matrix.decompress()
829
-
830
- # cropping
831
- if self.cropper is not None:
832
- (img2, mask2, coords2), idx = self.cropper(img, mask, coords)
833
- cropped_timepoints = timepoints[idx]
834
-
835
- # at least one detection in each timepoint to accept the crop
836
- if len(np.unique(cropped_timepoints)) == self.window_size:
837
- # at least two total detections to accept the crop
838
- # if len(idx) >= 2:
839
- img, mask, coords = img2, mask2, coords2
840
- labels = labels[idx]
841
- timepoints = timepoints[idx]
842
- assoc_matrix = assoc_matrix[idx][:, idx]
843
- else:
844
- logger.debug("disable cropping as no trajectories would be left")
845
-
846
- if self.features == "none":
847
- if self.augmenter is not None:
848
- coords = self.augmenter(coords)
849
- # Empty features
850
- features = np.zeros((len(coords), 0))
851
-
852
- elif self.features in ("regionprops", "regionprops2"):
853
- if self.augmenter is not None:
854
- (img2, mask2, coords2), idx = self.augmenter(
855
- img, mask, coords, timepoints - min_time
856
- )
857
- if len(idx) > 0:
858
- img, mask, coords = img2, mask2, coords2
859
- labels = labels[idx]
860
- timepoints = timepoints[idx]
861
- assoc_matrix = assoc_matrix[idx][:, idx]
862
- mask = mask.astype(int)
863
- else:
864
- logger.debug(
865
- "disable augmentation as no trajectories would be left"
866
- )
867
-
868
- features = tuple(
869
- extract_features_regionprops(
870
- m, im, labels[timepoints == i + min_time], properties=self.features
871
- )
872
- for i, (m, im) in enumerate(zip(mask, img))
873
- )
874
- features = np.concatenate(features, axis=0)
875
- # features = np.zeros((len(coords), self.feat_dim))
876
-
877
- elif self.features == "patch":
878
- if self.augmenter is not None:
879
- (img2, mask2, coords2), idx = self.augmenter(
880
- img, mask, coords, timepoints - min_time
881
- )
882
- if len(idx) > 0:
883
- img, mask, coords = img2, mask2, coords2
884
- labels = labels[idx]
885
- timepoints = timepoints[idx]
886
- assoc_matrix = assoc_matrix[idx][:, idx]
887
- mask = mask.astype(int)
888
- else:
889
- print("disable augmentation as no trajectories would be left")
890
-
891
- features = tuple(
892
- extract_features_patch(
893
- m,
894
- im,
895
- coords[timepoints == min_time + i],
896
- labels[timepoints == min_time + i],
897
- )
898
- for i, (m, im) in enumerate(zip(mask, img))
899
- )
900
- features = np.concatenate(features, axis=0)
901
- elif self.features == "patch_regionprops":
902
- if self.augmenter is not None:
903
- (img2, mask2, coords2), idx = self.augmenter(
904
- img, mask, coords, timepoints - min_time
905
- )
906
- if len(idx) > 0:
907
- img, mask, coords = img2, mask2, coords2
908
- labels = labels[idx]
909
- timepoints = timepoints[idx]
910
- assoc_matrix = assoc_matrix[idx][:, idx]
911
- mask = mask.astype(int)
912
- else:
913
- print("disable augmentation as no trajectories would be left")
914
-
915
- features1 = tuple(
916
- extract_features_patch(
917
- m,
918
- im,
919
- coords[timepoints == min_time + i],
920
- labels[timepoints == min_time + i],
921
- )
922
- for i, (m, im) in enumerate(zip(mask, img))
923
- )
924
- features2 = tuple(
925
- extract_features_regionprops(
926
- m,
927
- im,
928
- labels[timepoints == i + min_time],
929
- properties=self.features,
930
- )
931
- for i, (m, im) in enumerate(zip(mask, img))
932
- )
933
-
934
- features = tuple(
935
- np.concatenate((f1, f2), axis=-1)
936
- for f1, f2 in zip(features1, features2)
937
- )
938
-
939
- features = np.concatenate(features, axis=0)
940
-
941
- # remove temporal offset and add timepoints to coords
942
- relative_timepoints = timepoints - track["t1"]
943
- coords = np.concatenate((relative_timepoints[:, None], coords), axis=-1)
944
-
945
- if self.max_tokens and len(timepoints) > self.max_tokens:
946
- time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
947
- n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
948
- timepoints = timepoints[:n_elems]
949
- labels = labels[:n_elems]
950
- coords = coords[:n_elems]
951
- features = features[:n_elems]
952
- assoc_matrix = assoc_matrix[:n_elems, :n_elems]
953
- logger.info(
954
- f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
955
- )
956
-
957
- coords0 = torch.from_numpy(coords).float()
958
- features = torch.from_numpy(features).float()
959
- assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float()
960
- labels = torch.from_numpy(labels).long()
961
- timepoints = torch.from_numpy(timepoints).long()
962
-
963
- if self.augmenter is not None:
964
- coords = coords0.clone()
965
- coords[:, 1:] += torch.randint(0, 256, (1, self.ndim))
966
- else:
967
- coords = coords0.clone()
968
- res = dict(
969
- features=features,
970
- coords0=coords0,
971
- coords=coords,
972
- assoc_matrix=assoc_matrix,
973
- timepoints=timepoints,
974
- labels=labels,
975
- )
976
-
977
- if return_dense:
978
- if all([x is not None for x in img]):
979
- img = torch.from_numpy(img).float()
980
- res["img"] = img
981
-
982
- mask = torch.from_numpy(mask.astype(int)).long()
983
- res["mask"] = mask
984
-
985
- return res
986
-
987
- # wrfeat functions...
988
- # TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way
989
-
990
- def _setup_features_augs_wrfeat(
991
- self, ndim: int, features: str, augment: int, crop_size: Tuple[int]
992
- ):
993
- # FIXME: hardcoded
994
- feat_dim = 7 if ndim == 2 else 12
995
- if augment == 1:
996
- augmenter = wrfeat.WRAugmentationPipeline([
997
- wrfeat.WRRandomFlip(p=0.5),
998
- wrfeat.WRRandomAffine(
999
- p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
1000
- ),
1001
- # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)),
1002
- # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
1003
- ])
1004
- elif augment == 2:
1005
- augmenter = wrfeat.WRAugmentationPipeline([
1006
- wrfeat.WRRandomFlip(p=0.5),
1007
- wrfeat.WRRandomAffine(
1008
- p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
1009
- ),
1010
- wrfeat.WRRandomBrightness(p=0.8),
1011
- wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
1012
- ])
1013
- elif augment == 3:
1014
- augmenter = wrfeat.WRAugmentationPipeline([
1015
- wrfeat.WRRandomFlip(p=0.5),
1016
- wrfeat.WRRandomAffine(
1017
- p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
1018
- ),
1019
- wrfeat.WRRandomBrightness(p=0.8),
1020
- wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3),
1021
- wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
1022
- ])
1023
- else:
1024
- augmenter = None
1025
-
1026
- cropper = (
1027
- wrfeat.WRRandomCrop(
1028
- crop_size=crop_size,
1029
- ndim=ndim,
1030
- )
1031
- if crop_size is not None
1032
- else None
1033
- )
1034
- return feat_dim, augmenter, cropper
1035
-
1036
- def _load_wrfeat(self):
1037
- # Load ground truth
1038
- self.gt_masks, self.gt_track_df = self._load_gt()
1039
- self.gt_masks = self._check_dimensions(self.gt_masks)
1040
-
1041
- # Load images
1042
- if self.img_folder is None:
1043
- if self.gt_masks is not None:
1044
- self.imgs = np.zeros_like(self.gt_masks)
1045
- else:
1046
- raise NotImplementedError("No images and no GT masks")
1047
- else:
1048
- logger.info("Loading images")
1049
- imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
1050
- self.imgs = np.stack([
1051
- normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
1052
- ])
1053
- self.imgs = self._check_dimensions(self.imgs)
1054
- if self.compress:
1055
- # prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
1056
- self.imgs = np.stack([
1057
- _compress_img_mask_preproc(im, mask, self.features)
1058
- for im, mask in zip(self.imgs, self.gt_masks)
1059
- ])
1060
-
1061
- assert len(self.gt_masks) == len(self.imgs)
1062
-
1063
- # Load each of the detection folders and create data samples with a sliding window
1064
- windows = []
1065
- self.properties_by_time = dict()
1066
- self.det_masks = dict()
1067
- logger.info("Loading detections")
1068
- for _f in self.detection_folders:
1069
- det_folder = self.root / _f
1070
-
1071
- if det_folder == self.gt_mask_folder:
1072
- det_masks = self.gt_masks
1073
- logger.info("DET MASK:\tUsing GT masks")
1074
- # identity matching
1075
- det_gt_matching = {
1076
- t: {_l: _l for _l in set(np.unique(d)) - {0}}
1077
- for t, d in enumerate(det_masks)
1078
- }
1079
- else:
1080
- det_folder = self._guess_det_folder(root=self.root, suffix=_f)
1081
- if det_folder is None:
1082
- continue
1083
- logger.info(f"DET MASK (guessed):\t{det_folder}")
1084
- det_masks = self._load_tiffs(det_folder, dtype=np.int32)
1085
- det_masks = self._correct_gt_with_st(
1086
- det_folder, det_masks, dtype=np.int32
1087
- )
1088
- det_masks = self._check_dimensions(det_masks)
1089
- # FIXME matching can be slow for big images
1090
- # raise NotImplementedError("Matching not implemented for 3d version")
1091
- det_gt_matching = {
1092
- t: {
1093
- _d: _gt
1094
- for _gt, _d in matching(
1095
- self.gt_masks[t],
1096
- det_masks[t],
1097
- threshold=0.3,
1098
- max_distance=16,
1099
- )
1100
- }
1101
- for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
1102
- }
1103
-
1104
- self.det_masks[_f] = det_masks
1105
-
1106
- # build features
1107
-
1108
- features = joblib.Parallel(n_jobs=8)(
1109
- joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
1110
- mask=mask[None], img=img[None], t_start=t
1111
- )
1112
- for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
1113
- )
1114
-
1115
- properties_by_time = dict()
1116
- for _t, _feats in enumerate(features):
1117
- properties_by_time[_t] = dict(
1118
- coords=_feats.coords, labels=_feats.labels
1119
- )
1120
- self.properties_by_time[_f] = properties_by_time
1121
-
1122
- _w = self._build_windows_wrfeat(
1123
- features,
1124
- det_masks,
1125
- det_gt_matching,
1126
- )
1127
-
1128
- windows.extend(_w)
1129
-
1130
- return windows
1131
-
1132
- def _build_windows_wrfeat(
1133
- self,
1134
- features: Sequence[wrfeat.WRFeatures],
1135
- det_masks: np.ndarray,
1136
- matching: Tuple[dict],
1137
- ):
1138
- assert len(self.imgs) == len(det_masks)
1139
-
1140
- window_size = self.window_size
1141
- windows = []
1142
-
1143
- # Creates the data samples with a sliding window
1144
- for t1, t2 in tqdm(
1145
- zip(range(0, len(det_masks)), range(window_size, len(det_masks) + 1)),
1146
- total=len(det_masks) - window_size + 1,
1147
- leave=False,
1148
- desc="Building windows",
1149
- ):
1150
- img = self.imgs[t1:t2]
1151
- mask = det_masks[t1:t2]
1152
- feat = wrfeat.WRFeatures.concat(features[t1:t2])
1153
-
1154
- labels = feat.labels
1155
- timepoints = feat.timepoints
1156
- coords = feat.coords
1157
-
1158
- if len(feat) == 0:
1159
- A = np.zeros((0, 0), dtype=bool)
1160
- coords = np.zeros((0, feat.ndim), dtype=int)
1161
- else:
1162
- # build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
1163
- A = _ctc_assoc_matrix(
1164
- labels,
1165
- timepoints,
1166
- self.gt_graph,
1167
- matching,
1168
- )
1169
- w = dict(
1170
- coords=coords,
1171
- # TODO imgs and masks are unaltered here
1172
- t1=t1,
1173
- img=img,
1174
- mask=mask,
1175
- assoc_matrix=A,
1176
- labels=labels,
1177
- timepoints=timepoints,
1178
- wrfeat=feat,
1179
- )
1180
- windows.append(w)
1181
-
1182
- logger.debug(f"Built {len(windows)} track windows.\n")
1183
- return windows
1184
-
1185
- def _getitem_wrfeat(self, n: int, return_dense=None):
1186
- # if not set, use default
1187
-
1188
- if return_dense is None:
1189
- return_dense = self.return_dense
1190
-
1191
- track = self.windows[n]
1192
- # coords = track["coords"]
1193
- assoc_matrix = track["assoc_matrix"]
1194
- labels = track["labels"]
1195
- img = track["img"]
1196
- mask = track["mask"]
1197
- timepoints = track["timepoints"]
1198
- # track["t1"]
1199
- feat = track["wrfeat"]
1200
-
1201
- if return_dense and isinstance(mask, _CompressedArray):
1202
- mask = mask.decompress()
1203
- if return_dense and isinstance(img, _CompressedArray):
1204
- img = img.decompress()
1205
- if isinstance(assoc_matrix, _CompressedArray):
1206
- assoc_matrix = assoc_matrix.decompress()
1207
-
1208
- # cropping
1209
- if self.cropper is not None:
1210
- # Use only if there is at least one timepoint per detection
1211
- cropped_feat, cropped_idx = self.cropper(feat)
1212
- cropped_timepoints = timepoints[cropped_idx]
1213
- if len(np.unique(cropped_timepoints)) == self.window_size:
1214
- idx = cropped_idx
1215
- feat = cropped_feat
1216
- labels = labels[idx]
1217
- timepoints = timepoints[idx]
1218
- assoc_matrix = assoc_matrix[idx][:, idx]
1219
- else:
1220
- logger.debug("Skipping cropping")
1221
-
1222
- if self.augmenter is not None:
1223
- feat = self.augmenter(feat)
1224
-
1225
- coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1)
1226
- coords0 = torch.from_numpy(coords0).float()
1227
- assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32))
1228
- features = torch.from_numpy(feat.features_stacked).float()
1229
- labels = torch.from_numpy(feat.labels).long()
1230
- timepoints = torch.from_numpy(feat.timepoints).long()
1231
-
1232
- if self.max_tokens and len(timepoints) > self.max_tokens:
1233
- time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
1234
- n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
1235
- timepoints = timepoints[:n_elems]
1236
- labels = labels[:n_elems]
1237
- coords0 = coords0[:n_elems]
1238
- features = features[:n_elems]
1239
- assoc_matrix = assoc_matrix[:n_elems, :n_elems]
1240
- logger.debug(
1241
- f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
1242
- )
1243
-
1244
- if self.augmenter is not None:
1245
- coords = coords0.clone()
1246
- coords[:, 1:] += torch.randint(0, 512, (1, self.ndim))
1247
- else:
1248
- coords = coords0.clone()
1249
- res = dict(
1250
- features=features,
1251
- coords0=coords0,
1252
- coords=coords,
1253
- assoc_matrix=assoc_matrix,
1254
- timepoints=timepoints,
1255
- labels=labels,
1256
- )
1257
-
1258
- if return_dense:
1259
- if all([x is not None for x in img]):
1260
- img = torch.from_numpy(img).float()
1261
- res["img"] = img
1262
-
1263
- mask = torch.from_numpy(mask.astype(int)).long()
1264
- res["mask"] = mask
1265
-
1266
- return res
1267
-
1268
-
1269
- def _ctc_lineages(df, masks, t1=0, t2=None):
1270
- """From a ctc dataframe, create a digraph that contains all sublineages
1271
- between t1 and t2 (exclusive t2).
1272
-
1273
- Args:
1274
- df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt)
1275
- masks: List of masks. If t1 is not 0, then the masks are assumed to be already cropped accordingly.
1276
- t1: Start timepoint
1277
- t2: End timepoint (exclusive). If None, then t2 is set to len(masks)
1278
-
1279
- Returns:
1280
- labels: List of label ids extracted from the masks, ordered by timepoint.
1281
- ts: List of corresponding timepoints
1282
- graph: The digraph of the lineages between t1 and t2.
1283
- """
1284
- if t1 > 0:
1285
- assert t2 is not None
1286
- assert t2 - t1 == len(masks)
1287
- if t2 is None:
1288
- t2 = len(masks)
1289
-
1290
- graph = nx.DiGraph()
1291
- labels = []
1292
- ts = []
1293
-
1294
- # get all objects that are present in the time interval
1295
- df_sub = df[(df.t1 < t2) & (df.t2 >= t1)]
1296
-
1297
- # Correct offset
1298
- df_sub.loc[:, "t1"] -= t1
1299
- df_sub.loc[:, "t2"] -= t1
1300
-
1301
- # all_labels = df_sub.label.unique()
1302
- # TODO speed up by precalculating unique values once
1303
- # in_masks = set(np.where(np.bincount(np.stack(masks[t1:t2]).ravel()))[0]) - {0}
1304
- # all_labels = [l for l in all_labels if l in in_masks]
1305
- all_labels = set()
1306
-
1307
- for t in tqdm(
1308
- range(0, t2 - t1), desc="Building and checking lineage graph", leave=False
1309
- ):
1310
- # get all entities at timepoint
1311
- obs = df_sub[(df_sub.t1 <= t) & (df_sub.t2 >= t)]
1312
- in_t = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0}
1313
- all_labels.update(in_t)
1314
- for row in obs.itertuples():
1315
- label, t1, t2, parent = row.label, row.t1, row.t2, row.parent
1316
- if label not in in_t:
1317
- continue
1318
-
1319
- labels.append(label)
1320
- ts.append(t)
1321
-
1322
- # add label as node if not already in graph
1323
- if not graph.has_node(label):
1324
- graph.add_node(label)
1325
-
1326
- # Parents have been added in previous timepoints
1327
- if parent in all_labels:
1328
- if not graph.has_node(parent):
1329
- graph.add_node(parent)
1330
- graph.add_edge(parent, label)
1331
-
1332
- labels = np.array(labels)
1333
- ts = np.array(ts)
1334
- return labels, ts, graph
1335
-
1336
-
1337
- @njit
1338
- def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray):
1339
- """For each detection, associate with all detections that are."""
1340
- for i in range(len(labels)):
1341
- for j in range(len(labels)):
1342
- A[i, j] = family[i, labels[j]]
1343
-
1344
-
1345
- def _ctc_assoc_matrix(detections, ts, graph, matching):
1346
- """Create the association matrix for a list of labels and a tracklet parent -> childrend graph.
1347
-
1348
- Each detection is associated with all its ancestors and descendants, but not its siblings and their offspring.
1349
-
1350
- Args:
1351
- detections: list of integer labels, ordered by timepoint
1352
- ts: list of timepoints corresponding to the detections
1353
- graph: networkx DiGraph with each ground truth tracklet id (spanning n timepoints) as a single node
1354
- and parent -> children relationships as edges.
1355
- matching: for each timepoint, a dictionary that maps from detection id to gt tracklet id
1356
- """
1357
- assert 0 not in graph
1358
- matched_gt = []
1359
- for i, (label, t) in enumerate(zip(detections, ts)):
1360
- gt_tracklet_id = matching[t].get(label, 0)
1361
- matched_gt.append(gt_tracklet_id)
1362
- matched_gt = np.array(matched_gt, dtype=int)
1363
- # Now we have the subset of gt nodes that is matched to any detection in the current window
1364
-
1365
- # relabel to reduce the size of lookup matrices
1366
- # offset 0 not allowed in skimage, which makes this very annoying
1367
- relabeled_gt, fwd_map, _inv_map = relabel_sequential(matched_gt, offset=1)
1368
- # dict is faster than arraymap
1369
- fwd_map = dict(zip(fwd_map.in_values, fwd_map.out_values))
1370
- # inv_map = dict(zip(inv_map.in_values, inv_map.out_values))
1371
-
1372
- # the family relationships for each ground truth detection,
1373
- # Maps from local detection number (0-indexed) to global gt tracklet id (1-indexed)
1374
- family = np.zeros((len(detections), len(relabeled_gt) + 1), bool)
1375
-
1376
- # Connects each tracklet id with its children and parent tracklets (according to man_track.txt)
1377
- for i, (label, t) in enumerate(zip(detections, ts)):
1378
- # Get the original label corresponding to the graph
1379
- gt_tracklet_id = matching[t].get(label, None)
1380
- if gt_tracklet_id is not None:
1381
- ancestors = []
1382
- descendants = []
1383
- # This iterates recursively through the graph
1384
- for n in nx.descendants(graph, gt_tracklet_id):
1385
- if n in fwd_map:
1386
- descendants.append(fwd_map[n])
1387
- for n in nx.ancestors(graph, gt_tracklet_id):
1388
- if n in fwd_map:
1389
- ancestors.append(fwd_map[n])
1390
-
1391
- family[i, np.array([fwd_map[gt_tracklet_id], *ancestors, *descendants])] = (
1392
- True
1393
- )
1394
- else:
1395
- pass
1396
- # Now we match to nothing, so even the matrix diagonal will not be filled.
1397
-
1398
- # This assures that matching to 0 is always false
1399
- assert family[:, 0].sum() == 0
1400
-
1401
- # Create the detection-to-detection association matrix
1402
- A = np.zeros((len(detections), len(detections)), dtype=bool)
1403
-
1404
- _assoc(A, relabeled_gt, family)
1405
-
1406
- return A
1407
-
1408
-
1409
- def sigmoid(x):
1410
- return 1 / (1 + np.exp(-x))
1411
-
1412
-
1413
- def _compress_img_mask_preproc(img, mask, features):
1414
- """Remove certain img pixels if not needed to save memory for large datasets."""
1415
- # dont change anything if we need patch values
1416
- if features in ("patch", "patch_regionprops"):
1417
- # clear img pixels outside of patch_mask of size 16x16
1418
- patch_width = 16 # TOD: hardcoded: change this if needed
1419
- coords = tuple(np.array(r.centroid).astype(int) for r in regionprops(mask))
1420
- img2 = np.zeros_like(img)
1421
- if len(coords) > 0:
1422
- coords = np.stack(coords)
1423
- coords = np.clip(coords, 0, np.array(mask.shape)[None] - 1)
1424
- patch_mask = np.zeros_like(img, dtype=bool)
1425
- patch_mask[tuple(coords.T)] = True
1426
- # retain 3*patch_width+1 around center to be safe...
1427
- patch_mask = ndi.maximum_filter(patch_mask, 3 * patch_width + 1)
1428
- img2[patch_mask] = img[patch_mask]
1429
-
1430
- else:
1431
- # otherwise set img value inside masks to mean
1432
- # FIXME: change when using other intensity based regionprops
1433
- img2 = np.zeros_like(img)
1434
- for reg in regionprops(mask, intensity_image=img):
1435
- m = mask[reg.slice] == reg.label
1436
- img2[reg.slice][m] = reg.mean_intensity
1437
- return img2
1438
-
1439
-
1440
- def pad_tensor(x, n_max: int, dim=0, value=0):
1441
- n = x.shape[dim]
1442
- if n_max < n:
1443
- raise ValueError(f"pad_tensor: n_max={n_max} must be larger than n={n} !")
1444
- pad_shape = list(x.shape)
1445
- pad_shape[dim] = n_max - n
1446
- # pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype).to(x.device)
1447
- pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype)
1448
- return torch.cat((x, pad), dim=dim)
1449
-
1450
-
1451
- def collate_sequence_padding(batch):
1452
- """Collate function that pads all sequences to the same length."""
1453
- lens = tuple(len(x["coords"]) for x in batch)
1454
- n_max_len = max(lens)
1455
- # print(tuple(len(x["coords"]) for x in batch))
1456
- # print(tuple(len(x["features"]) for x in batch))
1457
- # print(batch[0].keys())
1458
- tuple(batch[0].keys())
1459
- normal_keys = {
1460
- "coords": 0,
1461
- "features": 0,
1462
- "labels": 0, # Not needed, remove for speed.
1463
- "timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that.
1464
- }
1465
- n_pads = tuple(n_max_len - s for s in lens)
1466
- batch_new = dict(
1467
- (
1468
- k,
1469
- torch.stack(
1470
- [pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0
1471
- ),
1472
- )
1473
- for k, v in normal_keys.items()
1474
- )
1475
- batch_new["assoc_matrix"] = torch.stack(
1476
- [
1477
- pad_tensor(
1478
- pad_tensor(x["assoc_matrix"], n_max_len, dim=0), n_max_len, dim=1
1479
- )
1480
- for x in batch
1481
- ],
1482
- dim=0,
1483
- )
1484
-
1485
- # add boolean mask that signifies whether tokens are padded or not (such that they can be ignored later)
1486
- pad_mask = torch.zeros((len(batch), n_max_len), dtype=torch.bool)
1487
- for i, n_pad in enumerate(n_pads):
1488
- pad_mask[i, n_max_len - n_pad :] = True
1489
-
1490
- batch_new["padding_mask"] = pad_mask.bool()
1491
- return batch_new
1492
-
1493
-
1494
- if __name__ == "__main__":
1495
- dummy_data = CTCData(
1496
- root="../../scripts/data/synthetic_cells/01",
1497
- ndim=2,
1498
- detection_folders=["TRA"],
1499
- window_size=4,
1500
- max_tokens=None,
1501
- augment=3,
1502
- features="none",
1503
- downscale_temporal=1,
1504
- downscale_spatial=1,
1505
- sanity_dist=False,
1506
- crop_size=(256, 256),
1507
- )
1508
-
1509
- x = dummy_data[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/distributed.py DELETED
@@ -1,316 +0,0 @@
1
- """Data loading and sampling utils for distributed training."""
2
-
3
- import hashlib
4
- import json
5
- import logging
6
- import pickle
7
- # from collections.abc import Iterable
8
- from copy import deepcopy
9
- from pathlib import Path
10
- from timeit import default_timer
11
-
12
- import numpy as np
13
- import torch
14
- # from lightning import LightningDataModule
15
- from torch.utils.data import (
16
- BatchSampler,
17
- ConcatDataset,
18
- DataLoader,
19
- Dataset,
20
- DistributedSampler,
21
- )
22
- from typing import Optional, Iterable
23
- from .data import CTCData
24
-
25
- logger = logging.getLogger(__name__)
26
- logger.setLevel(logging.INFO)
27
-
28
-
29
- def cache_class(cachedir=None):
30
- """A simple file cache for CTCData."""
31
-
32
- def make_hashable(obj):
33
- if isinstance(obj, tuple | list):
34
- return tuple(make_hashable(e) for e in obj)
35
- elif isinstance(obj, Path):
36
- return obj.as_posix()
37
- elif isinstance(obj, dict):
38
- return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
39
- else:
40
- return obj
41
-
42
- def hash_args_kwargs(*args, **kwargs):
43
- hashable_args = tuple(make_hashable(arg) for arg in args)
44
- hashable_kwargs = make_hashable(kwargs)
45
- combined_serialized = json.dumps(
46
- [hashable_args, hashable_kwargs], sort_keys=True
47
- )
48
- hash_obj = hashlib.sha256(combined_serialized.encode())
49
- return hash_obj.hexdigest()
50
-
51
- if cachedir is None:
52
- return CTCData
53
- else:
54
- cachedir = Path(cachedir)
55
-
56
- def _wrapped(*args, **kwargs):
57
- h = hash_args_kwargs(*args, **kwargs)
58
- cachedir.mkdir(exist_ok=True, parents=True)
59
- cache_file = cachedir / f"{h}.pkl"
60
- if cache_file.exists():
61
- logger.info(f"Loading cached dataset from {cache_file}")
62
- with open(cache_file, "rb") as f:
63
- return pickle.load(f)
64
- else:
65
- c = CTCData(*args, **kwargs)
66
- logger.info(f"Saving cached dataset to {cache_file}")
67
- pickle.dump(c, open(cache_file, "wb"))
68
- return c
69
-
70
- return _wrapped
71
-
72
-
73
- class BalancedBatchSampler(BatchSampler):
74
- """samples batch indices such that the number of objects in each batch is balanced
75
- (so to reduce the number of paddings in the batch).
76
-
77
-
78
- """
79
-
80
- def __init__(
81
- self,
82
- dataset: torch.utils.data.Dataset,
83
- batch_size: int,
84
- n_pool: int = 10,
85
- num_samples: Optional[int] = None,
86
- weight_by_ndivs: bool = False,
87
- weight_by_dataset: bool = False,
88
- drop_last: bool = False,
89
- ):
90
- """Setting n_pool =1 will result in a regular random batch sampler.
91
-
92
- weight_by_ndivs: if True, the probability of sampling an element is proportional to the number of divisions
93
- weight_by_dataset: if True, the probability of sampling an element is inversely proportional to the length of the dataset
94
- """
95
- if isinstance(dataset, CTCData):
96
- self.n_objects = dataset.n_objects
97
- self.n_divs = np.array(dataset.n_divs)
98
- self.n_sizes = np.ones(len(dataset)) * len(dataset)
99
- elif isinstance(dataset, ConcatDataset):
100
- self.n_objects = tuple(n for d in dataset.datasets for n in d.n_objects)
101
- self.n_divs = np.array(tuple(n for d in dataset.datasets for n in d.n_divs))
102
- self.n_sizes = np.array(
103
- tuple(len(d) for d in dataset.datasets for _ in range(len(d)))
104
- )
105
- else:
106
- raise NotImplementedError(
107
- f"BalancedBatchSampler: Unknown dataset type {type(dataset)}"
108
- )
109
- assert len(self.n_objects) == len(self.n_divs) == len(self.n_sizes)
110
-
111
- self.batch_size = batch_size
112
- self.n_pool = n_pool
113
- self.drop_last = drop_last
114
- self.num_samples = num_samples
115
- self.weight_by_ndivs = weight_by_ndivs
116
- self.weight_by_dataset = weight_by_dataset
117
- logger.debug(f"{weight_by_ndivs=}")
118
- logger.debug(f"{weight_by_dataset=}")
119
-
120
- def get_probs(self, idx):
121
- idx = np.array(idx)
122
- if self.weight_by_ndivs:
123
- probs = 1 + np.sqrt(self.n_divs[idx])
124
- else:
125
- probs = np.ones(len(idx))
126
- if self.weight_by_dataset:
127
- probs = probs / (self.n_sizes[idx] + 1e-6)
128
-
129
- probs = probs / (probs.sum() + 1e-10)
130
- return probs
131
-
132
- def sample_batches(self, idx: Iterable[int]):
133
- # we will split the indices into pools of size n_pool
134
- num_samples = self.num_samples if self.num_samples is not None else len(idx)
135
- # sample from the indices with replacement and given probabilites
136
- idx = np.random.choice(idx, num_samples, replace=True, p=self.get_probs(idx))
137
-
138
- n_pool = min(
139
- self.n_pool * self.batch_size,
140
- (len(idx) // self.batch_size) * self.batch_size,
141
- )
142
-
143
- batches = []
144
- for i in range(0, len(idx), n_pool):
145
- # the indices in the pool are sorted by their number of objects
146
- idx_pool = idx[i : i + n_pool]
147
- idx_pool = sorted(idx_pool, key=lambda i: self.n_objects[i])
148
-
149
- # such that we can create batches where each element has a similar number of objects
150
- jj = np.arange(0, len(idx_pool), self.batch_size)
151
- np.random.shuffle(jj)
152
-
153
- for j in jj:
154
- # dont drop_last, as this leads to a lot of lightning problems....
155
- # if j + self.batch_size > len(idx_pool): # assume drop_last=True
156
- # continue
157
- batch = idx_pool[j : j + self.batch_size]
158
- batches.append(batch)
159
- return batches
160
-
161
- def __iter__(self):
162
- idx = np.arange(len(self.n_objects))
163
- batches = self.sample_batches(idx)
164
- return iter(batches)
165
-
166
- def __len__(self):
167
- if self.num_samples is not None:
168
- return self.num_samples // self.batch_size
169
- else:
170
- return len(self.n_objects) // self.batch_size
171
-
172
-
173
- class BalancedDistributedSampler(DistributedSampler):
174
- def __init__(
175
- self,
176
- dataset: Dataset,
177
- batch_size: int,
178
- n_pool: int,
179
- num_samples: int,
180
- weight_by_ndivs: bool = False,
181
- weight_by_dataset: bool = False,
182
- *args,
183
- **kwargs,
184
- ) -> None:
185
- super().__init__(dataset=dataset, *args, drop_last=True, **kwargs)
186
- self._balanced_batch_sampler = BalancedBatchSampler(
187
- dataset,
188
- batch_size=batch_size,
189
- n_pool=n_pool,
190
- num_samples=max(1, num_samples // self.num_replicas),
191
- weight_by_ndivs=weight_by_ndivs,
192
- weight_by_dataset=weight_by_dataset,
193
- )
194
-
195
- def __len__(self) -> int:
196
- if self.num_samples is not None:
197
- return self._balanced_batch_sampler.num_samples
198
- else:
199
- return super().__len__()
200
-
201
- def __iter__(self):
202
- indices = list(super().__iter__())
203
- batches = self._balanced_batch_sampler.sample_batches(indices)
204
- for batch in batches:
205
- yield from batch
206
-
207
-
208
- # class BalancedDataModule(LightningDataModule):
209
- # def __init__(
210
- # self,
211
- # input_train: list,
212
- # input_val: list,
213
- # cachedir: str,
214
- # augment: int,
215
- # distributed: bool,
216
- # dataset_kwargs: dict,
217
- # sampler_kwargs: dict,
218
- # loader_kwargs: dict,
219
- # ):
220
- # super().__init__()
221
- # self.input_train = input_train
222
- # self.input_val = input_val
223
- # self.cachedir = cachedir
224
- # self.augment = augment
225
- # self.distributed = distributed
226
- # self.dataset_kwargs = dataset_kwargs
227
- # self.sampler_kwargs = sampler_kwargs
228
- # self.loader_kwargs = loader_kwargs
229
-
230
- # def prepare_data(self):
231
- # """Loads and caches the datasets if not already done.
232
-
233
- # Running on the main CPU process.
234
- # """
235
- # CTCData = cache_class(self.cachedir)
236
- # datasets = dict()
237
- # for split, inps in zip(
238
- # ("train", "val"),
239
- # (self.input_train, self.input_val),
240
- # ):
241
- # logger.info(f"Loading {split.upper()} data")
242
- # start = default_timer()
243
- # datasets[split] = torch.utils.data.ConcatDataset(
244
- # CTCData(
245
- # root=Path(inp),
246
- # augment=self.augment if split == "train" else 0,
247
- # **self.dataset_kwargs,
248
- # )
249
- # for inp in inps
250
- # )
251
- # logger.info(
252
- # f"Loaded {len(datasets[split])} {split.upper()} samples (in"
253
- # f" {(default_timer() - start):.1f} s)\n\n"
254
- # )
255
-
256
- # del datasets
257
-
258
- # def setup(self, stage: str):
259
- # CTCData = cache_class(self.cachedir)
260
- # self.datasets = dict()
261
- # for split, inps in zip(
262
- # ("train", "val"),
263
- # (self.input_train, self.input_val),
264
- # ):
265
- # logger.info(f"Loading {split.upper()} data")
266
- # start = default_timer()
267
- # self.datasets[split] = torch.utils.data.ConcatDataset(
268
- # CTCData(
269
- # root=Path(inp),
270
- # augment=self.augment if split == "train" else 0,
271
- # **self.dataset_kwargs,
272
- # )
273
- # for inp in inps
274
- # )
275
- # logger.info(
276
- # f"Loaded {len(self.datasets[split])} {split.upper()} samples (in"
277
- # f" {(default_timer() - start):.1f} s)\n\n"
278
- # )
279
-
280
- # def train_dataloader(self):
281
- # loader_kwargs = self.loader_kwargs.copy()
282
- # if self.distributed:
283
- # sampler = BalancedDistributedSampler(
284
- # self.datasets["train"],
285
- # **self.sampler_kwargs,
286
- # )
287
- # batch_sampler = None
288
- # else:
289
- # sampler = None
290
- # batch_sampler = BalancedBatchSampler(
291
- # self.datasets["train"],
292
- # **self.sampler_kwargs,
293
- # )
294
- # if not loader_kwargs["batch_size"] == batch_sampler.batch_size:
295
- # raise ValueError(
296
- # f"Batch size in loader_kwargs ({loader_kwargs['batch_size']}) and sampler_kwargs ({batch_sampler.batch_size}) must match"
297
- # )
298
- # del loader_kwargs["batch_size"]
299
-
300
- # loader = DataLoader(
301
- # self.datasets["train"],
302
- # sampler=sampler,
303
- # batch_sampler=batch_sampler,
304
- # **loader_kwargs,
305
- # )
306
- # return loader
307
-
308
- # def val_dataloader(self):
309
- # val_loader_kwargs = deepcopy(self.loader_kwargs)
310
- # val_loader_kwargs["persistent_workers"] = False
311
- # val_loader_kwargs["num_workers"] = 1
312
- # return DataLoader(
313
- # self.datasets["val"],
314
- # shuffle=False,
315
- # **val_loader_kwargs,
316
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/example_data.py DELETED
@@ -1,48 +0,0 @@
1
- from pathlib import Path
2
-
3
- import tifffile
4
-
5
- root = Path(__file__).parent / "resources"
6
-
7
-
8
- def example_data_bacteria():
9
- """Bacteria images and masks from.
10
-
11
- Van Vliet et al. Spatially Correlated Gene Expression in Bacterial Groups: The Role of Lineage History, Spatial Gradients, and Cell-Cell Interactions (2018)
12
- https://doi.org/10.1016/j.cels.2018.03.009
13
-
14
- subset of timelapse trpL/150310-11
15
- """
16
- img = tifffile.imread(root / "trpL_150310-11_img.tif")
17
- mask = tifffile.imread(root / "trpL_150310-11_mask.tif")
18
- return img, mask
19
-
20
-
21
- def example_data_hela():
22
- """Hela data from the cell tracking challenge.
23
-
24
- Neumann et al. Phenotypic profiling of the human genome by time-lapse microscopy reveals cell division genes (2010)
25
-
26
- subset of Fluo-N2DL-HeLa/train/02
27
- """
28
- img = tifffile.imread(root / "Fluo_Hela_02_img.tif")
29
- mask = tifffile.imread(root / "Fluo_Hela_02_ERR_SEG.tif")
30
- print(img.shape, mask.shape)
31
- return img, mask
32
-
33
-
34
- def example_data_fluo_3d():
35
- """Fluo-N3DH-CHO data from the cell tracking challenge.
36
-
37
- Dzyubachyk et al. Advanced Level-Set-Based Cell Tracking in Time-Lapse Fluorescence Microscopy (2010)
38
-
39
- subset of Fluo-N3DH-CHO/train/02
40
- """
41
- img = tifffile.imread(root / "Fluo-N3DH-CHO_02_img.tif")
42
- mask = tifffile.imread(root / "Fluo-N3DH-CHO_02_ERR_SEG.tif")
43
- return img, mask
44
-
45
- def data_hela():
46
- img = tifffile.imread("02_imgs.tif")
47
- mask = tifffile.imread("02_masks.tif")
48
- return img, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/features.py DELETED
@@ -1,148 +0,0 @@
1
- import itertools
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from skimage.measure import regionprops_table
6
-
7
- # the property keys that are supported for 2 and 3 dim
8
-
9
- _PROPERTIES = {
10
- 2: {
11
- # FIXME: The only image regionprop possible now (when compressing) is mean_intensity,
12
- # since we store a mask with the mean intensity of each detection as the image.
13
- "regionprops": (
14
- "label",
15
- "area",
16
- "intensity_mean",
17
- "eccentricity",
18
- "solidity",
19
- "inertia_tensor",
20
- ),
21
- # faster
22
- "regionprops2": (
23
- "label",
24
- "area",
25
- "intensity_mean",
26
- "inertia_tensor",
27
- ),
28
- "patch_regionprops": (
29
- "label",
30
- "area",
31
- "intensity_mean",
32
- "inertia_tensor",
33
- ),
34
- },
35
- 3: {
36
- "regionprops2": (
37
- "label",
38
- "area",
39
- "intensity_mean",
40
- "inertia_tensor",
41
- ),
42
- "patch_regionprops": (
43
- "label",
44
- "area",
45
- "intensity_mean",
46
- "inertia_tensor",
47
- ),
48
- },
49
- }
50
-
51
-
52
- def extract_features_regionprops(
53
- mask: np.ndarray,
54
- img: np.ndarray,
55
- labels: np.ndarray,
56
- properties="regionprops2",
57
- ):
58
- ndim = mask.ndim
59
- assert ndim in (2, 3)
60
- assert mask.shape == img.shape
61
-
62
- prop_dict = _PROPERTIES[ndim]
63
- if properties not in prop_dict:
64
- raise ValueError(f"properties must be one of {prop_dict.keys()}")
65
- properties_tuple = prop_dict[properties]
66
-
67
- assert properties_tuple[0] == "label"
68
-
69
- labels = np.asarray(labels)
70
-
71
- # remove mask labels that are not present
72
- # not needed, remove for speed
73
- # mask[~np.isin(mask, labels)] = 0
74
-
75
- df = pd.DataFrame(
76
- regionprops_table(mask, intensity_image=img, properties=properties_tuple)
77
- )
78
- assert df.columns[0] == "label"
79
- assert df.columns[1] == "area"
80
-
81
- # the bnumber of inertia tensor columns depends on the dimensionality
82
- n_cols_inertia = ndim**2
83
- assert np.all(["inertia_tensor" in col for col in df.columns[-n_cols_inertia:]])
84
-
85
- # Hack for backwards compatibility
86
- if properties in ("regionprops", "patch_regionprops"):
87
- # Nice for conceptual clarity, but does not matter for speed
88
- # drop upper triangular part of symmetric inertia tensor
89
- for i, j in itertools.product(range(ndim), repeat=2):
90
- if i > j:
91
- df.drop(f"inertia_tensor-{i}-{j}", axis=1, inplace=True)
92
-
93
- table = df.to_numpy()
94
- table[:, 1] *= 0.001
95
- table[:, -n_cols_inertia:] *= 0.01
96
- # reorder according to labels
97
- features = np.zeros((len(labels), len(df.columns) - 1))
98
-
99
- # faster than iterating over pandas dataframe
100
- for row in table:
101
- # old version with tuple indexing, slow.
102
- # n = labels.index(int(row.label))
103
- # features[n] = row.to_numpy()[1:]
104
-
105
- # Only process regions present in the labels
106
- n = np.where(labels == int(row[0]))[0]
107
- if len(n) > 0:
108
- # Remove label column (0)!
109
- features[n[0]] = row[1:]
110
-
111
- return features
112
-
113
-
114
- def extract_features_patch(
115
- mask: np.ndarray,
116
- img: np.ndarray,
117
- coords: np.ndarray,
118
- labels: np.ndarray,
119
- width_patch: int = 16,
120
- ):
121
- """16x16 Image patch around detection."""
122
- ndim = mask.ndim
123
- assert ndim in (2, 3) and mask.shape == img.shape
124
- if len(coords) == 0:
125
- return np.zeros((0, width_patch * width_patch))
126
-
127
- pads = (width_patch // 2,) * ndim
128
-
129
- img = np.pad(
130
- img,
131
- tuple((p, p) for p in pads),
132
- mode="constant",
133
- )
134
-
135
- coords = coords.astype(int) + np.array(pads)
136
-
137
- ss = tuple(
138
- tuple(slice(_c - width_patch // 2, _c + width_patch // 2) for _c in c)
139
- for c in coords
140
- )
141
- fs = tuple(img[_s] for _s in ss)
142
-
143
- # max project along z if 3D
144
- if ndim == 3:
145
- fs = tuple(f.max(0) for f in fs)
146
-
147
- features = np.stack([f.flatten() for f in fs])
148
- return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/matching.py DELETED
@@ -1,251 +0,0 @@
1
- # Adapted from https://github.com/stardist/stardist/blob/master/stardist/matching.py
2
-
3
- import numpy as np
4
- from numba import jit
5
- from scipy.optimize import linear_sum_assignment
6
- from scipy.spatial.distance import cdist
7
- from skimage.measure import regionprops
8
-
9
- matching_criteria = dict()
10
-
11
-
12
- def label_are_sequential(y):
13
- """Returns true if y has only sequential labels from 1..."""
14
- labels = np.unique(y)
15
- return (set(labels) - {0}) == set(range(1, 1 + labels.max()))
16
-
17
-
18
- def is_array_of_integers(y):
19
- return isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer)
20
-
21
-
22
- def _check_label_array(y, name=None, check_sequential=False):
23
- err = ValueError(
24
- "{label} must be an array of {integers}.".format(
25
- label="labels" if name is None else name,
26
- integers=("sequential " if check_sequential else "")
27
- + "non-negative integers",
28
- )
29
- )
30
-
31
- if not is_array_of_integers(y):
32
- raise err
33
- if len(y) == 0:
34
- return True
35
- if check_sequential and not label_are_sequential(y):
36
- raise err
37
- else:
38
- if not y.min() >= 0:
39
- raise err
40
- return True
41
-
42
-
43
- def label_overlap(x, y, check=True):
44
- if check:
45
- _check_label_array(x, "x", True)
46
- _check_label_array(y, "y", True)
47
- if not x.shape == y.shape:
48
- raise ValueError("x and y must have the same shape")
49
- return _label_overlap(x, y)
50
-
51
-
52
- @jit(nopython=True)
53
- def _label_overlap(x, y):
54
- x = x.ravel()
55
- y = y.ravel()
56
- overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint32)
57
- for i in range(len(x)):
58
- overlap[x[i], y[i]] += 1
59
- return overlap[1:, 1:]
60
-
61
-
62
- def _safe_divide(x, y, eps=1e-10):
63
- """Computes a safe divide which returns 0 if y is zero."""
64
- if np.isscalar(x) and np.isscalar(y):
65
- return x / y if np.abs(y) > eps else 0.0
66
- else:
67
- out = np.zeros(np.broadcast(x, y).shape, np.float32)
68
- np.divide(x, y, out=out, where=np.abs(y) > eps)
69
- return out
70
-
71
-
72
- def intersection_over_union(overlap):
73
- _check_label_array(overlap, "overlap")
74
- if np.sum(overlap) == 0:
75
- return overlap
76
- n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
77
- n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
78
- return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap))
79
-
80
-
81
- def dist_score(y_true, y_pred, max_distance: int = 10):
82
- """Compute distance score between centroids of regions in y_true and y_pred
83
- and returns a score matrix of shape (n_true, n_pred) with values in [0,1]
84
- where
85
- distance >= max_distance -> score = 0
86
- distance = 0 -> score = 1.
87
- """
88
- c_true = np.stack([r.centroid for r in regionprops(y_true)], axis=0)
89
- c_pred = np.stack([r.centroid for r in regionprops(y_pred)], axis=0)
90
- dist = np.minimum(cdist(c_true, c_pred), max_distance)
91
- score = 1 - dist / max_distance
92
- return score
93
-
94
-
95
- # copied from scikit-image master for now (remove when part of a release)
96
- def relabel_sequential(label_field, offset=1):
97
- """Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.
98
-
99
- This function also returns the forward map (mapping the original labels to
100
- the reduced labels) and the inverse map (mapping the reduced labels back
101
- to the original ones).
102
-
103
- Parameters
104
- ----------
105
- label_field : numpy array of int, arbitrary shape
106
- An array of labels, which must be non-negative integers.
107
- offset : int, optional
108
- The return labels will start at `offset`, which should be
109
- strictly positive.
110
-
111
- Returns:
112
- -------
113
- relabeled : numpy array of int, same shape as `label_field`
114
- The input label field with labels mapped to
115
- {offset, ..., number_of_labels + offset - 1}.
116
- The data type will be the same as `label_field`, except when
117
- offset + number_of_labels causes overflow of the current data type.
118
- forward_map : numpy array of int, shape ``(label_field.max() + 1,)``
119
- The map from the original label space to the returned label
120
- space. Can be used to re-apply the same mapping. See examples
121
- for usage. The data type will be the same as `relabeled`.
122
- inverse_map : 1D numpy array of int, of length offset + number of labels
123
- The map from the new label space to the original space. This
124
- can be used to reconstruct the original label field from the
125
- relabeled one. The data type will be the same as `relabeled`.
126
-
127
- Notes:
128
- -----
129
- The label 0 is assumed to denote the background and is never remapped.
130
-
131
- The forward map can be extremely big for some inputs, since its
132
- length is given by the maximum of the label field. However, in most
133
- situations, ``label_field.max()`` is much smaller than
134
- ``label_field.size``, and in these cases the forward map is
135
- guaranteed to be smaller than either the input or output images.
136
-
137
- Examples:
138
- --------
139
- >>> from skimage.segmentation import relabel_sequential
140
- >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])
141
- >>> relab, fw, inv = relabel_sequential(label_field)
142
- >>> relab
143
- array([1, 1, 2, 2, 3, 5, 4])
144
- >>> fw
145
- array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
146
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,
147
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
148
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
149
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])
150
- >>> inv
151
- array([ 0, 1, 5, 8, 42, 99])
152
- >>> (fw[label_field] == relab).all()
153
- True
154
- >>> (inv[relab] == label_field).all()
155
- True
156
- >>> relab, fw, inv = relabel_sequential(label_field, offset=5)
157
- >>> relab
158
- array([5, 5, 6, 6, 7, 9, 8])
159
- """
160
- offset = int(offset)
161
- if offset <= 0:
162
- raise ValueError("Offset must be strictly positive.")
163
- if np.min(label_field) < 0:
164
- raise ValueError("Cannot relabel array that contains negative values.")
165
- max_label = int(label_field.max()) # Ensure max_label is an integer
166
- if not np.issubdtype(label_field.dtype, np.integer):
167
- new_type = np.min_scalar_type(max_label)
168
- label_field = label_field.astype(new_type)
169
- labels = np.unique(label_field)
170
- labels0 = labels[labels != 0]
171
- new_max_label = offset - 1 + len(labels0)
172
- new_labels0 = np.arange(offset, new_max_label + 1)
173
- output_type = label_field.dtype
174
- required_type = np.min_scalar_type(new_max_label)
175
- if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:
176
- output_type = required_type
177
- forward_map = np.zeros(max_label + 1, dtype=output_type)
178
- forward_map[labels0] = new_labels0
179
- inverse_map = np.zeros(new_max_label + 1, dtype=output_type)
180
- inverse_map[offset:] = labels0
181
- relabeled = forward_map[label_field]
182
- return relabeled, forward_map, inverse_map
183
-
184
-
185
- def matching(y_true, y_pred, threshold=0.5, max_distance: int = 16):
186
- """Computes IoU and distance score between all pairs of regions in y_true and y_pred.
187
-
188
- returns the true/pred matching based on the higher of the two scores for each pair of regions
189
-
190
- Parameters
191
- ----------
192
- y_true: ndarray
193
- ground truth label image (integer valued)
194
- y_pred: ndarray
195
- predicted label image (integer valued)
196
- threshold: float
197
- threshold for matching criterion (default 0.5)
198
- max_distance: int
199
- maximum distance between centroids of regions in y_true and y_pred (default 16)
200
-
201
- Returns:
202
- -------
203
- gt_pred: tuple
204
- tuple of all matched region label pairs in y_true and y_pred
205
-
206
-
207
- """
208
- y_true, y_pred = y_true.astype(np.int32), y_pred.astype(np.int32)
209
- _check_label_array(y_true, "y_true")
210
- _check_label_array(y_pred, "y_pred")
211
- if not y_true.shape == y_pred.shape:
212
- raise ValueError(
213
- f"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes"
214
- )
215
- if threshold is None:
216
- threshold = 0
217
-
218
- threshold = float(threshold) if np.isscalar(threshold) else map(float, threshold)
219
-
220
- y_true, _, map_rev_true = relabel_sequential(y_true)
221
- y_pred, _, map_rev_pred = relabel_sequential(y_pred)
222
-
223
- overlap = label_overlap(y_true, y_pred, check=False)
224
-
225
- scores_iou = intersection_over_union(overlap)
226
- scores_dist = dist_score(y_true, y_pred, max_distance)
227
- scores = np.maximum(scores_iou, scores_dist)
228
-
229
- assert 0 <= np.min(scores) <= np.max(scores) <= 1
230
-
231
- n_true, n_pred = scores.shape
232
- n_matched = min(n_true, n_pred)
233
-
234
- # not_trivial = n_matched > 0 and np.any(scores >= thr)
235
- not_trivial = n_matched > 0
236
- if not_trivial:
237
- # compute optimal matching with scores as tie-breaker
238
- costs = -(scores >= threshold).astype(float) - scores / (2 * n_matched)
239
- true_ind, pred_ind = linear_sum_assignment(costs)
240
- assert n_matched == len(true_ind) == len(pred_ind)
241
- match_ok = scores[true_ind, pred_ind] >= threshold
242
- true_ind = true_ind[match_ok]
243
- pred_ind = pred_ind[match_ok]
244
- matched = tuple(
245
- (int(map_rev_true[i]), int(map_rev_pred[j]))
246
- for i, j in zip(1 + true_ind, 1 + pred_ind)
247
- )
248
- else:
249
- matched = ()
250
-
251
- return matched
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/data/utils.py DELETED
@@ -1,232 +0,0 @@
1
- import logging
2
- import sys
3
- from pathlib import Path
4
-
5
- import numpy as np
6
- import pandas as pd
7
-
8
- # from .data import CTCData
9
- import tifffile
10
- from tqdm import tqdm
11
- from typing import Optional, Union, Tuple
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def load_tiff_timeseries(
17
- dir: Path,
18
- dtype: Optional[Union[str, type]] = None,
19
- downscale: Optional[Tuple[int, ...]] = None,
20
- start_frame: int = 0,
21
- end_frame: Optional[int] = None,
22
- ) -> np.ndarray:
23
- """Loads a folder of `.tif` or `.tiff` files into a numpy array.
24
- Each file is interpreted as a frame of a time series.
25
-
26
- Args:
27
- folder:
28
- dtype:
29
- downscale: One int for each dimension of the data. Avoids memory overhead.
30
- start_frame: The first frame to load.
31
- end_frame: The last frame to load.
32
-
33
- Returns:
34
- np.ndarray: The loaded data.
35
- """
36
- # TODO make safe for label arrays
37
- logger.debug(f"Loading tiffs from {dir} as {dtype}")
38
- files = sorted(list(dir.glob("*.tif")) + list(dir.glob("*.tiff")))[
39
- start_frame:end_frame
40
- ]
41
- shape = tifffile.imread(files[0]).shape
42
- if downscale:
43
- assert len(downscale) == len(shape)
44
- else:
45
- downscale = (1,) * len(shape)
46
-
47
- files = files[:: downscale[0]]
48
-
49
- x = []
50
- for f in tqdm(
51
- files,
52
- leave=False,
53
- desc=f"Loading [{start_frame}:{end_frame}:{downscale[0]}]",
54
- ):
55
- _x = tifffile.imread(f)
56
- if dtype:
57
- _x = _x.astype(dtype)
58
- assert _x.shape == shape
59
- slices = tuple(slice(None, None, d) for d in downscale[1:])
60
- _x = _x[slices]
61
- x.append(_x)
62
-
63
- x = np.stack(x)
64
- logger.debug(f"Loaded array of shape {x.shape} from {dir}")
65
- return x
66
-
67
-
68
- def load_tracklet_links(folder: Path) -> pd.DataFrame:
69
- candidates = [
70
- folder / "man_track.txt",
71
- folder / "res_track.txt",
72
- ]
73
- for c in candidates:
74
- if c.exists():
75
- path = c
76
- break
77
- else:
78
- raise FileNotFoundError(f"Could not find tracklet links in {folder}")
79
-
80
- df = pd.read_csv(
81
- path,
82
- delimiter=" ",
83
- names=["label", "t1", "t2", "parent"],
84
- dtype=int,
85
- )
86
- # Remove invalid tracks with t2 > t1
87
- df = df[df.t1 <= df.t2]
88
-
89
- n_dets = (df.t2 - df.t1 + 1).sum()
90
- logger.debug(f"{folder} has {n_dets} detections")
91
-
92
- n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum()
93
- logger.debug(f"{folder} has {n_divs} divisions")
94
- return df
95
-
96
-
97
- def filter_track_df(
98
- df: pd.DataFrame,
99
- start_frame: int = 0,
100
- end_frame: int = sys.maxsize,
101
- downscale: int = 1,
102
- ) -> pd.DataFrame:
103
- """Only keep tracklets that are present in the given time interval."""
104
- df.columns = ["label", "t1", "t2", "parent"]
105
- # only retain cells in interval
106
- df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)]
107
-
108
- # shift start and end of each cell
109
- df.t1 = df.t1 - start_frame
110
- df.t2 = df.t2 - start_frame
111
- # set start/end to min/max
112
- df.t1 = df.t1.clip(0, end_frame - start_frame - 1)
113
- df.t2 = df.t2.clip(0, end_frame - start_frame - 1)
114
- # set all parents to 0 that are not in the interval
115
- df.loc[~df.parent.isin(df.label), "parent"] = 0
116
-
117
- if downscale > 1:
118
- if start_frame % downscale != 0:
119
- raise ValueError("start_frame must be a multiple of downscale")
120
-
121
- logger.debug(f"Temporal downscaling of tracklet links by {downscale}")
122
-
123
- # remove tracklets that have been fully deleted by temporal downsampling
124
-
125
- mask = (
126
- # (df["t2"] - df["t1"] < downscale - 1)
127
- (df["t1"] % downscale != 0)
128
- & (df["t2"] % downscale != 0)
129
- & (df["t1"] // downscale == df["t2"] // downscale)
130
- )
131
- logger.debug(
132
- f"Remove {mask.sum()} tracklets that are fully deleted by downsampling"
133
- )
134
- logger.debug(f"Remove {df[mask]}")
135
-
136
- df = df[~mask]
137
- # set parent to 0 if it has been deleted
138
- df.loc[~df.parent.isin(df.label), "parent"] = 0
139
-
140
- df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int)
141
- df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int)
142
-
143
- # Correct for edge case of single frame tracklet
144
- assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"]))
145
-
146
- return df
147
-
148
-
149
- # TODO fix
150
- # def dataset_to_ctc(dataset: CTCData, path, start: int = 0, stop: int | None = None):
151
- # """save dataset to ctc format for debugging purposes"""
152
- # out = Path(path)
153
- # print(f"Saving dataset to {out}")
154
- # out_img = out / "img"
155
- # out_img.mkdir(exist_ok=True, parents=True)
156
- # out_mask = out / "TRA"
157
- # out_mask.mkdir(exist_ok=True, parents=True)
158
- # if stop is None:
159
- # stop = len(self)
160
- # lines = []
161
- # masks, imgs = [], []
162
- # t_offset = 0
163
- # max_mask = 0
164
- # n_lines = 0
165
- # all_coords = []
166
- # for i in tqdm(range(start, stop)):
167
- # d = dataset.__getitem__(i, return_dense=True)
168
- # mask = d["mask"].numpy()
169
- # mask[mask > 0] += max_mask
170
- # max_mask = max(max_mask, mask.max())
171
- # masks.extend(mask)
172
- # imgs.extend(d["img"].numpy())
173
- # # add vertices
174
- # coords = d["coords0"].numpy()
175
- # ts, coords = coords[:, 0].astype(int), coords[:, 1:]
176
- # A = d["assoc_matrix"].numpy()
177
- # t_unique = sorted(np.unique(ts))
178
- # for t1, t2 in zip(t_unique[:-1], t_unique[1:]):
179
- # A_sub = A[ts == t1][:, ts == t2]
180
- # for i, a in enumerate(A_sub):
181
-
182
- # v1 = coords[ts == t1][i]
183
- # for j in np.where(a > 0)[0]:
184
- # v2 = coords[ts == t2][j]
185
- # # lines.append(
186
- # # {
187
- # # "index": n_lines,
188
- # # "shape-type": "line",
189
- # # "vertex-index": 0,
190
- # # "axis-0": t2 + t_offset,
191
- # # "axis-1": v1[0],
192
- # # "axis-2": v1[1],
193
- # # }
194
- # # )
195
- # # lines.append(
196
- # # {
197
- # # "index": n_lines,
198
- # # "shape-type": "line",
199
- # # "vertex-index": 1,
200
- # # "axis-0": t2 + t_offset,
201
- # # "axis-1": v2[0],
202
- # # "axis-2": v2[1],
203
- # # }
204
- # # )
205
- # lines.append([n_lines, "line", 0, t2 + t_offset] + v1.tolist())
206
- # lines.append([n_lines, "line", 1, t2 + t_offset] + v2.tolist())
207
- # n_lines += 1
208
-
209
- # c = d["coords0"].numpy()
210
- # c[:, 0] += t_offset
211
- # all_coords.extend(c)
212
- # t_offset += len(mask)
213
-
214
- # ax_cols = [f"axis-{i}" for i in range(dataset.ndim + 1)]
215
- # df = pd.DataFrame(lines, columns=["index", "shape-type", "vertex-index"] + ax_cols)
216
- # df.to_csv(out / "lines.csv", index=False)
217
-
218
- # df_c = pd.DataFrame(all_coords, columns=ax_cols)
219
- # df_c.to_csv(out / "coords.csv", index=False)
220
-
221
- # for i, m in enumerate(imgs):
222
- # # tifffile.imwrite(out_img/f'img_{i:04d}.tif', m)
223
- # if dataset.ndim == 2:
224
- # imageio.imwrite(
225
- # out_img / f"img_{i:04d}.jpg",
226
- # np.clip(20 + 100 * m, 0, 255).astype(np.uint8),
227
- # )
228
-
229
- # for i, m in enumerate(masks):
230
- # tifffile.imwrite(out_mask / f"mask_{i:04d}.tif", m, compression="zstd")
231
-
232
- # return d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # ruff: noqa: F401
2
-
3
- from .model import TrackingTransformer
4
- from .model_api import Trackastra
 
 
 
 
 
models/tra_post_model/trackastra/model/model_api.py DELETED
@@ -1,338 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Literal, Union, Optional, Tuple
5
-
6
- import dask.array as da
7
- import numpy as np
8
- import tifffile
9
- import torch
10
- import yaml
11
- from tqdm import tqdm
12
-
13
- from ..data import build_windows, get_features, load_tiff_timeseries
14
- from ..tracking import TrackGraph, build_graph, track_greedy
15
- from ..utils import normalize
16
- from .model import TrackingTransformer
17
- from .predict import predict_windows
18
- from .pretrained import download_pretrained
19
-
20
- logging.basicConfig(level=logging.INFO)
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- class Trackastra:
25
- """A transformer-based tracking model for time-lapse data.
26
-
27
- Trackastra links segmented objects across time frames by predicting
28
- associations with a transformer model trained on diverse time-lapse videos.
29
-
30
- The model takes as input:
31
- - A sequence of images of shape (T,(Z),Y,X)
32
- - Corresponding instance segmentation masks of shape (T,(Z),Y,X)
33
-
34
- It supports multiple tracking modes:
35
- - greedy_nodiv: Fast greedy linking without division
36
- - greedy: Fast greedy linking with division
37
- - ilp: Integer Linear Programming based linking (more accurate but slower)
38
-
39
- Examples:
40
- >>> # Load example data
41
- >>> from trackastra.data import example_data_bacteria
42
- >>> imgs, masks = example_data_bacteria()
43
- >>>
44
- >>> # Load pretrained model and track
45
- >>> model = Trackastra.from_pretrained("general_2d", device="cuda")
46
- >>> track_graph = model.track(imgs, masks, mode="greedy")
47
- """
48
-
49
- def __init__(
50
- self,
51
- transformer: TrackingTransformer,
52
- train_args: dict,
53
- device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
54
- ):
55
- """Initialize Trackastra model.
56
-
57
- Args:
58
- transformer: The underlying transformer model.
59
- train_args: Training configuration arguments.
60
- device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
61
- """
62
- if device == "cuda":
63
- if torch.cuda.is_available():
64
- self.device = "cuda"
65
- else:
66
- logger.info("Cuda not available, falling back to cpu.")
67
- self.device = "cpu"
68
- elif device == "mps":
69
- if (
70
- torch.backends.mps.is_available()
71
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
72
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
73
- ):
74
- self.device = "mps"
75
- else:
76
- logger.info("Mps not available, falling back to cpu.")
77
- self.device = "cpu"
78
- elif device == "cpu":
79
- self.device = "cpu"
80
- elif device == "automatic" or device is None:
81
- should_use_mps = (
82
- torch.backends.mps.is_available()
83
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
84
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
85
- )
86
- self.device = (
87
- "cuda"
88
- if torch.cuda.is_available()
89
- else (
90
- "mps"
91
- if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
92
- else "cpu"
93
- )
94
- )
95
- else:
96
- raise ValueError(f"Device {device} not recognized.")
97
-
98
- logger.info(f"Using device {self.device}")
99
-
100
- self.transformer = transformer.to(self.device)
101
- self.train_args = train_args
102
-
103
- @classmethod
104
- def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None):
105
- """Load a Trackastra model from a local folder.
106
-
107
- Args:
108
- dir: Path to model folder containing:
109
- - model weights
110
- - train_config.yaml with training arguments
111
- device: Device to run model on.
112
-
113
- Returns:
114
- Trackastra model instance.
115
- """
116
- # Always load to cpu first
117
- transformer = TrackingTransformer.from_folder(
118
- Path(dir).expanduser(), map_location="cpu"
119
- )
120
- train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
121
- return cls(transformer=transformer, train_args=train_args, device=device)
122
-
123
- @classmethod
124
- def from_pretrained(
125
- cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None
126
- ):
127
- """Load a pretrained Trackastra model.
128
-
129
- Available pretrained models are described in detail in pretrained.json.
130
-
131
- Args:
132
- name: Name of pretrained model (e.g. "general_2d").
133
- device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
134
- download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
135
-
136
- Returns:
137
- Trackastra model instance.
138
- """
139
- folder = download_pretrained(name, download_dir)
140
- # download zip from github to location/name, then unzip
141
- return cls.from_folder(folder, device=device)
142
-
143
- def _predict(
144
- self,
145
- imgs: Union[np.ndarray, da.Array],
146
- masks: Union[np.ndarray, da.Array],
147
- edge_threshold: float = 0.05,
148
- n_workers: int = 0,
149
- normalize_imgs: bool = True,
150
- progbar_class=tqdm,
151
- ):
152
- logger.info("Predicting weights for candidate graph")
153
- if normalize_imgs:
154
- if isinstance(imgs, da.Array):
155
- imgs = imgs.map_blocks(normalize)
156
- else:
157
- imgs = normalize(imgs)
158
-
159
- self.transformer.eval()
160
-
161
- features = get_features(
162
- detections=masks,
163
- imgs=imgs,
164
- ndim=self.transformer.config["coord_dim"],
165
- n_workers=n_workers,
166
- progbar_class=progbar_class,
167
- )
168
- logger.info("Building windows")
169
- windows = build_windows(
170
- features,
171
- window_size=self.transformer.config["window"],
172
- progbar_class=progbar_class,
173
- )
174
-
175
- logger.info("Predicting windows")
176
- predictions = predict_windows(
177
- windows=windows,
178
- features=features,
179
- model=self.transformer,
180
- edge_threshold=edge_threshold,
181
- spatial_dim=masks.ndim - 1,
182
- progbar_class=progbar_class,
183
- )
184
-
185
- return predictions
186
-
187
- def _track_from_predictions(
188
- self,
189
- predictions,
190
- mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
191
- use_distance: bool = False,
192
- max_distance: int = 256,
193
- max_neighbors: int = 10,
194
- delta_t: int = 1,
195
- **kwargs,
196
- ):
197
- logger.info("Running greedy tracker")
198
- nodes = predictions["nodes"]
199
- weights = predictions["weights"]
200
-
201
- candidate_graph = build_graph(
202
- nodes=nodes,
203
- weights=weights,
204
- use_distance=use_distance,
205
- max_distance=max_distance,
206
- max_neighbors=max_neighbors,
207
- delta_t=delta_t,
208
- )
209
- if mode == "greedy":
210
- return track_greedy(candidate_graph)
211
- elif mode == "greedy_nodiv":
212
- return track_greedy(candidate_graph, allow_divisions=False)
213
- elif mode == "ilp":
214
- from trackastra.tracking.ilp import track_ilp
215
-
216
- return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
217
- else:
218
- raise ValueError(f"Tracking mode {mode} does not exist.")
219
-
220
- def track(
221
- self,
222
- imgs: Union[np.ndarray, da.Array],
223
- masks: Union[np.ndarray, da.Array],
224
- mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
225
- normalize_imgs: bool = True,
226
- progbar_class=tqdm,
227
- n_workers: int = 0,
228
- **kwargs,
229
- ) -> TrackGraph:
230
- """Track objects across time frames.
231
-
232
- This method links segmented objects across time frames using the specified
233
- tracking mode. No hyperparameters need to be chosen beyond the tracking mode.
234
-
235
- Args:
236
- imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array)
237
- masks: Instance segmentation masks of shape (T,(Z),Y,X).
238
- mode: Tracking mode:
239
- - "greedy_nodiv": Fast greedy linking without division
240
- - "greedy": Fast greedy linking with division
241
- - "ilp": Integer Linear Programming based linking (more accurate but slower)
242
- progbar_class: Progress bar class to use.
243
- n_workers: Number of worker processes for feature extraction.
244
- normalize_imgs: Whether to normalize the images.
245
- **kwargs: Additional arguments passed to tracking algorithm.
246
-
247
- Returns:
248
- TrackGraph containing the tracking results.
249
- """
250
- if not imgs.shape == masks.shape:
251
- raise RuntimeError(
252
- f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
253
- )
254
-
255
- if not imgs.ndim == self.transformer.config["coord_dim"] + 1:
256
- raise RuntimeError(
257
- f"images should be a sequence of {self.transformer.config['coord_dim']}D images"
258
- )
259
-
260
- predictions = self._predict(
261
- imgs,
262
- masks,
263
- normalize_imgs=normalize_imgs,
264
- progbar_class=progbar_class,
265
- n_workers=n_workers,
266
- )
267
- track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
268
- return track_graph
269
-
270
- def track_from_disk(
271
- self,
272
- imgs_path: Path,
273
- masks_path: Path,
274
- mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
275
- normalize_imgs: bool = True,
276
- **kwargs,
277
- ) -> Tuple[TrackGraph, np.ndarray]:
278
- """Track objects directly from image and mask files on disk.
279
-
280
- This method supports both single tiff files and directories
281
-
282
- Args:
283
- imgs_path: Path to input images. Can be:
284
- - Directory containing numbered tiff files of shape (C),(Z),Y,X
285
- - Single tiff file with time series of shape T,(C),(Z),Y,X
286
- masks_path: Path to mask files. Can be:
287
- - Directory containing numbered tiff files of shape (Z),Y,X
288
- - Single tiff file with time series of shape T,(Z),Y,X
289
- mode: Tracking mode:
290
- - "greedy_nodiv": Fast greedy linking without division
291
- - "greedy": Fast greedy linking with division
292
- - "ilp": Integer Linear Programming based linking (more accurate but slower)
293
- normalize_imgs: Whether to normalize the images.
294
- **kwargs: Additional arguments passed to tracking algorithm.
295
-
296
- Returns:
297
- Tuple of (TrackGraph, tracked masks).
298
- """
299
- if not imgs_path.exists():
300
- raise FileNotFoundError(f"{imgs_path=} does not exist.")
301
- if not masks_path.exists():
302
- raise FileNotFoundError(f"{masks_path=} does not exist.")
303
-
304
- if imgs_path.is_dir():
305
- imgs = load_tiff_timeseries(imgs_path)
306
- else:
307
- imgs = tifffile.imread(imgs_path)
308
-
309
- if masks_path.is_dir():
310
- masks = load_tiff_timeseries(masks_path)
311
- else:
312
- masks = tifffile.imread(masks_path)
313
-
314
- if len(imgs) != len(masks):
315
- raise RuntimeError(
316
- f"#imgs and #masks do not match. Found {len(imgs)} images,"
317
- f" {len(masks)} masks."
318
- )
319
-
320
- if imgs.ndim - 1 == masks.ndim:
321
- if imgs[1] == 1:
322
- logger.info(
323
- "Found a channel dimension with a single channel. Removing dim."
324
- )
325
- masks = np.squeeze(masks, 1)
326
- else:
327
- raise RuntimeError(
328
- "Trackastra currently only supports single channel images."
329
- )
330
-
331
- if imgs.shape != masks.shape:
332
- raise RuntimeError(
333
- f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
334
- )
335
-
336
- return self.track(
337
- imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs
338
- ), masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/model_parts.py DELETED
@@ -1,287 +0,0 @@
1
- """Transformer class."""
2
-
3
- import logging
4
- import math
5
- from typing import Literal
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from torch import nn
10
-
11
- from .rope import RotaryPositionalEncoding
12
-
13
- from typing import Tuple
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- def _pos_embed_fourier1d_init(
19
- cutoff: float = 256, n: int = 32, cutoff_start: float = 1
20
- ):
21
- return (
22
- torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n))
23
- .unsqueeze(0)
24
- .unsqueeze(0)
25
- )
26
-
27
-
28
- class FeedForward(nn.Module):
29
- def __init__(self, d_model, expand: float = 2, bias: bool = True):
30
- super().__init__()
31
- self.fc1 = nn.Linear(d_model, int(d_model * expand))
32
- self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias)
33
- self.act = nn.GELU()
34
-
35
- def forward(self, x):
36
- return self.fc2(self.act(self.fc1(x)))
37
-
38
-
39
- class PositionalEncoding(nn.Module):
40
- def __init__(
41
- self,
42
- cutoffs: Tuple[float] = (256,),
43
- n_pos: Tuple[int] = (32,),
44
- cutoffs_start=None,
45
- ):
46
- """Positional encoding with given cutoff and number of frequencies for each dimension.
47
- number of dimension is inferred from the length of cutoffs and n_pos.
48
- """
49
- super().__init__()
50
- if cutoffs_start is None:
51
- cutoffs_start = (1,) * len(cutoffs)
52
-
53
- assert len(cutoffs) == len(n_pos)
54
- self.freqs = nn.ParameterList([
55
- nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
56
- for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start)
57
- ])
58
-
59
- def forward(self, coords: torch.Tensor):
60
- _B, _N, D = coords.shape
61
- assert D == len(self.freqs)
62
- embed = torch.cat(
63
- tuple(
64
- torch.cat(
65
- (
66
- torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq),
67
- torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq),
68
- ),
69
- axis=-1,
70
- )
71
- / math.sqrt(len(freq))
72
- for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
73
- ),
74
- axis=-1,
75
- )
76
-
77
- return embed
78
-
79
-
80
- class NoPositionalEncoding(nn.Module):
81
- def __init__(self, d):
82
- """One learnable input token that ignores positional information."""
83
- super().__init__()
84
- self.d = d
85
- # self.token = nn.Parameter(torch.randn(d))
86
-
87
- def forward(self, coords: torch.Tensor):
88
- B, N, _ = coords.shape
89
- return (
90
- # torch.ones((B, N, self.d), device=coords.device) * 0.1
91
- # torch.randn((1, 1, self.d), device=coords.device).expand(B, N, -1) * 0.01
92
- torch.randn((B, N, self.d), device=coords.device) * 0.01
93
- + torch.randn((1, 1, self.d), device=coords.device).expand(B, N, -1) * 0.1
94
- )
95
- # return self.token.view(1, 1, -1).expand(B, N, -1)
96
-
97
-
98
- def _bin_init_exp(cutoff: float, n: int):
99
- return torch.exp(torch.linspace(0, math.log(cutoff + 1), n))
100
-
101
-
102
- def _bin_init_linear(cutoff: float, n: int):
103
- return torch.linspace(-cutoff, cutoff, n)
104
-
105
-
106
- class RelativePositionalBias(nn.Module):
107
- def __init__(
108
- self,
109
- n_head: int,
110
- cutoff_spatial: float,
111
- cutoff_temporal: float,
112
- n_spatial: int = 32,
113
- n_temporal: int = 16,
114
- ):
115
- """Learnt relative positional bias to add to self-attention matrix.
116
-
117
- Spatial bins are exponentially spaced, temporal bins are linearly spaced.
118
-
119
- Args:
120
- n_head (int): Number of pos bias heads. Equal to number of attention heads
121
- cutoff_spatial (float): Maximum distance in space.
122
- cutoff_temporal (float): Maxium distance in time. Equal to window size of transformer.
123
- n_spatial (int, optional): Number of spatial bins.
124
- n_temporal (int, optional): Number of temporal bins in each direction. Should be equal to window size. Total = 2 * n_temporal + 1. Defaults to 16.
125
- """
126
- super().__init__()
127
- self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial)
128
- self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1)
129
- self.register_buffer("spatial_bins", self._spatial_bins)
130
- self.register_buffer("temporal_bins", self._temporal_bins)
131
- self.n_spatial = n_spatial
132
- self.n_head = n_head
133
- self.bias = nn.Parameter(
134
- -0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head)
135
- )
136
-
137
- def forward(self, coords: torch.Tensor):
138
- _B, _N, _D = coords.shape
139
- t = coords[..., 0]
140
- yx = coords[..., 1:]
141
- temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2)
142
- spatial_dist = torch.cdist(yx, yx)
143
-
144
- spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins)
145
- torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1)
146
- temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins)
147
- torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1)
148
-
149
- # do some index gymnastics such that backward is not super slow
150
- # https://discuss.pytorch.org/t/how-to-select-multiple-indexes-over-multiple-dimensions-at-the-same-time/98532/2
151
- idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial
152
- bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head))
153
- # -> B, nH, N, N
154
- bias = bias.transpose(-1, 1)
155
- return bias
156
-
157
-
158
- class RelativePositionalAttention(nn.Module):
159
- def __init__(
160
- self,
161
- coord_dim: int,
162
- embed_dim: int,
163
- n_head: int,
164
- cutoff_spatial: float = 256,
165
- cutoff_temporal: float = 16,
166
- n_spatial: int = 32,
167
- n_temporal: int = 16,
168
- dropout: float = 0.0,
169
- mode: Literal["bias", "rope", "none"] = "bias",
170
- attn_dist_mode: str = "v0",
171
- ):
172
- super().__init__()
173
-
174
- if not embed_dim % (2 * n_head) == 0:
175
- raise ValueError(
176
- f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}"
177
- )
178
-
179
- # qkv projection
180
- self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True)
181
- self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True)
182
- self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True)
183
-
184
- # output projection
185
- self.proj = nn.Linear(embed_dim, embed_dim)
186
- # regularization
187
- self.dropout = dropout
188
- self.n_head = n_head
189
- self.embed_dim = embed_dim
190
- self.cutoff_spatial = cutoff_spatial
191
- self.attn_dist_mode = attn_dist_mode
192
-
193
- if mode == "bias" or mode is True:
194
- self.pos_bias = RelativePositionalBias(
195
- n_head=n_head,
196
- cutoff_spatial=cutoff_spatial,
197
- cutoff_temporal=cutoff_temporal,
198
- n_spatial=n_spatial,
199
- n_temporal=n_temporal,
200
- )
201
- elif mode == "rope":
202
- # each part needs to be divisible by 2
203
- n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head))
204
-
205
- self.rot_pos_enc = RotaryPositionalEncoding(
206
- cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim),
207
- n_pos=(embed_dim // n_head - coord_dim * n_split,)
208
- + (n_split,) * coord_dim,
209
- )
210
- elif mode == "none":
211
- pass
212
- elif mode is None or mode is False:
213
- logger.warning(
214
- "attn_positional_bias is not set (None or False), no positional bias."
215
- )
216
- pass
217
- else:
218
- raise ValueError(f"Unknown mode {mode}")
219
-
220
- self._mode = mode
221
-
222
- def forward(
223
- self,
224
- query: torch.Tensor,
225
- key: torch.Tensor,
226
- value: torch.Tensor,
227
- coords: torch.Tensor,
228
- padding_mask: torch.Tensor = None,
229
- ):
230
- B, N, D = query.size()
231
- q = self.q_pro(query) # (B, N, D)
232
- k = self.k_pro(key) # (B, N, D)
233
- v = self.v_pro(value) # (B, N, D)
234
- # (B, nh, N, hs)
235
- k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
236
- q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
237
- v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
238
-
239
- attn_mask = torch.zeros(
240
- (B, self.n_head, N, N), device=query.device, dtype=q.dtype
241
- )
242
-
243
- # add negative value but not too large to keep mixed precision loss from becoming nan
244
- attn_ignore_val = -1e3
245
-
246
- # spatial cutoff
247
- yx = coords[..., 1:]
248
- spatial_dist = torch.cdist(yx, yx)
249
- spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
250
- attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
251
-
252
- # dont add positional bias to self-attention if coords is None
253
- if coords is not None:
254
- if self._mode == "bias":
255
- attn_mask = attn_mask + self.pos_bias(coords)
256
- elif self._mode == "rope":
257
- q, k = self.rot_pos_enc(q, k, coords)
258
- else:
259
- pass
260
-
261
- if self.attn_dist_mode == "v0":
262
- dist = torch.cdist(coords, coords, p=2)
263
- attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
264
- elif self.attn_dist_mode == "v1":
265
- attn_mask += torch.exp(
266
- -5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial
267
- )
268
- else:
269
- raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
270
-
271
- # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens)
272
- if padding_mask is not None:
273
- ignore_mask = torch.logical_or(
274
- padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
275
- ).unsqueeze(1)
276
- attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
277
-
278
- # self.attn_mask = attn_mask.clone()
279
-
280
- y = F.scaled_dot_product_attention(
281
- q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
282
- )
283
-
284
- y = y.transpose(1, 2).contiguous().view(B, N, D)
285
- # output projection
286
- y = self.proj(y)
287
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/model_sd.py DELETED
@@ -1,338 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Literal, Union, Optional, Tuple
5
-
6
- import dask.array as da
7
- import numpy as np
8
- import tifffile
9
- import torch
10
- import yaml
11
- from tqdm import tqdm
12
-
13
- from ..data import build_windows, get_features, load_tiff_timeseries
14
- from ..tracking import TrackGraph, build_graph, track_greedy
15
- from ..utils import normalize
16
- from .model import TrackingTransformer
17
- from .predict import predict_windows
18
- from .pretrained import download_pretrained
19
-
20
- logging.basicConfig(level=logging.INFO)
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- class Trackastra:
25
- """A transformer-based tracking model for time-lapse data.
26
-
27
- Trackastra links segmented objects across time frames by predicting
28
- associations with a transformer model trained on diverse time-lapse videos.
29
-
30
- The model takes as input:
31
- - A sequence of images of shape (T,(Z),Y,X)
32
- - Corresponding instance segmentation masks of shape (T,(Z),Y,X)
33
-
34
- It supports multiple tracking modes:
35
- - greedy_nodiv: Fast greedy linking without division
36
- - greedy: Fast greedy linking with division
37
- - ilp: Integer Linear Programming based linking (more accurate but slower)
38
-
39
- Examples:
40
- >>> # Load example data
41
- >>> from trackastra.data import example_data_bacteria
42
- >>> imgs, masks = example_data_bacteria()
43
- >>>
44
- >>> # Load pretrained model and track
45
- >>> model = Trackastra.from_pretrained("general_2d", device="cuda")
46
- >>> track_graph = model.track(imgs, masks, mode="greedy")
47
- """
48
-
49
- def __init__(
50
- self,
51
- transformer: TrackingTransformer,
52
- train_args: dict,
53
- device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
54
- ):
55
- """Initialize Trackastra model.
56
-
57
- Args:
58
- transformer: The underlying transformer model.
59
- train_args: Training configuration arguments.
60
- device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
61
- """
62
- if device == "cuda":
63
- if torch.cuda.is_available():
64
- self.device = "cuda"
65
- else:
66
- logger.info("Cuda not available, falling back to cpu.")
67
- self.device = "cpu"
68
- elif device == "mps":
69
- if (
70
- torch.backends.mps.is_available()
71
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
72
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
73
- ):
74
- self.device = "mps"
75
- else:
76
- logger.info("Mps not available, falling back to cpu.")
77
- self.device = "cpu"
78
- elif device == "cpu":
79
- self.device = "cpu"
80
- elif device == "automatic" or device is None:
81
- should_use_mps = (
82
- torch.backends.mps.is_available()
83
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
84
- and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
85
- )
86
- self.device = (
87
- "cuda"
88
- if torch.cuda.is_available()
89
- else (
90
- "mps"
91
- if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
92
- else "cpu"
93
- )
94
- )
95
- else:
96
- raise ValueError(f"Device {device} not recognized.")
97
-
98
- logger.info(f"Using device {self.device}")
99
-
100
- self.transformer = transformer.to(self.device)
101
- self.train_args = train_args
102
-
103
- @classmethod
104
- def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None):
105
- """Load a Trackastra model from a local folder.
106
-
107
- Args:
108
- dir: Path to model folder containing:
109
- - model weights
110
- - train_config.yaml with training arguments
111
- device: Device to run model on.
112
-
113
- Returns:
114
- Trackastra model instance.
115
- """
116
- # Always load to cpu first
117
- transformer = TrackingTransformer.from_folder(
118
- Path(dir).expanduser(), map_location="cpu"
119
- )
120
- train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
121
- return cls(transformer=transformer, train_args=train_args, device=device)
122
-
123
- @classmethod
124
- def from_pretrained(
125
- cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None
126
- ):
127
- """Load a pretrained Trackastra model.
128
-
129
- Available pretrained models are described in detail in pretrained.json.
130
-
131
- Args:
132
- name: Name of pretrained model (e.g. "general_2d").
133
- device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
134
- download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
135
-
136
- Returns:
137
- Trackastra model instance.
138
- """
139
- folder = download_pretrained(name, download_dir)
140
- # download zip from github to location/name, then unzip
141
- return cls.from_folder(folder, device=device)
142
-
143
- def _predict(
144
- self,
145
- imgs: Union[np.ndarray, da.Array],
146
- masks: Union[np.ndarray, da.Array],
147
- edge_threshold: float = 0.05,
148
- n_workers: int = 0,
149
- normalize_imgs: bool = True,
150
- progbar_class=tqdm,
151
- ):
152
- logger.info("Predicting weights for candidate graph")
153
- if normalize_imgs:
154
- if isinstance(imgs, da.Array):
155
- imgs = imgs.map_blocks(normalize)
156
- else:
157
- imgs = normalize(imgs)
158
-
159
- self.transformer.eval()
160
-
161
- features = get_features(
162
- detections=masks,
163
- imgs=imgs,
164
- ndim=self.transformer.config["coord_dim"],
165
- n_workers=n_workers,
166
- progbar_class=progbar_class,
167
- )
168
- logger.info("Building windows")
169
- windows = build_windows(
170
- features,
171
- window_size=self.transformer.config["window"],
172
- progbar_class=progbar_class,
173
- )
174
-
175
- logger.info("Predicting windows")
176
- predictions = predict_windows(
177
- windows=windows,
178
- features=features,
179
- model=self.transformer,
180
- edge_threshold=edge_threshold,
181
- spatial_dim=masks.ndim - 1,
182
- progbar_class=progbar_class,
183
- )
184
-
185
- return predictions
186
-
187
- def _track_from_predictions(
188
- self,
189
- predictions,
190
- mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
191
- use_distance: bool = False,
192
- max_distance: int = 256,
193
- max_neighbors: int = 10,
194
- delta_t: int = 1,
195
- **kwargs,
196
- ):
197
- logger.info("Running greedy tracker")
198
- nodes = predictions["nodes"]
199
- weights = predictions["weights"]
200
-
201
- candidate_graph = build_graph(
202
- nodes=nodes,
203
- weights=weights,
204
- use_distance=use_distance,
205
- max_distance=max_distance,
206
- max_neighbors=max_neighbors,
207
- delta_t=delta_t,
208
- )
209
- if mode == "greedy":
210
- return track_greedy(candidate_graph)
211
- elif mode == "greedy_nodiv":
212
- return track_greedy(candidate_graph, allow_divisions=False)
213
- elif mode == "ilp":
214
- from trackastra.tracking.ilp import track_ilp
215
-
216
- return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
217
- else:
218
- raise ValueError(f"Tracking mode {mode} does not exist.")
219
-
220
- def track(
221
- self,
222
- imgs: Union[np.ndarray, da.Array],
223
- masks: Union[np.ndarray, da.Array],
224
- mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
225
- normalize_imgs: bool = True,
226
- progbar_class=tqdm,
227
- n_workers: int = 0,
228
- **kwargs,
229
- ) -> TrackGraph:
230
- """Track objects across time frames.
231
-
232
- This method links segmented objects across time frames using the specified
233
- tracking mode. No hyperparameters need to be chosen beyond the tracking mode.
234
-
235
- Args:
236
- imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array)
237
- masks: Instance segmentation masks of shape (T,(Z),Y,X).
238
- mode: Tracking mode:
239
- - "greedy_nodiv": Fast greedy linking without division
240
- - "greedy": Fast greedy linking with division
241
- - "ilp": Integer Linear Programming based linking (more accurate but slower)
242
- progbar_class: Progress bar class to use.
243
- n_workers: Number of worker processes for feature extraction.
244
- normalize_imgs: Whether to normalize the images.
245
- **kwargs: Additional arguments passed to tracking algorithm.
246
-
247
- Returns:
248
- TrackGraph containing the tracking results.
249
- """
250
- if not imgs.shape == masks.shape:
251
- raise RuntimeError(
252
- f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
253
- )
254
-
255
- if not imgs.ndim == self.transformer.config["coord_dim"] + 1:
256
- raise RuntimeError(
257
- f"images should be a sequence of {self.transformer.config['coord_dim']}D images"
258
- )
259
-
260
- predictions = self._predict(
261
- imgs,
262
- masks,
263
- normalize_imgs=normalize_imgs,
264
- progbar_class=progbar_class,
265
- n_workers=n_workers,
266
- )
267
- track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
268
- return track_graph
269
-
270
- def track_from_disk(
271
- self,
272
- imgs_path: Path,
273
- masks_path: Path,
274
- mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
275
- normalize_imgs: bool = True,
276
- **kwargs,
277
- ) -> Tuple[TrackGraph, np.ndarray]:
278
- """Track objects directly from image and mask files on disk.
279
-
280
- This method supports both single tiff files and directories
281
-
282
- Args:
283
- imgs_path: Path to input images. Can be:
284
- - Directory containing numbered tiff files of shape (C),(Z),Y,X
285
- - Single tiff file with time series of shape T,(C),(Z),Y,X
286
- masks_path: Path to mask files. Can be:
287
- - Directory containing numbered tiff files of shape (Z),Y,X
288
- - Single tiff file with time series of shape T,(Z),Y,X
289
- mode: Tracking mode:
290
- - "greedy_nodiv": Fast greedy linking without division
291
- - "greedy": Fast greedy linking with division
292
- - "ilp": Integer Linear Programming based linking (more accurate but slower)
293
- normalize_imgs: Whether to normalize the images.
294
- **kwargs: Additional arguments passed to tracking algorithm.
295
-
296
- Returns:
297
- Tuple of (TrackGraph, tracked masks).
298
- """
299
- if not imgs_path.exists():
300
- raise FileNotFoundError(f"{imgs_path=} does not exist.")
301
- if not masks_path.exists():
302
- raise FileNotFoundError(f"{masks_path=} does not exist.")
303
-
304
- if imgs_path.is_dir():
305
- imgs = load_tiff_timeseries(imgs_path)
306
- else:
307
- imgs = tifffile.imread(imgs_path)
308
-
309
- if masks_path.is_dir():
310
- masks = load_tiff_timeseries(masks_path)
311
- else:
312
- masks = tifffile.imread(masks_path)
313
-
314
- if len(imgs) != len(masks):
315
- raise RuntimeError(
316
- f"#imgs and #masks do not match. Found {len(imgs)} images,"
317
- f" {len(masks)} masks."
318
- )
319
-
320
- if imgs.ndim - 1 == masks.ndim:
321
- if imgs[1] == 1:
322
- logger.info(
323
- "Found a channel dimension with a single channel. Removing dim."
324
- )
325
- masks = np.squeeze(masks, 1)
326
- else:
327
- raise RuntimeError(
328
- "Trackastra currently only supports single channel images."
329
- )
330
-
331
- if imgs.shape != masks.shape:
332
- raise RuntimeError(
333
- f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
334
- )
335
-
336
- return self.track(
337
- imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs
338
- ), masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/predict.py DELETED
@@ -1,188 +0,0 @@
1
- import logging
2
- import warnings
3
-
4
- import numpy as np
5
- import torch
6
- from scipy.sparse import SparseEfficiencyWarning, csr_array
7
- from tqdm import tqdm
8
- from typing import List
9
-
10
- # TODO fix circular import
11
- # from .model import TrackingTransformer
12
- # from trackastra.data import WRFeatures
13
-
14
- warnings.simplefilter("ignore", SparseEfficiencyWarning)
15
-
16
- logger = logging.getLogger(__name__)
17
- logger.setLevel(logging.INFO)
18
-
19
-
20
- def predict(batch, model):
21
- """Predict association scores between objects in a batch.
22
-
23
- Args:
24
- batch: Dictionary containing:
25
- - features: Object features array
26
- - coords: Object coordinates array
27
- - timepoints: Time points array
28
- model: TrackingTransformer model to use for prediction.
29
-
30
- Returns:
31
- Array of association scores between objects.
32
- """
33
- feats = torch.from_numpy(batch["features"])
34
- coords = torch.from_numpy(batch["coords"])
35
- timepoints = torch.from_numpy(batch["timepoints"]).long()
36
- # Hack that assumes that all parameters of a model are on the same device
37
- device = next(model.parameters()).device
38
- feats = feats.unsqueeze(0).to(device)
39
- timepoints = timepoints.unsqueeze(0).to(device)
40
- coords = coords.unsqueeze(0).to(device)
41
-
42
- # Concat timepoints to coordinates
43
- coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
44
- with torch.no_grad():
45
- A = model(coords, features=feats)
46
- A = model.normalize_output(A, timepoints, coords)
47
-
48
- # # Spatially far entries should not influence the causal normalization
49
- # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
50
- # invalid = dist > model.config["spatial_pos_cutoff"]
51
- # A[invalid] = -torch.inf
52
-
53
- A = A.squeeze(0).detach().cpu().numpy()
54
-
55
- return A
56
-
57
-
58
- def predict_windows(
59
- windows: List[dict],
60
- # features: list[WRFeatures],
61
- # model: TrackingTransformer,
62
- features: list,
63
- model,
64
- intra_window_weight: float = 0,
65
- delta_t: int = 1,
66
- edge_threshold: float = 0.05,
67
- spatial_dim: int = 3,
68
- progbar_class=tqdm,
69
- ) -> dict:
70
- """Predict associations between objects across sliding windows.
71
-
72
- This function processes a sequence of sliding windows to predict associations
73
- between objects across time frames. It handles:
74
- - Object tracking across time
75
- - Weight normalization across windows
76
- - Edge thresholding
77
- - Time-based filtering
78
-
79
- Args:
80
- windows: List of window dictionaries containing:
81
- - timepoints: Array of time points
82
- - labels: Array of object labels
83
- - features: Object features
84
- - coords: Object coordinates
85
- features: List of feature objects containing:
86
- - labels: Object labels
87
- - timepoints: Time points
88
- - coords: Object coordinates
89
- model: TrackingTransformer model to use for prediction.
90
- intra_window_weight: Weight factor for objects in middle of window. Defaults to 0.
91
- delta_t: Maximum time difference between objects to consider. Defaults to 1.
92
- edge_threshold: Minimum association score to consider. Defaults to 0.05.
93
- spatial_dim: Dimensionality of input masks. May be less than model.coord_dim.
94
- progbar_class: Progress bar class to use. Defaults to tqdm.
95
-
96
- Returns:
97
- Dictionary containing:
98
- - nodes: List of node properties (id, coords, time, label)
99
- - weights: Tuple of ((node_i, node_j), weight) pairs
100
- """
101
- # first get all objects/coords
102
- time_labels_to_id = dict()
103
- node_properties = list()
104
- max_id = np.sum([len(f.labels) for f in features])
105
-
106
- all_timepoints = np.concatenate([f.timepoints for f in features])
107
- all_labels = np.concatenate([f.labels for f in features])
108
- all_coords = np.concatenate([f.coords for f in features])
109
- all_coords = all_coords[:, -spatial_dim:]
110
-
111
- for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)):
112
- time_labels_to_id[(t, la)] = i
113
- node_properties.append(
114
- dict(
115
- id=i,
116
- coords=tuple(c),
117
- time=t,
118
- # index=ix,
119
- label=la,
120
- )
121
- )
122
-
123
- # create assoc matrix between ids
124
- sp_weights, sp_accum = (
125
- csr_array((max_id, max_id), dtype=np.float32),
126
- csr_array((max_id, max_id), dtype=np.float32),
127
- )
128
-
129
- for t in progbar_class(
130
- range(len(windows)),
131
- desc="Computing associations",
132
- ):
133
- # This assumes that the samples in the dataset are ordered by time and start at 0.
134
- batch = windows[t]
135
- timepoints = batch["timepoints"]
136
- labels = batch["labels"]
137
-
138
- A = predict(batch, model)
139
-
140
- dt = timepoints[None, :] - timepoints[:, None]
141
- time_mask = np.logical_and(dt <= delta_t, dt > 0)
142
- A[~time_mask] = 0
143
- ii, jj = np.where(A >= edge_threshold)
144
-
145
- if len(ii) == 0:
146
- continue
147
-
148
- labels_ii = labels[ii]
149
- labels_jj = labels[jj]
150
- ts_ii = timepoints[ii]
151
- ts_jj = timepoints[jj]
152
- nodes_ii = np.array(
153
- tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii))
154
- )
155
- nodes_jj = np.array(
156
- tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj))
157
- )
158
-
159
- # weight middle parts higher
160
- t_middle = t + (model.config["window"] - 1) / 2
161
- ddt = timepoints[:, None] - t_middle * np.ones_like(dt)
162
- window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1
163
- # window_weight = np.exp(4*A) # smooth max
164
- sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj]
165
- sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj]
166
-
167
- sp_weights_coo = sp_weights.tocoo()
168
- sp_accum_coo = sp_accum.tocoo()
169
- assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose(
170
- sp_weights_coo.row, sp_accum_coo.row
171
- )
172
-
173
- # Normalize weights by the number of times they were written from different sliding window positions
174
- weights = tuple(
175
- ((i, j), v / a)
176
- for i, j, v, a in zip(
177
- sp_weights_coo.row,
178
- sp_weights_coo.col,
179
- sp_weights_coo.data,
180
- sp_accum_coo.data,
181
- )
182
- )
183
-
184
- results = dict()
185
- results["nodes"] = node_properties
186
- results["weights"] = weights
187
-
188
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/pretrained.json DELETED
@@ -1,81 +0,0 @@
1
- {
2
- "general_2d": {
3
- "tags": ["cells, nuclei, bacteria, epithelial, yeast, particles"],
4
- "dimensionality": [2],
5
- "description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane, budding yeast cells (PhC), fluorescent particles, .",
6
- "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
7
- "datasets": {
8
- "Subset of Cell Tracking Challenge 2d datasets": {
9
- "url": "https://celltrackingchallenge.net/2d-datasets/",
10
- "reference": "Maőka M, Ulman V, Delgado-Rodriguez P, Gómez-de-Mariscal E, NečasovÑ T, Guerrero Peña FA, Ren TI, Meyerowitz EM, Scherr T, Lâffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20."
11
- },
12
- "Bacteria van Vliet": {
13
- "url": "https://zenodo.org/records/268921",
14
- "reference": "van Vliet S, Winkler AR, Spriewald S, Stecher B, Ackermann M. Spatially correlated gene expression in bacterial groups: the role of lineage history, spatial gradients, and cell-cell interactions. Cell systems. 2018 Apr 25;6(4):496-507."
15
- },
16
- "Bacteria ObiWan-Microbi": {
17
- "url": "https://zenodo.org/records/7260137",
18
- "reference": "Seiffarth J, Scherr T, Wollenhaupt B, Neumann O, Scharr H, Kohlheyer D, Mikut R, NΓΆh K. ObiWan-Microbi: OMERO-based integrated workflow for annotating microbes in the cloud. SoftwareX. 2024 May 1;26:101638."
19
- },
20
- "Bacteria Persat": {
21
- "url": "https://www.p-lab.science",
22
- "reference": "Datasets kindly provided by Persat lab, EPFL."
23
- },
24
- "DeepCell": {
25
- "url": "https://datasets.deepcell.org/data",
26
- "reference": "Schwartz, M, Moen E, Miller G, Dougherty T, Borba E, Ding R, Graf W, Pao E, Van Valen D. Caliban: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning. Biorxiv. 2023 Sept 13:803205."
27
- },
28
- "Ker phase contrast": {
29
- "url": "https://osf.io/ysaq2/",
30
- "reference": "Ker DF, Eom S, Sanami S, Bise R, Pascale C, Yin Z, Huh SI, Osuna-Highley E, Junkers SN, Helfrich CJ, Liang PY. Phase contrast time-lapse microscopy datasets with automated and manual cell tracking annotations. Scientific data. 2018 Nov 13;5(1):1-2."
31
- },
32
- "Epithelia benchmark": {
33
- "reference": "Funke J, Mais L, Champion A, Dye N, Kainmueller D. A benchmark for epithelial cell tracking. InProceedings of The European Conference on Computer Vision (ECCV) Workshops 2018 (pp. 0-0)."
34
- },
35
- "T Cells": {
36
- "url": "https://zenodo.org/records/5206119"
37
- },
38
- "Neisseria meningitidis bacterial growth": {
39
- "url": "https://zenodo.org/records/5419619"
40
- },
41
- "Synthetic nuclei": {
42
- "reference": "Weigert group live cell simulator."
43
- },
44
- "Synthetic particles": {
45
- "reference": "Weigert group particle simulator."
46
- },
47
- "Particle Tracking Challenge": {
48
- "url": "http://bioimageanalysis.org/track/#data",
49
- "reference": "Chenouard, N., Smal, I., De Chaumont, F., MaΕ‘ka, M., Sbalzarini, I. F., Gong, Y., ... & Meijering, E. (2014). Objective comparison of particle tracking methods. Nature methods, 11(3), 281-289."
50
- },
51
- "Yeast Cell-ACDC": {
52
- "url": "https://zenodo.org/records/6795124",
53
- "reference": "Padovani, F., MairhΓΆrmann, B., Falter-Braun, P., Lengefeld, J., & Schmoller, K. M. (2022). Segmentation, tracking and cell cycle analysis of live-cell imaging data with Cell-ACDC. BMC biology, 20(1), 174."
54
- },
55
- "DeepSea": {
56
- "url": "https://deepseas.org/datasets/",
57
- "reference": "Zargari, A., Lodewijk, G. A., Mashhadi, N., Cook, N., Neudorf, C. W., Araghbidikashani, K., ... & Shariati, S. A. (2023). DeepSea is an efficient deep-learning model for single-cell segmentation and tracking in time-lapse microscopy. Cell Reports Methods, 3(6)."
58
- },
59
- "Btrack" : {
60
- "url": "https://rdr.ucl.ac.uk/articles/dataset/Cell_tracking_reference_dataset/16595978",
61
- "reference": "Ulicna, K., Vallardi, G., Charras, G., & Lowe, A. R. (2021). Automated deep lineage tree analysis using a Bayesian single cell tracking approach. Frontiers in Computer Science, 3, 734559."
62
- },
63
- "E. coli in mother machine": {
64
- "url": "https://zenodo.org/records/11237127",
65
- "reference": "O’Connor, O. M., & Dunlop, M. J. (2024). Cell-TRACTR: A transformer-based model for end-to-end segmentation and tracking of cells. bioRxiv, 2024-07."
66
- }
67
- }
68
- },
69
- "ctc": {
70
- "tags": ["ctc", "Cell Tracking Challenge", "Cell Linking Benchmark"],
71
- "dimensionality": [2, 3],
72
- "description": "For tracking Cell Tracking Challenge datasets. This is the successor of the winning model of the ISBI 2024 CTC generalizable linking challenge.",
73
- "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
74
- "datasets": {
75
- "All Cell Tracking Challenge 2d+3d datasets with available GT and ERR_SEG": {
76
- "url": "https://celltrackingchallenge.net/3d-datasets/",
77
- "reference": "Maőka M, Ulman V, Delgado-Rodriguez P, Gómez-de-Mariscal E, NečasovÑ T, Guerrero Peña FA, Ren TI, Meyerowitz EM, Scherr T, Lâffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20."
78
- }
79
- }
80
- }
81
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/pretrained.py DELETED
@@ -1,90 +0,0 @@
1
- import logging
2
- import shutil
3
- import tempfile
4
- import zipfile
5
- try:
6
- from importlib.resources import files
7
- except:
8
- from importlib_resources import files
9
- from pathlib import Path
10
-
11
- import requests
12
- from tqdm import tqdm
13
- from typing import Optional
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
- _MODELS = {
18
- "ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
19
- "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
20
- }
21
-
22
-
23
- def download_and_unzip(url: str, dst: Path):
24
- # TODO make safe and use tempfile lib
25
- if dst.exists():
26
- print(f"{dst} already downloaded, skipping.")
27
- return
28
-
29
- # get the name of the zipfile
30
- zip_base = Path(url.split("/")[-1])
31
-
32
- with tempfile.TemporaryDirectory() as tmp:
33
- tmp = Path(tmp)
34
- zip_file = tmp / zip_base
35
- # Download the zip file
36
- download(url, zip_file)
37
-
38
- # Unzip the file
39
- with zipfile.ZipFile(zip_file, "r") as zip_ref:
40
- zip_ref.extractall(tmp)
41
-
42
- shutil.move(tmp / zip_base.stem, dst)
43
-
44
-
45
- def download(url: str, fname: Path):
46
- resp = requests.get(url, stream=True)
47
- total = int(resp.headers.get("content-length", 0))
48
- # try:
49
- # with (open(str(fname), "wb") as file,
50
- # tqdm(
51
- # desc=str(fname),
52
- # total=total,
53
- # unit="iB",
54
- # unit_scale=True,
55
- # unit_divisor=1024,
56
- # ) as bar,):
57
- # for data in resp.iter_content(chunk_size=1024):
58
- # size = file.write(data)
59
- # bar.update(size)
60
- # except:
61
- with open(str(fname), "wb") as file, tqdm(
62
- desc=str(fname),
63
- total=total,
64
- unit="iB",
65
- unit_scale=True,
66
- unit_divisor=1024,
67
- ) as bar:
68
- for data in resp.iter_content(chunk_size=1024):
69
- size = file.write(data)
70
- bar.update(size)
71
-
72
-
73
- def download_pretrained(name: str, download_dir: Optional[Path] = None):
74
- # TODO make safe, introduce versioning
75
- if download_dir is None:
76
- download_dir = files("trackastra").joinpath(".models")
77
- else:
78
- download_dir = Path(download_dir)
79
-
80
- download_dir.mkdir(exist_ok=True, parents=True)
81
- try:
82
- url = _MODELS[name]
83
- except KeyError:
84
- raise ValueError(
85
- "Pretrained model `name` is not available. Choose from"
86
- f" {list(_MODELS.keys())}"
87
- )
88
- folder = download_dir / name
89
- download_and_unzip(url=url, dst=folder)
90
- return folder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/rope.py DELETED
@@ -1,94 +0,0 @@
1
- """Transformer class."""
2
-
3
- # from torch_geometric.nn import GATv2Conv
4
- import math
5
-
6
- import torch
7
- from torch import nn
8
- from typing import Tuple
9
-
10
-
11
- def _pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
12
- # Maximum initial frequency is 1
13
- return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
14
-
15
-
16
- # https://github.com/cvg/LightGlue/blob/b1cd942fc4a3a824b6aedff059d84f5c31c297f6/lightglue/lightglue.py#L51
17
- def _rotate_half(x: torch.Tensor) -> torch.Tensor:
18
- """Rotate pairs of scalars as 2d vectors by pi/2.
19
- Refer to eq 34 in https://arxiv.org/pdf/2104.09864.pdf.
20
- """
21
- x = x.unflatten(-1, (-1, 2))
22
- x1, x2 = x.unbind(dim=-1)
23
- return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
24
-
25
-
26
- class RotaryPositionalEncoding(nn.Module):
27
- def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
28
- """Rotary positional encoding with given cutoff and number of frequencies for each dimension.
29
- number of dimension is inferred from the length of cutoffs and n_pos.
30
-
31
- see
32
- https://arxiv.org/pdf/2104.09864.pdf
33
- """
34
- super().__init__()
35
- assert len(cutoffs) == len(n_pos)
36
- if not all(n % 2 == 0 for n in n_pos):
37
- raise ValueError("n_pos must be even")
38
-
39
- self._n_dim = len(cutoffs)
40
- # theta in RoFormer https://arxiv.org/pdf/2104.09864.pdf
41
- self.freqs = nn.ParameterList([
42
- nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
43
- for cutoff, n in zip(cutoffs, n_pos)
44
- ])
45
-
46
- def get_co_si(self, coords: torch.Tensor):
47
- _B, _N, D = coords.shape
48
- assert D == len(self.freqs)
49
- co = torch.cat(
50
- tuple(
51
- torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
52
- for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
53
- ),
54
- axis=-1,
55
- )
56
- si = torch.cat(
57
- tuple(
58
- torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
59
- for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
60
- ),
61
- axis=-1,
62
- )
63
-
64
- return co, si
65
-
66
- def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
67
- _B, _N, D = coords.shape
68
- _B, _H, _N, _C = q.shape
69
-
70
- if not D == self._n_dim:
71
- raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
72
-
73
- co, si = self.get_co_si(coords)
74
-
75
- co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
76
- si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
77
- q2 = q * co + _rotate_half(q) * si
78
- k2 = k * co + _rotate_half(k) * si
79
-
80
- return q2, k2
81
-
82
-
83
- if __name__ == "__main__":
84
- model = RotaryPositionalEncoding((256, 256), (32, 32))
85
-
86
- x = 100 * torch.rand(1, 17, 2)
87
- q = torch.rand(1, 4, 17, 64)
88
- k = torch.rand(1, 4, 17, 64)
89
-
90
- q1, k1 = model(q, k, x)
91
- A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2)
92
-
93
- q2, k2 = model(q, k, x + 10)
94
- A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/utils/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # ruff: noqa: F401
2
-
3
- from .utils import (
4
- blockwise_causal_norm,
5
- blockwise_sum,
6
- normalize,
7
- normalize_01,
8
- preallocate_memory,
9
- random_label_cmap,
10
- render_label,
11
- seed,
12
- str2bool,
13
- str2path,
14
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/{trackastra/tracking β†’ tracking}/__init__.py RENAMED
@@ -7,9 +7,6 @@ from .tracking import (
7
  )
8
  from .utils import (
9
  ctc_to_graph,
10
- ctc_to_napari_tracks,
11
  graph_to_ctc,
12
  graph_to_edge_table,
13
- graph_to_napari_tracks,
14
- linear_chains,
15
  )
 
7
  )
8
  from .utils import (
9
  ctc_to_graph,
 
10
  graph_to_ctc,
11
  graph_to_edge_table,
 
 
12
  )
models/tra_post_model/{trackastra/tracking β†’ tracking}/ilp.py RENAMED
File without changes
models/tra_post_model/{trackastra/tracking β†’ tracking}/track_graph.py RENAMED
File without changes
models/tra_post_model/{trackastra/tracking β†’ tracking}/tracking.py RENAMED
@@ -9,8 +9,6 @@ from tqdm import tqdm
9
  from .track_graph import TrackGraph
10
  from typing import Optional, Tuple
11
 
12
- # from trackastra.tracking import graph_to_napari_tracks, graph_to_ctc
13
-
14
  logger = logging.getLogger(__name__)
15
 
16
 
@@ -43,9 +41,6 @@ def track_greedy(
43
 
44
  solution_graph = nx.DiGraph()
45
 
46
- # TODO bring back
47
- # if args.gt_as_dets:
48
- # solution_graph.add_nodes_from(candidate_graph.nodes(data=True))
49
 
50
  edges = candidate_graph.edges(data=True)
51
  edges = sorted(
@@ -75,12 +70,9 @@ def track_greedy(
75
  # otherwise add to solution
76
  copy_edge(edge, candidate_graph, solution_graph)
77
 
78
- # df, masks = graph_to_ctc(solution_graph, masks_original)
79
- # tracks, tracks_graph, _ = graph_to_napari_tracks(solution_graph)
80
 
81
  return solution_graph
82
- # TODO this should all be in a tracker class
83
- # return df, masks, solution_graph, tracks_graph, tracks, candidate_graph
84
 
85
 
86
  def build_graph(
 
9
  from .track_graph import TrackGraph
10
  from typing import Optional, Tuple
11
 
 
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
 
41
 
42
  solution_graph = nx.DiGraph()
43
 
 
 
 
44
 
45
  edges = candidate_graph.edges(data=True)
46
  edges = sorted(
 
70
  # otherwise add to solution
71
  copy_edge(edge, candidate_graph, solution_graph)
72
 
 
 
73
 
74
  return solution_graph
75
+
 
76
 
77
 
78
  def build_graph(
models/tra_post_model/{trackastra/tracking β†’ tracking}/utils.py RENAMED
@@ -14,38 +14,8 @@ logger = logging.getLogger(__name__)
14
  logger.setLevel(logging.INFO)
15
 
16
 
17
- class FoundTracks(Exception):
18
- pass
19
-
20
-
21
- def ctc_to_napari_tracks(segmentation: np.ndarray, man_track: pd.DataFrame):
22
- """Convert tracks in CTC format to tracks in napari format.
23
-
24
- Args:
25
- segmentation: Dims time, spatial_0, ... , spatial_n
26
- man_track: columns id, start, end, parent
27
- """
28
- tracks = []
29
- for t, frame in tqdm(
30
- enumerate(segmentation),
31
- total=len(segmentation),
32
- leave=False,
33
- desc="Computing centroids",
34
- ):
35
- for r in regionprops(frame):
36
- tracks.append((r.label, t, *r.centroid))
37
-
38
- tracks_graph = {}
39
- for idx, _, _, parent in tqdm(
40
- man_track.to_numpy(),
41
- desc="Converting CTC to napari tracks",
42
- leave=False,
43
- ):
44
- if parent != 0:
45
- tracks_graph[idx] = [parent]
46
-
47
- return tracks, tracks_graph
48
-
49
 
50
  class CtcTracklet:
51
  def __init__(self, parent: int, nodes: List[int], start_frame: int) -> None:
@@ -125,77 +95,77 @@ def ctc_tracklets(G: nx.DiGraph, frame_attribute: str = "time") -> List[CtcTrack
125
  return tracklets
126
 
127
 
128
- def linear_chains(G: nx.DiGraph):
129
- """Find all linear chains in a tree/graph, i.e. paths that.
130
-
131
- i) either start/end at a node with out_degree>in_degree or and have no internal branches, or
132
- ii) consists of a single node or a single splitting node
133
-
134
- Note that each chain includes its start/end node, i.e. they can be appear in multiple chains.
135
- """
136
- # get all nodes with out_degree>in_degree (i.e. start of chain)
137
- nodes = tuple(n for n in G.nodes if G.out_degree[n] > G.in_degree[n])
138
- # single nodes are those that are not starting a linear chain
139
- # single_nodes = tuple(n for n in G.nodes if G.out_degree[n] == G.in_degree[n] == 0)
140
- single_nodes = tuple(
141
- n for n in G.nodes if G.in_degree[n] == 0 and G.out_degree[n] != 1
142
- )
143
-
144
- for ni in single_nodes:
145
- yield [ni]
146
-
147
- for ni in nodes:
148
- neighs = tuple(G.neighbors(ni))
149
- for child in neighs:
150
- path = [ni, child]
151
- while len(childs := tuple(G.neighbors(path[-1]))) == 1:
152
- path.append(childs[0])
153
- yield path
154
-
155
-
156
- def graph_to_napari_tracks(
157
- graph: nx.DiGraph,
158
- properties: List[str] = [],
159
- ):
160
- """Convert a track graph to napari tracks."""
161
- # each tracklet is a linear chain in the graph
162
- chains = tuple(linear_chains(graph))
163
-
164
- track_end_to_track_id = dict()
165
- labels = []
166
- for i, cs in enumerate(chains):
167
- label = i + 1
168
- labels.append(label)
169
- # if len(cs) == 1:
170
- # print(cs)
171
- # # Non-connected node
172
- # continue
173
- end = cs[-1]
174
- track_end_to_track_id[end] = label
175
-
176
- tracks = []
177
- tracks_graph = dict()
178
- tracks_props = {p: [] for p in properties}
179
-
180
- for label, cs in tqdm(zip(labels, chains), total=len(chains)):
181
- start = cs[0]
182
- if start in track_end_to_track_id and len(cs) > 1:
183
- tracks_graph[label] = track_end_to_track_id[start]
184
- nodes = cs[1:]
185
- else:
186
- nodes = cs
187
-
188
- for c in nodes:
189
- node = graph.nodes[c]
190
- t = node["time"]
191
- coord = node["coords"]
192
- tracks.append([label, t, *list(coord)])
193
-
194
- for p in properties:
195
- tracks_props[p].append(node[p])
196
-
197
- tracks = np.array(tracks)
198
- return tracks, tracks_graph, tracks_props
199
 
200
 
201
  def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray):
 
14
  logger.setLevel(logging.INFO)
15
 
16
 
17
+ # class FoundTracks(Exception):
18
+ # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class CtcTracklet:
21
  def __init__(self, parent: int, nodes: List[int], start_frame: int) -> None:
 
95
  return tracklets
96
 
97
 
98
+ # def linear_chains(G: nx.DiGraph):
99
+ # """Find all linear chains in a tree/graph, i.e. paths that.
100
+
101
+ # i) either start/end at a node with out_degree>in_degree or and have no internal branches, or
102
+ # ii) consists of a single node or a single splitting node
103
+
104
+ # Note that each chain includes its start/end node, i.e. they can be appear in multiple chains.
105
+ # """
106
+ # # get all nodes with out_degree>in_degree (i.e. start of chain)
107
+ # nodes = tuple(n for n in G.nodes if G.out_degree[n] > G.in_degree[n])
108
+ # # single nodes are those that are not starting a linear chain
109
+ # # single_nodes = tuple(n for n in G.nodes if G.out_degree[n] == G.in_degree[n] == 0)
110
+ # single_nodes = tuple(
111
+ # n for n in G.nodes if G.in_degree[n] == 0 and G.out_degree[n] != 1
112
+ # )
113
+
114
+ # for ni in single_nodes:
115
+ # yield [ni]
116
+
117
+ # for ni in nodes:
118
+ # neighs = tuple(G.neighbors(ni))
119
+ # for child in neighs:
120
+ # path = [ni, child]
121
+ # while len(childs := tuple(G.neighbors(path[-1]))) == 1:
122
+ # path.append(childs[0])
123
+ # yield path
124
+
125
+
126
+ # def graph_to_napari_tracks(
127
+ # graph: nx.DiGraph,
128
+ # properties: List[str] = [],
129
+ # ):
130
+ # """Convert a track graph to napari tracks."""
131
+ # # each tracklet is a linear chain in the graph
132
+ # chains = tuple(linear_chains(graph))
133
+
134
+ # track_end_to_track_id = dict()
135
+ # labels = []
136
+ # for i, cs in enumerate(chains):
137
+ # label = i + 1
138
+ # labels.append(label)
139
+ # # if len(cs) == 1:
140
+ # # print(cs)
141
+ # # # Non-connected node
142
+ # # continue
143
+ # end = cs[-1]
144
+ # track_end_to_track_id[end] = label
145
+
146
+ # tracks = []
147
+ # tracks_graph = dict()
148
+ # tracks_props = {p: [] for p in properties}
149
+
150
+ # for label, cs in tqdm(zip(labels, chains), total=len(chains)):
151
+ # start = cs[0]
152
+ # if start in track_end_to_track_id and len(cs) > 1:
153
+ # tracks_graph[label] = track_end_to_track_id[start]
154
+ # nodes = cs[1:]
155
+ # else:
156
+ # nodes = cs
157
+
158
+ # for c in nodes:
159
+ # node = graph.nodes[c]
160
+ # t = node["time"]
161
+ # coord = node["coords"]
162
+ # tracks.append([label, t, *list(coord)])
163
+
164
+ # for p in properties:
165
+ # tracks_props[p].append(node[p])
166
+
167
+ # tracks = np.array(tracks)
168
+ # return tracks, tracks_graph, tracks_props
169
 
170
 
171
  def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray):
models/tra_post_model/{trackastra/utils/utils.py β†’ utils.py} RENAMED
@@ -1,13 +1,6 @@
1
- import colorsys
2
- import itertools
3
  import logging
4
- import random
5
- import sys
6
- from pathlib import Path
7
- from timeit import default_timer
8
 
9
  import dask.array as da
10
- import matplotlib
11
  import numpy as np
12
  import torch
13
  from typing import Optional, Union
@@ -15,174 +8,6 @@ from typing import Optional, Union
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
- def _single_color_integer_cmap(color=(0.3, 0.4, 0.5)):
19
- from matplotlib.colors import Colormap
20
-
21
- assert len(color) in (3, 4)
22
-
23
- class BinaryMap(Colormap):
24
- def __init__(self, color):
25
- self.color = np.array(color)
26
- if len(self.color) == 3:
27
- self.color = np.concatenate([self.color, [1]])
28
-
29
- def __call__(self, X, alpha=None, bytes=False):
30
- res = np.zeros((*X.shape, 4), np.float32)
31
- res[..., -1] = self.color[-1]
32
- res[X > 0] = np.expand_dims(self.color, 0)
33
- if bytes:
34
- return np.clip(256 * res, 0, 255).astype(np.uint8)
35
- else:
36
- return res
37
-
38
- return BinaryMap(color)
39
-
40
-
41
- def render_label(
42
- lbl,
43
- img=None,
44
- cmap=None,
45
- cmap_img="gray",
46
- alpha=0.5,
47
- alpha_boundary=None,
48
- normalize_img=True,
49
- ):
50
- """Renders a label image and optionally overlays it with another image. Used for generating simple output images to asses the label quality.
51
-
52
- Parameters
53
- ----------
54
- lbl: np.ndarray of dtype np.uint16
55
- The 2D label image
56
- img: np.ndarray
57
- The array to overlay the label image with (optional)
58
- cmap: string, tuple, or callable
59
- The label colormap. If given as rgb(a) only a single color is used, if None uses a random colormap
60
- cmap_img: string or callable
61
- The colormap of img (optional)
62
- alpha: float
63
- The alpha value of the overlay. Set alpha=1 to get fully opaque labels
64
- alpha_boundary: float
65
- The alpha value of the boundary (if None, use the same as for labels, i.e. no boundaries are visible)
66
- normalize_img: bool
67
- If True, normalizes the img (if given)
68
-
69
- Returns:
70
- -------
71
- img: np.ndarray
72
- the (m,n,4) RGBA image of the rendered label
73
-
74
- Example:
75
- -------
76
- from scipy.ndimage import label, zoom
77
- img = zoom(np.random.uniform(0,1,(16,16)),(8,8),order=3)
78
- lbl,_ = label(img>.8)
79
- u1 = render_label(lbl, img = img, alpha = .7)
80
- u2 = render_label(lbl, img = img, alpha = 0, alpha_boundary =.8)
81
- plt.subplot(1,2,1);plt.imshow(u1)
82
- plt.subplot(1,2,2);plt.imshow(u2)
83
-
84
- """
85
- from matplotlib import cm
86
- from skimage.segmentation import find_boundaries
87
-
88
- alpha = np.clip(alpha, 0, 1)
89
-
90
- if alpha_boundary is None:
91
- alpha_boundary = alpha
92
-
93
- if cmap is None:
94
- cmap = random_label_cmap()
95
- elif isinstance(cmap, tuple):
96
- cmap = _single_color_integer_cmap(cmap)
97
- else:
98
- pass
99
-
100
- cmap = cm.get_cmap(cmap) if isinstance(cmap, str) else cmap
101
- cmap_img = cm.get_cmap(cmap_img) if isinstance(cmap_img, str) else cmap_img
102
-
103
- # render image if given
104
- if img is None:
105
- im_img = np.zeros((*lbl.shape, 4), np.float32)
106
- im_img[..., -1] = 1
107
-
108
- else:
109
- assert lbl.shape[:2] == img.shape[:2]
110
- img = normalize(img) if normalize_img else img
111
- if img.ndim == 2:
112
- im_img = cmap_img(img)
113
- elif img.ndim == 3:
114
- im_img = img[..., :4]
115
- if img.shape[-1] < 4:
116
- im_img = np.concatenate(
117
- [img, np.ones(img.shape[:2] + (4 - img.shape[-1],))], axis=-1
118
- )
119
- else:
120
- raise ValueError("img should be 2 or 3 dimensional")
121
-
122
- # render label
123
- im_lbl = cmap(lbl)
124
-
125
- mask_lbl = lbl > 0
126
- mask_bound = np.bitwise_and(mask_lbl, find_boundaries(lbl, mode="thick"))
127
-
128
- # blend
129
- im = im_img.copy()
130
-
131
- im[mask_lbl] = alpha * im_lbl[mask_lbl] + (1 - alpha) * im_img[mask_lbl]
132
- im[mask_bound] = (
133
- alpha_boundary * im_lbl[mask_bound] + (1 - alpha_boundary) * im_img[mask_bound]
134
- )
135
-
136
- return im
137
-
138
-
139
- def random_label_cmap(n=2**16, h=(0, 1), lightness=(0.4, 1), s=(0.2, 0.8)):
140
- h, lightness, s = (
141
- np.random.uniform(*h, n),
142
- np.random.uniform(*lightness, n),
143
- np.random.uniform(*s, n),
144
- )
145
- cols = np.stack(
146
- [colorsys.hls_to_rgb(_h, _l, _s) for _h, _l, _s in zip(h, lightness, s)], axis=0
147
- )
148
- cols[0] = 0
149
- return matplotlib.colors.ListedColormap(cols)
150
-
151
-
152
- # @torch.jit.script
153
- def _blockwise_sum_with_bounds(A: torch.Tensor, bounds: torch.Tensor, dim: int = 0):
154
- A = A.transpose(dim, 0)
155
- cum = torch.cumsum(A, dim=0)
156
- cum = torch.cat((torch.zeros_like(cum[:1]), cum), dim=0)
157
- B = torch.zeros_like(A, device=A.device)
158
- for i, j in itertools.pairwise(bounds[:-1], bounds[1:]):
159
- B[i:j] = cum[j] - cum[i]
160
- B = B.transpose(0, dim)
161
- return B
162
-
163
-
164
- def _bounds_from_timepoints(timepoints: torch.Tensor):
165
- assert timepoints.ndim == 1
166
- bounds = torch.cat((
167
- torch.tensor([0], device=timepoints.device),
168
- # torch.nonzero faster than torch.where
169
- torch.nonzero(timepoints[1:] - timepoints[:-1], as_tuple=False)[:, 0] + 1,
170
- torch.tensor([len(timepoints)], device=timepoints.device),
171
- ))
172
- return bounds
173
-
174
-
175
- # def blockwise_sum(A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0):
176
- # # get block boundaries
177
- # assert A.shape[dim] == len(timepoints)
178
-
179
- # bounds = _bounds_from_timepoints(timepoints)
180
-
181
- # # normalize within blocks
182
- # u = _blockwise_sum_with_bounds(A, bounds, dim=dim)
183
- # return u
184
-
185
-
186
  def blockwise_sum(
187
  A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum"
188
  ):
@@ -270,11 +95,6 @@ def blockwise_causal_norm(
270
  else:
271
  raise NotImplementedError(f"Mode {mode} not implemented")
272
 
273
- # get block boundaries and normalize within blocks
274
- # bounds = _bounds_from_timepoints(timepoints)
275
- # u0_sum = _blockwise_sum_with_bounds(u0, bounds, dim=0) + eps
276
- # u1_sum = _blockwise_sum_with_bounds(u1, bounds, dim=1) + eps
277
-
278
  u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps
279
  u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps
280
 
@@ -296,12 +116,6 @@ def blockwise_causal_norm(
296
  return res
297
 
298
 
299
- def normalize_tensor(x: torch.Tensor, dim: Optional[int] = None, eps: float = 1e-8):
300
- if dim is None:
301
- dim = tuple(range(x.ndim))
302
-
303
- mi, ma = torch.amin(x, dim=dim, keepdim=True), torch.amax(x, dim=dim, keepdim=True)
304
- return (x - mi) / (ma - mi + eps)
305
 
306
 
307
  def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4):
@@ -340,155 +154,3 @@ def normalize_01(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4):
340
  x /= ma - mi + 1e-8
341
  return x
342
 
343
-
344
- def batched(x, batch_size, device):
345
- return x.unsqueeze(0).expand(batch_size, *((-1,) * x.ndim)).to(device)
346
-
347
-
348
- def preallocate_memory(dataset, model_lightning, batch_size, max_tokens, device):
349
- """https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#preallocate-memory-in-case-of-variable-input-length."""
350
- start = default_timer()
351
-
352
- if max_tokens is None:
353
- logger.warning(
354
- "Preallocating memory without specifying max_tokens not implemented."
355
- )
356
- return
357
-
358
- # max_len = 0
359
- # max_idx = -1
360
- # # TODO speed up
361
- # # find largest training sample
362
- # if isinstance(dataset, torch.utils.data.dataset.ConcatDataset):
363
- # lens = tuple(
364
- # len(t["timepoints"]) for data in dataset.datasets for t in data.windows
365
- # )
366
- # elif isinstance(dataset, torch.utils.data.Dataset):
367
- # lens = tuple(len(t["timepoints"]) for t in dataset.windows)
368
- # else:
369
- # lens = tuple(
370
- # len(s["timepoints"])
371
- # for i, s in tqdm(
372
- # enumerate(dataset),
373
- # desc="Iterate over training set to find largest training sample",
374
- # total=len(dataset),
375
- # leave=False,
376
- # )
377
- # )
378
-
379
- # max_len = max(lens)
380
- # max_idx = lens.index(max_len)
381
-
382
- # # build random batch
383
- # x = dataset[max_idx]
384
- # batch = dict(
385
- # features=batched(x["features"], batch_size, device),
386
- # coords=batched(x["coords"], batch_size, device),
387
- # assoc_matrix=batched(x["assoc_matrix"], batch_size, device),
388
- # timepoints=batched(x["timepoints"], batch_size, device),
389
- # padding_mask=batched(torch.zeros_like(x["timepoints"]), batch_size, device),
390
- # )
391
-
392
- else:
393
- max_len = max_tokens
394
- x = dataset[0]
395
- batch = dict(
396
- features=batched(
397
- torch.zeros(
398
- (max_len,) + x["features"].shape[1:], dtype=x["features"].dtype
399
- ),
400
- batch_size,
401
- device,
402
- ),
403
- coords=batched(
404
- torch.zeros(
405
- (max_len,) + x["coords"].shape[1:], dtype=x["coords"].dtype
406
- ),
407
- batch_size,
408
- device,
409
- ),
410
- assoc_matrix=batched(
411
- torch.zeros((max_len, max_len), dtype=x["assoc_matrix"].dtype),
412
- batch_size,
413
- device,
414
- ),
415
- timepoints=batched(
416
- torch.zeros(max_len, dtype=x["timepoints"].dtype), batch_size, device
417
- ),
418
- padding_mask=batched(torch.zeros(max_len, dtype=bool), batch_size, device),
419
- )
420
-
421
- loss = model_lightning._common_step(batch)["loss"]
422
- loss.backward()
423
- model_lightning.zero_grad()
424
-
425
- logger.info(
426
- f"Preallocated memory for largest training batch (length {max_len}) in"
427
- f" {default_timer() - start:.02f} s"
428
- )
429
- if device.type == "cuda":
430
- logger.info(
431
- "Memory allocated for model:"
432
- f" {torch.cuda.max_memory_allocated() / 1024**3:.02f} GB"
433
- )
434
-
435
-
436
- def seed(s=None):
437
- """Seed random number generators.
438
-
439
- Defaults to unix timestamp of function call.
440
-
441
- Args:
442
- s (``int``): Manual seed.
443
- """
444
- if s is None:
445
- s = int(default_timer())
446
-
447
- random.seed(s)
448
- logger.debug(f"Seed `random` rng with {s}.")
449
- np.random.seed(s)
450
- logger.debug(f"Seed `numpy` rng with {s}.")
451
- if "torch" in sys.modules:
452
- torch.manual_seed(s)
453
- logger.debug(f"Seed `torch` rng with {s}.")
454
-
455
- return s
456
-
457
-
458
- def str2bool(x: str) -> bool:
459
- """Cast string to boolean.
460
-
461
- Useful for parsing command line arguments.
462
- """
463
- if not isinstance(x, str):
464
- raise TypeError("String expected.")
465
- elif x.lower() in ("true", "t", "1"):
466
- return True
467
- elif x.lower() in ("false", "f", "0"):
468
- return False
469
- else:
470
- raise ValueError(f"'{x}' does not seem to be boolean.")
471
-
472
-
473
- def str2path(x: str) -> Path:
474
- """Cast string to resolved absolute path.
475
-
476
- Useful for parsing command line arguments.
477
- """
478
- if not isinstance(x, str):
479
- raise TypeError("String expected.")
480
- else:
481
- return Path(x).expanduser().resolve()
482
-
483
-
484
- if __name__ == "__main__":
485
- A = torch.rand(50, 50)
486
- idx = torch.tensor([0, 10, 20, A.shape[0]])
487
-
488
- A = torch.eye(50)
489
-
490
- B = _blockwise_sum_with_bounds(A, idx)
491
-
492
- tps = torch.repeat_interleave(torch.arange(5), 10)
493
-
494
- C = blockwise_causal_norm(A, tps)
 
 
 
1
  import logging
 
 
 
 
2
 
3
  import dask.array as da
 
4
  import numpy as np
5
  import torch
6
  from typing import Optional, Union
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def blockwise_sum(
12
  A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum"
13
  ):
 
95
  else:
96
  raise NotImplementedError(f"Mode {mode} not implemented")
97
 
 
 
 
 
 
98
  u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps
99
  u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps
100
 
 
116
  return res
117
 
118
 
 
 
 
 
 
 
119
 
120
 
121
  def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4):
 
154
  x /= ma - mi + 1e-8
155
  return x
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tracking_one.py CHANGED
@@ -27,12 +27,12 @@ from _utils.load_models import load_stable_diffusion_model
27
  from models.model import Counting_with_SD_features_track as Counting
28
  from models.enc_model.loca import build_model as build_loca_model
29
  import time
30
- from models.tra_post_model.trackastra.model import TrackingTransformer
31
- from models.tra_post_model.trackastra.utils import (
32
  normalize,
33
  )
34
- from models.tra_post_model.trackastra.data import build_windows_sd, get_features
35
- from models.tra_post_model.trackastra.tracking import TrackGraph, build_graph, track_greedy
36
  from _utils.track_args import parse_train_args as get_track_args
37
  import torchvision.transforms as T
38
  from pathlib import Path
@@ -142,9 +142,7 @@ class TrackingModule(pl.LightningModule):
142
  # others
143
  self.placeholder_token = placeholder_token
144
  self.placeholder_token_id = placeholder_token_id
145
-
146
- # tracking model
147
- # fpath = Path("models/tra_post_model/trackastra/.models/general_2d/model.pt")
148
  fpath = Path("_utils/config.yaml")
149
  args_ = get_track_args()
150
 
@@ -935,7 +933,7 @@ class TrackingModule(pl.LightningModule):
935
  elif mode == "greedy_nodiv":
936
  return track_greedy(candidate_graph, allow_divisions=False)
937
  elif mode == "ilp":
938
- from models.tra_post_model.trackastra.tracking.ilp import track_ilp
939
 
940
  return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
941
  else:
 
27
  from models.model import Counting_with_SD_features_track as Counting
28
  from models.enc_model.loca import build_model as build_loca_model
29
  import time
30
+ from models.tra_post_model.model import TrackingTransformer
31
+ from models.tra_post_model.utils import (
32
  normalize,
33
  )
34
+ from models.tra_post_model.data import build_windows_sd, get_features
35
+ from models.tra_post_model.tracking import TrackGraph, build_graph, track_greedy
36
  from _utils.track_args import parse_train_args as get_track_args
37
  import torchvision.transforms as T
38
  from pathlib import Path
 
142
  # others
143
  self.placeholder_token = placeholder_token
144
  self.placeholder_token_id = placeholder_token_id
145
+
 
 
146
  fpath = Path("_utils/config.yaml")
147
  args_ = get_track_args()
148
 
 
933
  elif mode == "greedy_nodiv":
934
  return track_greedy(candidate_graph, allow_divisions=False)
935
  elif mode == "ilp":
936
+ from models.tra_post_model.tracking.ilp import track_ilp
937
 
938
  return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
939
  else: