Spaces:
Running on Zero
Running on Zero
Commit Β·
4ce5a27
1
Parent(s): 02e04fb
cleanup
Browse files- _utils/load_track_data.py +1 -1
- _utils/track_args.py +1 -96
- inference_track.py +1 -1
- models/tra_post_model/{trackastra/data/wrfeat.py β data.py} +0 -348
- models/tra_post_model/{trackastra/model/model.py β model.py} +280 -12
- models/tra_post_model/trackastra/data/__init__.py +0 -18
- models/tra_post_model/trackastra/data/_check_ctc.py +0 -114
- models/tra_post_model/trackastra/data/_compute_overlap.py +0 -209
- models/tra_post_model/trackastra/data/augmentations.py +0 -557
- models/tra_post_model/trackastra/data/data.py +0 -1509
- models/tra_post_model/trackastra/data/distributed.py +0 -316
- models/tra_post_model/trackastra/data/example_data.py +0 -48
- models/tra_post_model/trackastra/data/features.py +0 -148
- models/tra_post_model/trackastra/data/matching.py +0 -251
- models/tra_post_model/trackastra/data/utils.py +0 -232
- models/tra_post_model/trackastra/model/__init__.py +0 -4
- models/tra_post_model/trackastra/model/model_api.py +0 -338
- models/tra_post_model/trackastra/model/model_parts.py +0 -287
- models/tra_post_model/trackastra/model/model_sd.py +0 -338
- models/tra_post_model/trackastra/model/predict.py +0 -188
- models/tra_post_model/trackastra/model/pretrained.json +0 -81
- models/tra_post_model/trackastra/model/pretrained.py +0 -90
- models/tra_post_model/trackastra/model/rope.py +0 -94
- models/tra_post_model/trackastra/utils/__init__.py +0 -14
- models/tra_post_model/{trackastra/tracking β tracking}/__init__.py +0 -3
- models/tra_post_model/{trackastra/tracking β tracking}/ilp.py +0 -0
- models/tra_post_model/{trackastra/tracking β tracking}/track_graph.py +0 -0
- models/tra_post_model/{trackastra/tracking β tracking}/tracking.py +1 -9
- models/tra_post_model/{trackastra/tracking β tracking}/utils.py +73 -103
- models/tra_post_model/{trackastra/utils/utils.py β utils.py} +0 -338
- tracking_one.py +6 -8
_utils/load_track_data.py
CHANGED
|
@@ -9,7 +9,7 @@ import skimage.io as io
|
|
| 9 |
import torchvision.transforms as T
|
| 10 |
import cv2
|
| 11 |
from tqdm import tqdm
|
| 12 |
-
from models.tra_post_model.
|
| 13 |
IMG_SIZE = 512
|
| 14 |
|
| 15 |
def _load_tiffs(folder: Path, dtype=None):
|
|
|
|
| 9 |
import torchvision.transforms as T
|
| 10 |
import cv2
|
| 11 |
from tqdm import tqdm
|
| 12 |
+
from models.tra_post_model.utils import normalize_01, normalize
|
| 13 |
IMG_SIZE = 512
|
| 14 |
|
| 15 |
def _load_tiffs(folder: Path, dtype=None):
|
_utils/track_args.py
CHANGED
|
@@ -14,49 +14,14 @@ def parse_train_args():
|
|
| 14 |
is_config_file=True,
|
| 15 |
help="config file path",
|
| 16 |
)
|
| 17 |
-
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
|
| 18 |
-
parser.add_argument("-o", "--outdir", type=str, default="runs")
|
| 19 |
-
parser.add_argument("--name", type=str, help="Name to append to timestamp")
|
| 20 |
-
parser.add_argument("--timestamp", type=bool, default=True)
|
| 21 |
-
parser.add_argument(
|
| 22 |
-
"-m",
|
| 23 |
-
"--model",
|
| 24 |
-
type=str,
|
| 25 |
-
default="",
|
| 26 |
-
help="load this model at start (e.g. to continue training)",
|
| 27 |
-
)
|
| 28 |
-
parser.add_argument(
|
| 29 |
-
"--ndim", type=int, default=2, help="number of spatial dimensions"
|
| 30 |
-
)
|
| 31 |
parser.add_argument("-d", "--d_model", type=int, default=256)
|
| 32 |
parser.add_argument("-w", "--window", type=int, default=10)
|
| 33 |
-
parser.add_argument("--epochs", type=int, default=100)
|
| 34 |
-
parser.add_argument("--warmup_epochs", type=int, default=10)
|
| 35 |
-
parser.add_argument(
|
| 36 |
-
"--detection_folders",
|
| 37 |
-
type=str,
|
| 38 |
-
nargs="+",
|
| 39 |
-
default=["TRA"],
|
| 40 |
-
help=(
|
| 41 |
-
"Subfolders to search for detections. Defaults to `TRA`, which corresponds"
|
| 42 |
-
" to using only the GT."
|
| 43 |
-
),
|
| 44 |
-
)
|
| 45 |
-
parser.add_argument("--downscale_temporal", type=int, default=1)
|
| 46 |
-
parser.add_argument("--downscale_spatial", type=int, default=1)
|
| 47 |
parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
|
| 48 |
-
parser.add_argument("--from_subfolder", action="store_true")
|
| 49 |
-
# parser.add_argument("--train_samples", type=int, default=50000)
|
| 50 |
parser.add_argument("--num_encoder_layers", type=int, default=6)
|
| 51 |
parser.add_argument("--num_decoder_layers", type=int, default=6)
|
| 52 |
parser.add_argument("--pos_embed_per_dim", type=int, default=32)
|
| 53 |
parser.add_argument("--feat_embed_per_dim", type=int, default=8)
|
| 54 |
parser.add_argument("--dropout", type=float, default=0.00)
|
| 55 |
-
parser.add_argument("--num_workers", type=int, default=4)
|
| 56 |
-
parser.add_argument("--batch_size", type=int, default=1)
|
| 57 |
-
parser.add_argument("--max_tokens", type=int, default=None)
|
| 58 |
-
parser.add_argument("--delta_cutoff", type=int, default=2)
|
| 59 |
-
parser.add_argument("--lr", type=float, default=1e-4)
|
| 60 |
parser.add_argument(
|
| 61 |
"--attn_positional_bias",
|
| 62 |
type=str,
|
|
@@ -65,72 +30,12 @@ def parse_train_args():
|
|
| 65 |
)
|
| 66 |
parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
|
| 67 |
parser.add_argument("--attn_dist_mode", default="v0")
|
| 68 |
-
parser.add_argument("--mixedp", type=bool, default=True)
|
| 69 |
-
parser.add_argument("--dry", action="store_true")
|
| 70 |
-
parser.add_argument("--profile", action="store_true")
|
| 71 |
-
parser.add_argument(
|
| 72 |
-
"--features",
|
| 73 |
-
type=str,
|
| 74 |
-
choices=[
|
| 75 |
-
"none",
|
| 76 |
-
"regionprops",
|
| 77 |
-
"regionprops2",
|
| 78 |
-
"patch",
|
| 79 |
-
"patch_regionprops",
|
| 80 |
-
"wrfeat",
|
| 81 |
-
],
|
| 82 |
-
default="wrfeat",
|
| 83 |
-
)
|
| 84 |
parser.add_argument(
|
| 85 |
"--causal_norm",
|
| 86 |
type=str,
|
| 87 |
choices=["none", "linear", "softmax", "quiet_softmax"],
|
| 88 |
default="quiet_softmax",
|
| 89 |
)
|
| 90 |
-
parser.add_argument("--div_upweight", type=float, default=2)
|
| 91 |
-
|
| 92 |
-
parser.add_argument("--augment", type=int, default=3)
|
| 93 |
-
parser.add_argument("--tracking_frequency", type=int, default=-1)
|
| 94 |
-
|
| 95 |
-
parser.add_argument("--sanity_dist", action="store_true")
|
| 96 |
-
parser.add_argument("--preallocate", type=bool, default=False)
|
| 97 |
-
parser.add_argument("--only_prechecks", action="store_true")
|
| 98 |
-
parser.add_argument(
|
| 99 |
-
"--compress", type=bool, default=True, help="compress dataset"
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
parser.add_argument("--seed", type=int, default=None)
|
| 104 |
-
parser.add_argument(
|
| 105 |
-
"--logger",
|
| 106 |
-
type=str,
|
| 107 |
-
default="tensorboard",
|
| 108 |
-
choices=["tensorboard", "wandb", "none"],
|
| 109 |
-
)
|
| 110 |
-
parser.add_argument("--wandb_project", type=str, default="trackastra")
|
| 111 |
-
parser.add_argument(
|
| 112 |
-
"--crop_size",
|
| 113 |
-
type=int,
|
| 114 |
-
# required=True,
|
| 115 |
-
nargs="+",
|
| 116 |
-
default=None,
|
| 117 |
-
help="random crop size for augmentation",
|
| 118 |
-
)
|
| 119 |
-
parser.add_argument(
|
| 120 |
-
"--weight_by_ndivs",
|
| 121 |
-
type=bool,
|
| 122 |
-
default=True,
|
| 123 |
-
help="Oversample windows that contain divisions",
|
| 124 |
-
)
|
| 125 |
-
parser.add_argument(
|
| 126 |
-
"--weight_by_dataset",
|
| 127 |
-
type=bool,
|
| 128 |
-
default=False,
|
| 129 |
-
help=(
|
| 130 |
-
"Inversely weight datasets by number of samples (to counter dataset size"
|
| 131 |
-
" imbalance)"
|
| 132 |
-
),
|
| 133 |
-
)
|
| 134 |
|
| 135 |
args, unknown_args = parser.parse_known_args()
|
| 136 |
|
|
@@ -154,4 +59,4 @@ def parse_train_args():
|
|
| 154 |
# "--train_samples must be > 0, full dataset pass not supported."
|
| 155 |
# )
|
| 156 |
|
| 157 |
-
return args
|
|
|
|
| 14 |
is_config_file=True,
|
| 15 |
help="config file path",
|
| 16 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
parser.add_argument("-d", "--d_model", type=int, default=256)
|
| 18 |
parser.add_argument("-w", "--window", type=int, default=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
|
|
|
|
|
|
|
| 20 |
parser.add_argument("--num_encoder_layers", type=int, default=6)
|
| 21 |
parser.add_argument("--num_decoder_layers", type=int, default=6)
|
| 22 |
parser.add_argument("--pos_embed_per_dim", type=int, default=32)
|
| 23 |
parser.add_argument("--feat_embed_per_dim", type=int, default=8)
|
| 24 |
parser.add_argument("--dropout", type=float, default=0.00)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
parser.add_argument(
|
| 26 |
"--attn_positional_bias",
|
| 27 |
type=str,
|
|
|
|
| 30 |
)
|
| 31 |
parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
|
| 32 |
parser.add_argument("--attn_dist_mode", default="v0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
parser.add_argument(
|
| 34 |
"--causal_norm",
|
| 35 |
type=str,
|
| 36 |
choices=["none", "linear", "softmax", "quiet_softmax"],
|
| 37 |
default="quiet_softmax",
|
| 38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
args, unknown_args = parser.parse_known_args()
|
| 41 |
|
|
|
|
| 59 |
# "--train_samples must be > 0, full dataset pass not supported."
|
| 60 |
# )
|
| 61 |
|
| 62 |
+
return args
|
inference_track.py
CHANGED
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
from tracking_one import TrackingModule
|
| 10 |
-
from models.tra_post_model.
|
| 11 |
|
| 12 |
MODEL = None
|
| 13 |
DEVICE = torch.device("cpu")
|
|
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
from tracking_one import TrackingModule
|
| 10 |
+
from models.tra_post_model.tracking import graph_to_ctc
|
| 11 |
|
| 12 |
MODEL = None
|
| 13 |
DEVICE = torch.device("cpu")
|
models/tra_post_model/{trackastra/data/wrfeat.py β data.py}
RENAMED
|
@@ -17,11 +17,6 @@ from skimage.measure import regionprops, regionprops_table
|
|
| 17 |
from tqdm import tqdm
|
| 18 |
from typing import Tuple, Optional, Sequence, Union, List
|
| 19 |
import typing
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
from .utils import load_tiff_timeseries
|
| 23 |
-
except:
|
| 24 |
-
from utils import load_tiff_timeseries
|
| 25 |
import torch
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
|
@@ -42,36 +37,6 @@ _PROPERTIES = {
|
|
| 42 |
}
|
| 43 |
|
| 44 |
|
| 45 |
-
def _filter_points(
|
| 46 |
-
points: np.ndarray, shape: Tuple[int], origin: Optional[Tuple[int]] = None
|
| 47 |
-
) -> np.ndarray:
|
| 48 |
-
"""Returns indices of points that are inside the shape extent and given origin."""
|
| 49 |
-
ndim = points.shape[-1]
|
| 50 |
-
if origin is None:
|
| 51 |
-
origin = (0,) * ndim
|
| 52 |
-
|
| 53 |
-
idx = tuple(
|
| 54 |
-
np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i])
|
| 55 |
-
for i in range(ndim)
|
| 56 |
-
)
|
| 57 |
-
idx = np.where(np.all(idx, axis=0))[0]
|
| 58 |
-
return idx
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def _border_dist(mask: np.ndarray, cutoff: float = 5):
|
| 62 |
-
"""Returns distance to border normalized to 0 (at least cutoff away) and 1 (at border)."""
|
| 63 |
-
border = np.zeros_like(mask)
|
| 64 |
-
|
| 65 |
-
# only apply to last two dimensions
|
| 66 |
-
ss = tuple(
|
| 67 |
-
slice(None) if i < mask.ndim - 2 else slice(1, -1)
|
| 68 |
-
for i, s in enumerate(mask.shape)
|
| 69 |
-
)
|
| 70 |
-
border[ss] = 1
|
| 71 |
-
dist = 1 - np.minimum(edt(border) / cutoff, 1)
|
| 72 |
-
return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist))
|
| 73 |
-
|
| 74 |
-
|
| 75 |
def _border_dist_fast(mask: np.ndarray, cutoff: float = 5):
|
| 76 |
cutoff = int(cutoff)
|
| 77 |
border = np.ones(mask.shape, dtype=np.float32)
|
|
@@ -241,268 +206,6 @@ class WRFeatures:
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
| 244 |
-
# augmentations
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class WRRandomCrop:
|
| 248 |
-
"""windowed region random crop augmentation."""
|
| 249 |
-
|
| 250 |
-
def __init__(
|
| 251 |
-
self,
|
| 252 |
-
crop_size: Optional[Union[int, Tuple[int]]] = None,
|
| 253 |
-
ndim: int = 2,
|
| 254 |
-
) -> None:
|
| 255 |
-
"""crop_size: tuple of int
|
| 256 |
-
can be tuple of length 1 (all dimensions)
|
| 257 |
-
of length ndim (y,x,...)
|
| 258 |
-
of length 2*ndim (y1,y2, x1,x2, ...).
|
| 259 |
-
"""
|
| 260 |
-
if isinstance(crop_size, int):
|
| 261 |
-
crop_size = (crop_size,) * 2 * ndim
|
| 262 |
-
elif isinstance(crop_size, Iterable):
|
| 263 |
-
pass
|
| 264 |
-
else:
|
| 265 |
-
raise ValueError(f"{crop_size} has to be int or tuple of int")
|
| 266 |
-
|
| 267 |
-
if len(crop_size) == 1:
|
| 268 |
-
crop_size = (crop_size[0],) * 2 * ndim
|
| 269 |
-
elif len(crop_size) == ndim:
|
| 270 |
-
crop_size = tuple(itertools.chain(*tuple((c, c) for c in crop_size)))
|
| 271 |
-
elif len(crop_size) == 2 * ndim:
|
| 272 |
-
pass
|
| 273 |
-
else:
|
| 274 |
-
raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}")
|
| 275 |
-
|
| 276 |
-
crop_size = np.array(crop_size)
|
| 277 |
-
self._ndim = ndim
|
| 278 |
-
self._crop_bounds = crop_size[::2], crop_size[1::2]
|
| 279 |
-
self._rng = np.random.RandomState()
|
| 280 |
-
|
| 281 |
-
def __call__(self, features: WRFeatures):
|
| 282 |
-
crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1)
|
| 283 |
-
points = features.coords
|
| 284 |
-
|
| 285 |
-
if len(points) == 0:
|
| 286 |
-
print("No points given, cannot ensure inside points")
|
| 287 |
-
return features
|
| 288 |
-
|
| 289 |
-
# sample point and corner relative to it
|
| 290 |
-
|
| 291 |
-
_idx = np.random.randint(len(points))
|
| 292 |
-
corner = (
|
| 293 |
-
points[_idx]
|
| 294 |
-
- crop_size
|
| 295 |
-
+ 1
|
| 296 |
-
+ self._rng.randint(crop_size // 4, 3 * crop_size // 4)
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
idx = _filter_points(points, shape=crop_size, origin=corner)
|
| 300 |
-
|
| 301 |
-
return (
|
| 302 |
-
WRFeatures(
|
| 303 |
-
coords=points[idx],
|
| 304 |
-
labels=features.labels[idx],
|
| 305 |
-
timepoints=features.timepoints[idx],
|
| 306 |
-
features=OrderedDict((k, v[idx]) for k, v in features.features.items()),
|
| 307 |
-
),
|
| 308 |
-
idx,
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
class WRBaseAugmentation:
|
| 313 |
-
def __init__(self, p: float = 0.5) -> None:
|
| 314 |
-
self._p = p
|
| 315 |
-
self._rng = np.random.RandomState()
|
| 316 |
-
|
| 317 |
-
def __call__(self, features: WRFeatures):
|
| 318 |
-
if self._rng.rand() > self._p or len(features) == 0:
|
| 319 |
-
return features
|
| 320 |
-
return self._augment(features)
|
| 321 |
-
|
| 322 |
-
def _augment(self, features: WRFeatures):
|
| 323 |
-
raise NotImplementedError()
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
class WRRandomFlip(WRBaseAugmentation):
|
| 327 |
-
def _augment(self, features: WRFeatures):
|
| 328 |
-
ndim = features.ndim
|
| 329 |
-
flip = self._rng.randint(0, 2, features.ndim)
|
| 330 |
-
points = features.coords.copy()
|
| 331 |
-
for i, f in enumerate(flip):
|
| 332 |
-
if f == 1:
|
| 333 |
-
points[:, ndim - i - 1] *= -1
|
| 334 |
-
return WRFeatures(
|
| 335 |
-
coords=points,
|
| 336 |
-
labels=features.labels,
|
| 337 |
-
timepoints=features.timepoints,
|
| 338 |
-
features=features.features,
|
| 339 |
-
)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
def _scale_matrix(sz: float, sy: float, sx: float):
|
| 343 |
-
return np.diag([sz, sy, sx])
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
# def _scale_matrix(sy: float, sx: float):
|
| 347 |
-
# return np.array([[1, 0, 0], [0, sy, 0], [0, 0, sx]])
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
def _shear_matrix(shy: float, shx: float):
|
| 351 |
-
return np.array([[1, 0, 0], [0, 1 + shx * shy, shy], [0, shx, 1]])
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def _rotation_matrix(theta: float):
|
| 355 |
-
return np.array([
|
| 356 |
-
[1, 0, 0],
|
| 357 |
-
[0, np.cos(theta), -np.sin(theta)],
|
| 358 |
-
[0, np.sin(theta), np.cos(theta)],
|
| 359 |
-
])
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
def _transform_affine(k: str, v: np.ndarray, M: np.ndarray):
|
| 363 |
-
ndim = len(M)
|
| 364 |
-
if k == "area":
|
| 365 |
-
v = np.linalg.det(M) * v
|
| 366 |
-
elif k == "equivalent_diameter_area":
|
| 367 |
-
v = np.linalg.det(M) ** (1 / len(M)) * v
|
| 368 |
-
|
| 369 |
-
elif k == "inertia_tensor":
|
| 370 |
-
# v' = M * v * M^T
|
| 371 |
-
v = v.reshape(-1, ndim, ndim)
|
| 372 |
-
# v * M^T
|
| 373 |
-
v = np.einsum("ijk, mk -> ijm", v, M)
|
| 374 |
-
# M * v
|
| 375 |
-
v = np.einsum("ij, kjm -> kim", M, v)
|
| 376 |
-
v = v.reshape(-1, ndim * ndim)
|
| 377 |
-
elif k in (
|
| 378 |
-
"intensity_mean",
|
| 379 |
-
"intensity_std",
|
| 380 |
-
"intensity_max",
|
| 381 |
-
"intensity_min",
|
| 382 |
-
"border_dist",
|
| 383 |
-
):
|
| 384 |
-
pass
|
| 385 |
-
else:
|
| 386 |
-
raise ValueError(f"Don't know how to affinely transform {k}")
|
| 387 |
-
return v
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
class WRRandomAffine(WRBaseAugmentation):
|
| 391 |
-
def __init__(
|
| 392 |
-
self,
|
| 393 |
-
degrees: float = 10,
|
| 394 |
-
scale: float = (0.9, 1.1),
|
| 395 |
-
shear: float = (0.1, 0.1),
|
| 396 |
-
p: float = 0.5,
|
| 397 |
-
):
|
| 398 |
-
super().__init__(p)
|
| 399 |
-
self.degrees = degrees if degrees is not None else 0
|
| 400 |
-
self.scale = scale if scale is not None else (1, 1)
|
| 401 |
-
self.shear = shear if shear is not None else (0, 0)
|
| 402 |
-
|
| 403 |
-
def _augment(self, features: WRFeatures):
|
| 404 |
-
degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi
|
| 405 |
-
scale = self._rng.uniform(*self.scale, 3)
|
| 406 |
-
shy = self._rng.uniform(-self.shear[0], self.shear[0])
|
| 407 |
-
shx = self._rng.uniform(-self.shear[1], self.shear[1])
|
| 408 |
-
|
| 409 |
-
self._M = (
|
| 410 |
-
_rotation_matrix(degrees) @ _scale_matrix(*scale) @ _shear_matrix(shy, shx)
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
# M is by default 3D , we need to remove the last dimension for 2D
|
| 414 |
-
self._M = self._M[-features.ndim :, -features.ndim :]
|
| 415 |
-
points = features.coords @ self._M.T
|
| 416 |
-
|
| 417 |
-
feats = OrderedDict(
|
| 418 |
-
(k, _transform_affine(k, v, self._M)) for k, v in features.features.items()
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
-
return WRFeatures(
|
| 422 |
-
coords=points,
|
| 423 |
-
labels=features.labels,
|
| 424 |
-
timepoints=features.timepoints,
|
| 425 |
-
features=feats,
|
| 426 |
-
)
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
class WRRandomBrightness(WRBaseAugmentation):
|
| 430 |
-
def __init__(
|
| 431 |
-
self,
|
| 432 |
-
scale: Tuple[float] = (0.5, 2.0),
|
| 433 |
-
shift: Tuple[float] = (-0.1, 0.1),
|
| 434 |
-
p: float = 0.5,
|
| 435 |
-
):
|
| 436 |
-
super().__init__(p)
|
| 437 |
-
self.scale = scale
|
| 438 |
-
self.shift = shift
|
| 439 |
-
|
| 440 |
-
def _augment(self, features: WRFeatures):
|
| 441 |
-
scale = self._rng.uniform(*self.scale)
|
| 442 |
-
shift = self._rng.uniform(*self.shift)
|
| 443 |
-
|
| 444 |
-
key_vals = []
|
| 445 |
-
|
| 446 |
-
for k, v in features.features.items():
|
| 447 |
-
if "intensity" in k:
|
| 448 |
-
v = v * scale + shift
|
| 449 |
-
key_vals.append((k, v))
|
| 450 |
-
feats = OrderedDict(key_vals)
|
| 451 |
-
return WRFeatures(
|
| 452 |
-
coords=features.coords,
|
| 453 |
-
labels=features.labels,
|
| 454 |
-
timepoints=features.timepoints,
|
| 455 |
-
features=feats,
|
| 456 |
-
)
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
class WRRandomOffset(WRBaseAugmentation):
|
| 460 |
-
def __init__(self, offset: float = (-3, 3), p: float = 0.5):
|
| 461 |
-
super().__init__(p)
|
| 462 |
-
self.offset = offset
|
| 463 |
-
|
| 464 |
-
def _augment(self, features: WRFeatures):
|
| 465 |
-
offset = self._rng.uniform(*self.offset, features.coords.shape)
|
| 466 |
-
coords = features.coords + offset
|
| 467 |
-
return WRFeatures(
|
| 468 |
-
coords=coords,
|
| 469 |
-
labels=features.labels,
|
| 470 |
-
timepoints=features.timepoints,
|
| 471 |
-
features=features.features,
|
| 472 |
-
)
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
class WRRandomMovement(WRBaseAugmentation):
|
| 476 |
-
"""random global linear shift."""
|
| 477 |
-
|
| 478 |
-
def __init__(self, offset: float = (-10, 10), p: float = 0.5):
|
| 479 |
-
super().__init__(p)
|
| 480 |
-
self.offset = offset
|
| 481 |
-
|
| 482 |
-
def _augment(self, features: WRFeatures):
|
| 483 |
-
base_offset = self._rng.uniform(*self.offset, features.coords.shape[-1])
|
| 484 |
-
tmin = features.timepoints.min()
|
| 485 |
-
offset = (features.timepoints[:, None] - tmin) * base_offset[None]
|
| 486 |
-
coords = features.coords + offset
|
| 487 |
-
|
| 488 |
-
return WRFeatures(
|
| 489 |
-
coords=coords,
|
| 490 |
-
labels=features.labels,
|
| 491 |
-
timepoints=features.timepoints,
|
| 492 |
-
features=features.features,
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
class WRAugmentationPipeline:
|
| 497 |
-
def __init__(self, augmentations: Sequence[WRBaseAugmentation]):
|
| 498 |
-
self.augmentations = augmentations
|
| 499 |
-
|
| 500 |
-
def __call__(self, feats: WRFeatures):
|
| 501 |
-
for aug in self.augmentations:
|
| 502 |
-
feats = aug(feats)
|
| 503 |
-
return feats
|
| 504 |
-
|
| 505 |
-
|
| 506 |
def get_features(
|
| 507 |
detections: np.ndarray,
|
| 508 |
imgs: Optional[np.ndarray] = None,
|
|
@@ -561,36 +264,6 @@ def _check_dimensions(x: np.ndarray, ndim: int):
|
|
| 561 |
return x
|
| 562 |
|
| 563 |
|
| 564 |
-
def build_windows(
|
| 565 |
-
features: List[WRFeatures], window_size: int, progbar_class=tqdm
|
| 566 |
-
) -> List[dict]:
|
| 567 |
-
windows = []
|
| 568 |
-
for t1, t2 in progbar_class(
|
| 569 |
-
zip(range(0, len(features)), range(window_size, len(features) + 1)),
|
| 570 |
-
total=len(features) - window_size + 1,
|
| 571 |
-
desc="Building windows",
|
| 572 |
-
):
|
| 573 |
-
feat = WRFeatures.concat(features[t1:t2])
|
| 574 |
-
|
| 575 |
-
labels = feat.labels
|
| 576 |
-
timepoints = feat.timepoints
|
| 577 |
-
coords = feat.coords
|
| 578 |
-
|
| 579 |
-
if len(feat) == 0:
|
| 580 |
-
coords = np.zeros((0, feat.ndim), dtype=int)
|
| 581 |
-
|
| 582 |
-
w = dict(
|
| 583 |
-
coords=coords,
|
| 584 |
-
t1=t1,
|
| 585 |
-
labels=labels,
|
| 586 |
-
timepoints=timepoints,
|
| 587 |
-
features=feat.features_stacked,
|
| 588 |
-
)
|
| 589 |
-
windows.append(w)
|
| 590 |
-
|
| 591 |
-
logger.debug(f"Built {len(windows)} track windows.\n")
|
| 592 |
-
return windows
|
| 593 |
-
|
| 594 |
def build_windows_sd(
|
| 595 |
features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm
|
| 596 |
) -> List[dict]:
|
|
@@ -632,24 +305,3 @@ def build_windows_sd(
|
|
| 632 |
logger.debug(f"Built {len(windows)} track windows.\n")
|
| 633 |
return windows
|
| 634 |
|
| 635 |
-
if __name__ == "__main__":
|
| 636 |
-
imgs = load_tiff_timeseries(
|
| 637 |
-
# "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01",
|
| 638 |
-
"/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01",
|
| 639 |
-
)
|
| 640 |
-
masks = load_tiff_timeseries(
|
| 641 |
-
# "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA",
|
| 642 |
-
"/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA",
|
| 643 |
-
dtype=int,
|
| 644 |
-
)
|
| 645 |
-
|
| 646 |
-
features = get_features(detections=masks, imgs=imgs, ndim=3)
|
| 647 |
-
windows = build_windows(features, window_size=4)
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
# if __name__ == "__main__":
|
| 651 |
-
# y = np.zeros((1, 100, 100), np.uint8)
|
| 652 |
-
# y[:, 20:40, 20:60] = 1
|
| 653 |
-
# x = y + np.random.normal(0, 0.1, y.shape)
|
| 654 |
-
|
| 655 |
-
# f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area"))
|
|
|
|
| 17 |
from tqdm import tqdm
|
| 18 |
from typing import Tuple, Optional, Sequence, Union, List
|
| 19 |
import typing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import torch
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
|
|
|
| 37 |
}
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def _border_dist_fast(mask: np.ndarray, cutoff: float = 5):
|
| 41 |
cutoff = int(cutoff)
|
| 42 |
border = np.ones(mask.shape, dtype=np.float32)
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
def get_features(
|
| 210 |
detections: np.ndarray,
|
| 211 |
imgs: Optional[np.ndarray] = None,
|
|
|
|
| 264 |
return x
|
| 265 |
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
def build_windows_sd(
|
| 268 |
features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm
|
| 269 |
) -> List[dict]:
|
|
|
|
| 305 |
logger.debug(f"Built {len(windows)} track windows.\n")
|
| 306 |
return windows
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/{trackastra/model/model.py β model.py}
RENAMED
|
@@ -1,32 +1,300 @@
|
|
| 1 |
"""Transformer class."""
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
from collections import OrderedDict
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Literal
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 9 |
|
| 10 |
-
# from torch_geometric.nn import GATv2Conv
|
| 11 |
import yaml
|
| 12 |
from torch import nn
|
| 13 |
|
| 14 |
import sys, os
|
| 15 |
-
sys.path.append(os.path.join(os.getcwd(), "External_Repos", "trackastra"))
|
| 16 |
|
| 17 |
-
|
| 18 |
-
from ..utils import blockwise_causal_norm
|
| 19 |
|
| 20 |
-
from .model_parts import (
|
| 21 |
-
FeedForward,
|
| 22 |
-
PositionalEncoding,
|
| 23 |
-
RelativePositionalAttention,
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
# from memory_profiler import profile
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class EncoderLayer(nn.Module):
|
| 31 |
def __init__(
|
| 32 |
self,
|
|
|
|
| 1 |
"""Transformer class."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
import math
|
| 5 |
from collections import OrderedDict
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import Literal, Tuple
|
| 8 |
|
| 9 |
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
|
|
|
|
| 12 |
import yaml
|
| 13 |
from torch import nn
|
| 14 |
|
| 15 |
import sys, os
|
|
|
|
| 16 |
|
| 17 |
+
from .utils import blockwise_causal_norm
|
|
|
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
|
| 22 |
+
def _pos_embed_fourier1d_init(
|
| 23 |
+
cutoff: float = 256, n: int = 32, cutoff_start: float = 1
|
| 24 |
+
):
|
| 25 |
+
return (
|
| 26 |
+
torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n))
|
| 27 |
+
.unsqueeze(0)
|
| 28 |
+
.unsqueeze(0)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
|
| 33 |
+
# Maximum initial frequency is 1
|
| 34 |
+
return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
"""Rotate pairs of scalars as 2d vectors by pi/2."""
|
| 39 |
+
x = x.unflatten(-1, (-1, 2))
|
| 40 |
+
x1, x2 = x.unbind(dim=-1)
|
| 41 |
+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RotaryPositionalEncoding(nn.Module):
|
| 45 |
+
def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
|
| 46 |
+
super().__init__()
|
| 47 |
+
assert len(cutoffs) == len(n_pos)
|
| 48 |
+
if not all(n % 2 == 0 for n in n_pos):
|
| 49 |
+
raise ValueError("n_pos must be even")
|
| 50 |
+
|
| 51 |
+
self._n_dim = len(cutoffs)
|
| 52 |
+
self.freqs = nn.ParameterList([
|
| 53 |
+
nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2))
|
| 54 |
+
for cutoff, n in zip(cutoffs, n_pos)
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
def get_co_si(self, coords: torch.Tensor):
|
| 58 |
+
_B, _N, D = coords.shape
|
| 59 |
+
assert D == len(self.freqs)
|
| 60 |
+
co = torch.cat(
|
| 61 |
+
tuple(
|
| 62 |
+
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
|
| 63 |
+
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
|
| 64 |
+
),
|
| 65 |
+
axis=-1,
|
| 66 |
+
)
|
| 67 |
+
si = torch.cat(
|
| 68 |
+
tuple(
|
| 69 |
+
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
|
| 70 |
+
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
|
| 71 |
+
),
|
| 72 |
+
axis=-1,
|
| 73 |
+
)
|
| 74 |
+
return co, si
|
| 75 |
+
|
| 76 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
|
| 77 |
+
_B, _N, D = coords.shape
|
| 78 |
+
_B, _H, _N, _C = q.shape
|
| 79 |
+
|
| 80 |
+
if D != self._n_dim:
|
| 81 |
+
raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
|
| 82 |
+
|
| 83 |
+
co, si = self.get_co_si(coords)
|
| 84 |
+
co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
|
| 85 |
+
si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
|
| 86 |
+
q2 = q * co + _rotate_half(q) * si
|
| 87 |
+
k2 = k * co + _rotate_half(k) * si
|
| 88 |
+
return q2, k2
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class FeedForward(nn.Module):
|
| 92 |
+
def __init__(self, d_model, expand: float = 2, bias: bool = True):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.fc1 = nn.Linear(d_model, int(d_model * expand))
|
| 95 |
+
self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias)
|
| 96 |
+
self.act = nn.GELU()
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class PositionalEncoding(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
cutoffs: Tuple[float] = (256,),
|
| 106 |
+
n_pos: Tuple[int] = (32,),
|
| 107 |
+
cutoffs_start=None,
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
if cutoffs_start is None:
|
| 111 |
+
cutoffs_start = (1,) * len(cutoffs)
|
| 112 |
+
|
| 113 |
+
assert len(cutoffs) == len(n_pos)
|
| 114 |
+
self.freqs = nn.ParameterList([
|
| 115 |
+
nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
|
| 116 |
+
for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start)
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
def forward(self, coords: torch.Tensor):
|
| 120 |
+
_B, _N, D = coords.shape
|
| 121 |
+
assert D == len(self.freqs)
|
| 122 |
+
embed = torch.cat(
|
| 123 |
+
tuple(
|
| 124 |
+
torch.cat(
|
| 125 |
+
(
|
| 126 |
+
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq),
|
| 127 |
+
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq),
|
| 128 |
+
),
|
| 129 |
+
axis=-1,
|
| 130 |
+
)
|
| 131 |
+
/ math.sqrt(len(freq))
|
| 132 |
+
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
|
| 133 |
+
),
|
| 134 |
+
axis=-1,
|
| 135 |
+
)
|
| 136 |
+
return embed
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _bin_init_exp(cutoff: float, n: int):
|
| 140 |
+
return torch.exp(torch.linspace(0, math.log(cutoff + 1), n))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _bin_init_linear(cutoff: float, n: int):
|
| 144 |
+
return torch.linspace(-cutoff, cutoff, n)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class RelativePositionalBias(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
n_head: int,
|
| 151 |
+
cutoff_spatial: float,
|
| 152 |
+
cutoff_temporal: float,
|
| 153 |
+
n_spatial: int = 32,
|
| 154 |
+
n_temporal: int = 16,
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial)
|
| 158 |
+
self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1)
|
| 159 |
+
self.register_buffer("spatial_bins", self._spatial_bins)
|
| 160 |
+
self.register_buffer("temporal_bins", self._temporal_bins)
|
| 161 |
+
self.n_spatial = n_spatial
|
| 162 |
+
self.n_head = n_head
|
| 163 |
+
self.bias = nn.Parameter(
|
| 164 |
+
-0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(self, coords: torch.Tensor):
|
| 168 |
+
_B, _N, _D = coords.shape
|
| 169 |
+
t = coords[..., 0]
|
| 170 |
+
yx = coords[..., 1:]
|
| 171 |
+
temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2)
|
| 172 |
+
spatial_dist = torch.cdist(yx, yx)
|
| 173 |
+
|
| 174 |
+
spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins)
|
| 175 |
+
torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1)
|
| 176 |
+
temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins)
|
| 177 |
+
torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1)
|
| 178 |
+
|
| 179 |
+
idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial
|
| 180 |
+
bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head))
|
| 181 |
+
bias = bias.transpose(-1, 1)
|
| 182 |
+
return bias
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class RelativePositionalAttention(nn.Module):
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
coord_dim: int,
|
| 189 |
+
embed_dim: int,
|
| 190 |
+
n_head: int,
|
| 191 |
+
cutoff_spatial: float = 256,
|
| 192 |
+
cutoff_temporal: float = 16,
|
| 193 |
+
n_spatial: int = 32,
|
| 194 |
+
n_temporal: int = 16,
|
| 195 |
+
dropout: float = 0.0,
|
| 196 |
+
mode: Literal["bias", "rope", "none"] = "bias",
|
| 197 |
+
attn_dist_mode: str = "v0",
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
|
| 201 |
+
if not embed_dim % (2 * n_head) == 0:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 207 |
+
self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 208 |
+
self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 209 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 210 |
+
self.dropout = dropout
|
| 211 |
+
self.n_head = n_head
|
| 212 |
+
self.embed_dim = embed_dim
|
| 213 |
+
self.cutoff_spatial = cutoff_spatial
|
| 214 |
+
self.attn_dist_mode = attn_dist_mode
|
| 215 |
+
|
| 216 |
+
if mode == "bias" or mode is True:
|
| 217 |
+
self.pos_bias = RelativePositionalBias(
|
| 218 |
+
n_head=n_head,
|
| 219 |
+
cutoff_spatial=cutoff_spatial,
|
| 220 |
+
cutoff_temporal=cutoff_temporal,
|
| 221 |
+
n_spatial=n_spatial,
|
| 222 |
+
n_temporal=n_temporal,
|
| 223 |
+
)
|
| 224 |
+
elif mode == "rope":
|
| 225 |
+
n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head))
|
| 226 |
+
self.rot_pos_enc = RotaryPositionalEncoding(
|
| 227 |
+
cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim),
|
| 228 |
+
n_pos=(embed_dim // n_head - coord_dim * n_split,)
|
| 229 |
+
+ (n_split,) * coord_dim,
|
| 230 |
+
)
|
| 231 |
+
elif mode == "none":
|
| 232 |
+
pass
|
| 233 |
+
elif mode is None or mode is False:
|
| 234 |
+
logger.warning(
|
| 235 |
+
"attn_positional_bias is not set (None or False), no positional bias."
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Unknown mode {mode}")
|
| 239 |
+
|
| 240 |
+
self._mode = mode
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
query: torch.Tensor,
|
| 245 |
+
key: torch.Tensor,
|
| 246 |
+
value: torch.Tensor,
|
| 247 |
+
coords: torch.Tensor,
|
| 248 |
+
padding_mask: torch.Tensor = None,
|
| 249 |
+
):
|
| 250 |
+
B, N, D = query.size()
|
| 251 |
+
q = self.q_pro(query)
|
| 252 |
+
k = self.k_pro(key)
|
| 253 |
+
v = self.v_pro(value)
|
| 254 |
+
k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
|
| 255 |
+
q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
|
| 256 |
+
v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
|
| 257 |
+
|
| 258 |
+
attn_mask = torch.zeros(
|
| 259 |
+
(B, self.n_head, N, N), device=query.device, dtype=q.dtype
|
| 260 |
+
)
|
| 261 |
+
attn_ignore_val = -1e3
|
| 262 |
+
|
| 263 |
+
yx = coords[..., 1:]
|
| 264 |
+
spatial_dist = torch.cdist(yx, yx)
|
| 265 |
+
spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
|
| 266 |
+
attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
|
| 267 |
+
|
| 268 |
+
if coords is not None:
|
| 269 |
+
if self._mode == "bias":
|
| 270 |
+
attn_mask = attn_mask + self.pos_bias(coords)
|
| 271 |
+
elif self._mode == "rope":
|
| 272 |
+
q, k = self.rot_pos_enc(q, k, coords)
|
| 273 |
+
|
| 274 |
+
if self.attn_dist_mode == "v0":
|
| 275 |
+
dist = torch.cdist(coords, coords, p=2)
|
| 276 |
+
attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
|
| 277 |
+
elif self.attn_dist_mode == "v1":
|
| 278 |
+
attn_mask += torch.exp(
|
| 279 |
+
-5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
|
| 283 |
+
|
| 284 |
+
if padding_mask is not None:
|
| 285 |
+
ignore_mask = torch.logical_or(
|
| 286 |
+
padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
|
| 287 |
+
).unsqueeze(1)
|
| 288 |
+
attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
|
| 289 |
+
|
| 290 |
+
y = F.scaled_dot_product_attention(
|
| 291 |
+
q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
|
| 292 |
+
)
|
| 293 |
+
y = y.transpose(1, 2).contiguous().view(B, N, D)
|
| 294 |
+
y = self.proj(y)
|
| 295 |
+
return y
|
| 296 |
+
|
| 297 |
+
|
| 298 |
class EncoderLayer(nn.Module):
|
| 299 |
def __init__(
|
| 300 |
self,
|
models/tra_post_model/trackastra/data/__init__.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
# ruff: noqa: F401
|
| 2 |
-
|
| 3 |
-
from .augmentations import AugmentationPipeline, RandomCrop
|
| 4 |
-
from .data import (
|
| 5 |
-
CTCData,
|
| 6 |
-
_ctc_lineages,
|
| 7 |
-
# load_ctc_data_from_subfolders,
|
| 8 |
-
collate_sequence_padding,
|
| 9 |
-
extract_features_regionprops,
|
| 10 |
-
)
|
| 11 |
-
from .distributed import (
|
| 12 |
-
BalancedBatchSampler,
|
| 13 |
-
# BalancedDataModule,
|
| 14 |
-
BalancedDistributedSampler,
|
| 15 |
-
)
|
| 16 |
-
from .example_data import example_data_bacteria, example_data_fluo_3d, example_data_hela, data_hela
|
| 17 |
-
from .utils import filter_track_df, load_tiff_timeseries, load_tracklet_links
|
| 18 |
-
from .wrfeat import WRFeatures, build_windows, get_features, build_windows_sd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/_check_ctc.py
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from skimage.measure import label, regionprops_table
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# from https://github.com/Janelia-Trackathon-2023/traccuracy/blob/main/src/traccuracy/loaders/_ctc.py
|
| 11 |
-
def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray):
|
| 12 |
-
"""Sanity checks for valid CTC format.
|
| 13 |
-
|
| 14 |
-
Hard checks (throws exception):
|
| 15 |
-
- Tracklet IDs in tracks file must be unique and positive
|
| 16 |
-
- Parent tracklet IDs must exist in the tracks file
|
| 17 |
-
- Intertracklet edges must be directed forward in time.
|
| 18 |
-
- In each time point, the set of segmentation IDs present in the detections must equal the set
|
| 19 |
-
of tracklet IDs in the tracks file that overlap this time point.
|
| 20 |
-
|
| 21 |
-
Soft checks (prints warning):
|
| 22 |
-
- No duplicate tracklet IDs (non-connected pixels with same ID) in a single timepoint.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
tracks (pd.DataFrame): Tracks in CTC format with columns Cell_ID, Start, End, Parent_ID.
|
| 26 |
-
detections (pd.DataFrame): Detections extracted from masks, containing columns
|
| 27 |
-
segmentation_id, t.
|
| 28 |
-
masks (np.ndarray): Set of masks with time in the first axis.
|
| 29 |
-
|
| 30 |
-
Raises:
|
| 31 |
-
ValueError: If any of the hard checks fail.
|
| 32 |
-
"""
|
| 33 |
-
logger.debug("Running CTC format checks")
|
| 34 |
-
tracks = tracks.copy()
|
| 35 |
-
tracks.columns = ["Cell_ID", "Start", "End", "Parent_ID"]
|
| 36 |
-
if tracks["Cell_ID"].min() < 1:
|
| 37 |
-
raise ValueError("Cell_IDs in tracks file must be positive integers.")
|
| 38 |
-
if len(tracks["Cell_ID"]) < len(tracks["Cell_ID"].unique()):
|
| 39 |
-
raise ValueError("Cell_IDs in tracks file must be unique integers.")
|
| 40 |
-
|
| 41 |
-
for _, row in tracks.iterrows():
|
| 42 |
-
if row["Parent_ID"] != 0:
|
| 43 |
-
if row["Parent_ID"] not in tracks["Cell_ID"].values:
|
| 44 |
-
raise ValueError(
|
| 45 |
-
f"Parent_ID {row['Parent_ID']} is not present in tracks."
|
| 46 |
-
)
|
| 47 |
-
parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0]
|
| 48 |
-
if parent_end >= row["Start"]:
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} "
|
| 51 |
-
f"starts at t={row['Start']}, "
|
| 52 |
-
f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}."
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
for t in range(tracks["Start"].min(), tracks["End"].max()):
|
| 56 |
-
track_ids = set(
|
| 57 |
-
tracks[(tracks["Start"] <= t) & (tracks["End"] >= t)]["Cell_ID"]
|
| 58 |
-
)
|
| 59 |
-
det_ids = set(detections[(detections["t"] == t)]["segmentation_id"])
|
| 60 |
-
if not track_ids.issubset(det_ids):
|
| 61 |
-
raise ValueError(f"Missing IDs in masks at t={t}: {track_ids - det_ids}")
|
| 62 |
-
if not det_ids.issubset(track_ids):
|
| 63 |
-
raise ValueError(
|
| 64 |
-
f"IDs {det_ids - track_ids} at t={t} not represented in tracks file."
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
for t, frame in enumerate(masks):
|
| 68 |
-
_, n_components = label(frame, return_num=True)
|
| 69 |
-
n_labels = len(detections[detections["t"] == t])
|
| 70 |
-
if n_labels < n_components:
|
| 71 |
-
logger.warning(f"{n_components - n_labels} non-connected masks at t={t}.")
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _get_node_attributes(masks):
|
| 75 |
-
"""Calculates x,y,z,t,label for each detection in a movie.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
masks (np.ndarray): Set of masks with time in the first axis
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
pd.DataFrame: Dataframe with one detection per row. Columns
|
| 82 |
-
segmentation_id, x, y, z, t
|
| 83 |
-
"""
|
| 84 |
-
data_df = pd.concat([
|
| 85 |
-
_detections_from_image(masks, idx) for idx in range(masks.shape[0])
|
| 86 |
-
]).reset_index(drop=True)
|
| 87 |
-
data_df = data_df.rename(
|
| 88 |
-
columns={
|
| 89 |
-
"label": "segmentation_id",
|
| 90 |
-
"centroid-2": "z",
|
| 91 |
-
"centroid-1": "y",
|
| 92 |
-
"centroid-0": "x",
|
| 93 |
-
}
|
| 94 |
-
)
|
| 95 |
-
data_df["segmentation_id"] = data_df["segmentation_id"].astype(int)
|
| 96 |
-
data_df["t"] = data_df["t"].astype(int)
|
| 97 |
-
return data_df
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def _detections_from_image(stack, idx):
|
| 101 |
-
"""Return the unique track label, centroid and time for each track vertex.
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
stack (np.ndarray): Stack of masks
|
| 105 |
-
idx (int): Index of the image to calculate the centroids and track labels
|
| 106 |
-
|
| 107 |
-
Returns:
|
| 108 |
-
pd.DataFrame: The dataframe of track data for one time step (specified by idx)
|
| 109 |
-
"""
|
| 110 |
-
props = regionprops_table(
|
| 111 |
-
np.asarray(stack[idx, ...]), properties=("label", "centroid")
|
| 112 |
-
)
|
| 113 |
-
props["t"] = np.full(props["label"].shape, idx)
|
| 114 |
-
return pd.DataFrame(props)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/_compute_overlap.py
DELETED
|
@@ -1,209 +0,0 @@
|
|
| 1 |
-
"""Adapted from Fast R-CNN
|
| 2 |
-
Written by Sergey Karayev
|
| 3 |
-
Licensed under The MIT License
|
| 4 |
-
Copyright (c) 2015 Microsoft.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
from skimage.measure import regionprops
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def _union_slice(a: tuple[slice], b: tuple[slice]):
|
| 12 |
-
"""Returns the union of slice tuples a and b."""
|
| 13 |
-
starts = tuple(min(_a.start, _b.start) for _a, _b in zip(a, b))
|
| 14 |
-
stops = tuple(max(_a.stop, _b.stop) for _a, _b in zip(a, b))
|
| 15 |
-
return tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def get_labels_with_overlap(gt_frame, res_frame):
|
| 19 |
-
"""Get all labels IDs in gt_frame and res_frame whose bounding boxes
|
| 20 |
-
overlap.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
gt_frame (np.ndarray): ground truth segmentation for a single frame
|
| 24 |
-
res_frame (np.ndarray): result segmentation for a given frame
|
| 25 |
-
|
| 26 |
-
Returns:
|
| 27 |
-
overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes
|
| 28 |
-
overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes
|
| 29 |
-
intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area)
|
| 30 |
-
"""
|
| 31 |
-
gt_frame = gt_frame.astype(np.uint16, copy=False)
|
| 32 |
-
res_frame = res_frame.astype(np.uint16, copy=False)
|
| 33 |
-
gt_props = regionprops(gt_frame)
|
| 34 |
-
gt_boxes = [np.array(gt_prop.bbox) for gt_prop in gt_props]
|
| 35 |
-
gt_boxes = np.array(gt_boxes).astype(np.float64)
|
| 36 |
-
gt_box_labels = np.asarray(
|
| 37 |
-
[int(gt_prop.label) for gt_prop in gt_props], dtype=np.uint16
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
res_props = regionprops(res_frame)
|
| 41 |
-
res_boxes = [np.array(res_prop.bbox) for res_prop in res_props]
|
| 42 |
-
res_boxes = np.array(res_boxes).astype(np.float64)
|
| 43 |
-
res_box_labels = np.asarray(
|
| 44 |
-
[int(res_prop.label) for res_prop in res_props], dtype=np.uint16
|
| 45 |
-
)
|
| 46 |
-
if len(gt_props) == 0 or len(res_props) == 0:
|
| 47 |
-
return [], [], []
|
| 48 |
-
|
| 49 |
-
if gt_frame.ndim == 3:
|
| 50 |
-
overlaps = compute_overlap_3D(gt_boxes, res_boxes)
|
| 51 |
-
else:
|
| 52 |
-
overlaps = compute_overlap(
|
| 53 |
-
gt_boxes, res_boxes
|
| 54 |
-
) # has the form [gt_bbox, res_bbox]
|
| 55 |
-
|
| 56 |
-
# Find the bboxes that have overlap at all (ind_ corresponds to box number - starting at 0)
|
| 57 |
-
ind_gt, ind_res = np.nonzero(overlaps)
|
| 58 |
-
ind_gt = np.asarray(ind_gt, dtype=np.uint16)
|
| 59 |
-
ind_res = np.asarray(ind_res, dtype=np.uint16)
|
| 60 |
-
overlapping_gt_labels = gt_box_labels[ind_gt]
|
| 61 |
-
overlapping_res_labels = res_box_labels[ind_res]
|
| 62 |
-
|
| 63 |
-
intersections_over_gt = []
|
| 64 |
-
for i, j in zip(ind_gt, ind_res):
|
| 65 |
-
sslice = _union_slice(gt_props[i].slice, res_props[j].slice)
|
| 66 |
-
gt_mask = gt_frame[sslice] == gt_box_labels[i]
|
| 67 |
-
res_mask = res_frame[sslice] == res_box_labels[j]
|
| 68 |
-
area_inter = np.count_nonzero(np.logical_and(gt_mask, res_mask))
|
| 69 |
-
area_gt = np.count_nonzero(gt_mask)
|
| 70 |
-
intersections_over_gt.append(area_inter / area_gt)
|
| 71 |
-
|
| 72 |
-
return overlapping_gt_labels, overlapping_res_labels, intersections_over_gt
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def compute_overlap(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray:
|
| 76 |
-
"""Args:
|
| 77 |
-
a: (N, 4) ndarray of float
|
| 78 |
-
b: (K, 4) ndarray of float.
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
| 82 |
-
"""
|
| 83 |
-
N = boxes.shape[0]
|
| 84 |
-
K = query_boxes.shape[0]
|
| 85 |
-
overlaps = np.zeros((N, K), dtype=np.float64)
|
| 86 |
-
for k in range(K):
|
| 87 |
-
box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) * (
|
| 88 |
-
query_boxes[k, 3] - query_boxes[k, 1] + 1
|
| 89 |
-
)
|
| 90 |
-
for n in range(N):
|
| 91 |
-
iw = (
|
| 92 |
-
min(boxes[n, 2], query_boxes[k, 2])
|
| 93 |
-
- max(boxes[n, 0], query_boxes[k, 0])
|
| 94 |
-
+ 1
|
| 95 |
-
)
|
| 96 |
-
if iw > 0:
|
| 97 |
-
ih = (
|
| 98 |
-
min(boxes[n, 3], query_boxes[k, 3])
|
| 99 |
-
- max(boxes[n, 1], query_boxes[k, 1])
|
| 100 |
-
+ 1
|
| 101 |
-
)
|
| 102 |
-
if ih > 0:
|
| 103 |
-
ua = np.float64(
|
| 104 |
-
(boxes[n, 2] - boxes[n, 0] + 1)
|
| 105 |
-
* (boxes[n, 3] - boxes[n, 1] + 1)
|
| 106 |
-
+ box_area
|
| 107 |
-
- iw * ih
|
| 108 |
-
)
|
| 109 |
-
overlaps[n, k] = iw * ih / ua
|
| 110 |
-
return overlaps
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def compute_overlap_3D(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray:
|
| 114 |
-
"""Args:
|
| 115 |
-
a: (N, 6) ndarray of float
|
| 116 |
-
b: (K, 6) ndarray of float.
|
| 117 |
-
|
| 118 |
-
Returns:
|
| 119 |
-
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
| 120 |
-
"""
|
| 121 |
-
N = boxes.shape[0]
|
| 122 |
-
K = query_boxes.shape[0]
|
| 123 |
-
overlaps = np.zeros((N, K), dtype=np.float64)
|
| 124 |
-
for k in range(K):
|
| 125 |
-
box_volume = (
|
| 126 |
-
(query_boxes[k, 3] - query_boxes[k, 0] + 1)
|
| 127 |
-
* (query_boxes[k, 4] - query_boxes[k, 1] + 1)
|
| 128 |
-
* (query_boxes[k, 5] - query_boxes[k, 2] + 1)
|
| 129 |
-
)
|
| 130 |
-
for n in range(N):
|
| 131 |
-
id_ = (
|
| 132 |
-
min(boxes[n, 3], query_boxes[k, 3])
|
| 133 |
-
- max(boxes[n, 0], query_boxes[k, 0])
|
| 134 |
-
+ 1
|
| 135 |
-
)
|
| 136 |
-
if id_ > 0:
|
| 137 |
-
iw = (
|
| 138 |
-
min(boxes[n, 4], query_boxes[k, 4])
|
| 139 |
-
- max(boxes[n, 1], query_boxes[k, 1])
|
| 140 |
-
+ 1
|
| 141 |
-
)
|
| 142 |
-
if iw > 0:
|
| 143 |
-
ih = (
|
| 144 |
-
min(boxes[n, 5], query_boxes[k, 5])
|
| 145 |
-
- max(boxes[n, 2], query_boxes[k, 2])
|
| 146 |
-
+ 1
|
| 147 |
-
)
|
| 148 |
-
if ih > 0:
|
| 149 |
-
ua = np.float64(
|
| 150 |
-
(boxes[n, 3] - boxes[n, 0] + 1)
|
| 151 |
-
* (boxes[n, 4] - boxes[n, 1] + 1)
|
| 152 |
-
* (boxes[n, 5] - boxes[n, 2] + 1)
|
| 153 |
-
+ box_volume
|
| 154 |
-
- iw * ih * id_
|
| 155 |
-
)
|
| 156 |
-
overlaps[n, k] = iw * ih * id_ / ua
|
| 157 |
-
return overlaps
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
try:
|
| 161 |
-
import numba
|
| 162 |
-
except ImportError:
|
| 163 |
-
import os
|
| 164 |
-
import warnings
|
| 165 |
-
|
| 166 |
-
if not os.getenv("NO_JIT_WARNING", False):
|
| 167 |
-
warnings.warn(
|
| 168 |
-
"Numba not installed, falling back to slower numpy implementation. "
|
| 169 |
-
"Install numba for a significant speedup. Set the environment "
|
| 170 |
-
"variable NO_JIT_WARNING=1 to disable this warning.",
|
| 171 |
-
stacklevel=2,
|
| 172 |
-
)
|
| 173 |
-
else:
|
| 174 |
-
# compute_overlap 2d and 3d have the same signature
|
| 175 |
-
signature = [
|
| 176 |
-
"f8[:,::1](f8[:,::1], f8[:,::1])",
|
| 177 |
-
numba.types.Array(numba.float64, 2, "C", readonly=True)(
|
| 178 |
-
numba.types.Array(numba.float64, 2, "C", readonly=True),
|
| 179 |
-
numba.types.Array(numba.float64, 2, "C", readonly=True),
|
| 180 |
-
),
|
| 181 |
-
]
|
| 182 |
-
|
| 183 |
-
# variables that appear in the body of each function
|
| 184 |
-
common_locals = {
|
| 185 |
-
"N": numba.uint64,
|
| 186 |
-
"K": numba.uint64,
|
| 187 |
-
"overlaps": numba.types.Array(numba.float64, 2, "C"),
|
| 188 |
-
"iw": numba.float64,
|
| 189 |
-
"ih": numba.float64,
|
| 190 |
-
"ua": numba.float64,
|
| 191 |
-
"n": numba.uint64,
|
| 192 |
-
"k": numba.uint64,
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
compute_overlap = numba.njit(
|
| 196 |
-
signature,
|
| 197 |
-
locals={**common_locals, "box_area": numba.float64},
|
| 198 |
-
fastmath=True,
|
| 199 |
-
nogil=True,
|
| 200 |
-
boundscheck=False,
|
| 201 |
-
)(compute_overlap)
|
| 202 |
-
|
| 203 |
-
compute_overlap_3D = numba.njit(
|
| 204 |
-
signature,
|
| 205 |
-
locals={**common_locals, "id_": numba.float64, "box_volume": numba.float64},
|
| 206 |
-
fastmath=True,
|
| 207 |
-
nogil=True,
|
| 208 |
-
boundscheck=False,
|
| 209 |
-
)(compute_overlap_3D)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/augmentations.py
DELETED
|
@@ -1,557 +0,0 @@
|
|
| 1 |
-
"""#TODO: dont convert to numpy and back to torch."""
|
| 2 |
-
|
| 3 |
-
from collections.abc import Iterable, Sequence
|
| 4 |
-
from itertools import chain
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
import kornia.augmentation as K
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from kornia.augmentation import random_generator as rg
|
| 11 |
-
from kornia.augmentation.utils import _range_bound
|
| 12 |
-
from kornia.constants import DataKey, Resample
|
| 13 |
-
from typing import Optional, Tuple, Sequence, Dict, Union
|
| 14 |
-
|
| 15 |
-
def default_augmenter(coords: np.ndarray):
|
| 16 |
-
# TODO parametrize magnitude of different augmentations
|
| 17 |
-
ndim = coords.shape[1]
|
| 18 |
-
|
| 19 |
-
assert coords.ndim == 2 and ndim in (2, 3)
|
| 20 |
-
|
| 21 |
-
# first remove offset
|
| 22 |
-
center = coords.mean(axis=0, keepdims=True)
|
| 23 |
-
|
| 24 |
-
coords = coords - center
|
| 25 |
-
|
| 26 |
-
# apply random flip
|
| 27 |
-
coords *= 2 * np.random.randint(0, 2, (1, ndim)) - 1
|
| 28 |
-
|
| 29 |
-
# apply rotation along the last two dimensions
|
| 30 |
-
phi = np.random.uniform(0, 2 * np.pi)
|
| 31 |
-
coords = _rotate(coords, phi, center=None)
|
| 32 |
-
|
| 33 |
-
if ndim == 3:
|
| 34 |
-
# rotate along the first two dimensions too
|
| 35 |
-
phi2, phi3 = np.random.uniform(0, 2 * np.pi, 2)
|
| 36 |
-
coords = _rotate(coords, phi2, rot_axis=(0, 1), center=None)
|
| 37 |
-
coords = _rotate(coords, phi3, rot_axis=(0, 2), center=None)
|
| 38 |
-
|
| 39 |
-
coords += center
|
| 40 |
-
|
| 41 |
-
# translation
|
| 42 |
-
trans = 128 * np.random.uniform(-1, 1, (1, ndim))
|
| 43 |
-
coords += trans
|
| 44 |
-
|
| 45 |
-
# elastic
|
| 46 |
-
coords += 1.5 * np.random.normal(0, 1, coords.shape)
|
| 47 |
-
|
| 48 |
-
return coords
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def _rotate(
|
| 52 |
-
coords: np.ndarray, phi: float, rot_axis=(-2, -1), center: Optional[Tuple] = None
|
| 53 |
-
):
|
| 54 |
-
"""Rotation along the last two dimensions of coords[..,:-2:]."""
|
| 55 |
-
ndim = coords.shape[1]
|
| 56 |
-
assert coords.ndim == 2 and ndim in (2, 3)
|
| 57 |
-
|
| 58 |
-
if center is None:
|
| 59 |
-
center = (0,) * ndim
|
| 60 |
-
|
| 61 |
-
assert len(center) == ndim
|
| 62 |
-
|
| 63 |
-
center = np.asarray(center)
|
| 64 |
-
co, si = np.cos(phi), np.sin(phi)
|
| 65 |
-
Rot = np.eye(ndim)
|
| 66 |
-
Rot[np.ix_(rot_axis, rot_axis)] = np.array(((co, -si), (si, co)))
|
| 67 |
-
x = coords - center
|
| 68 |
-
x = x @ Rot.T
|
| 69 |
-
x += center
|
| 70 |
-
return x
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _filter_points(
|
| 74 |
-
points: np.ndarray, shape: tuple, origin: Optional[Tuple] = None
|
| 75 |
-
) -> np.ndarray:
|
| 76 |
-
"""Returns indices of points that are inside the shape extent and given origin."""
|
| 77 |
-
ndim = points.shape[-1]
|
| 78 |
-
if origin is None:
|
| 79 |
-
origin = (0,) * ndim
|
| 80 |
-
|
| 81 |
-
idx = tuple(
|
| 82 |
-
np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i])
|
| 83 |
-
for i in range(ndim)
|
| 84 |
-
)
|
| 85 |
-
idx = np.where(np.all(idx, axis=0))[0]
|
| 86 |
-
return idx
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class ConcatAffine(K.RandomAffine):
|
| 90 |
-
"""Concatenate multiple affine transformations without intermediates."""
|
| 91 |
-
|
| 92 |
-
def __init__(self, affines: Sequence[K.RandomAffine]):
|
| 93 |
-
super().__init__(degrees=0)
|
| 94 |
-
self._affines = affines
|
| 95 |
-
if not all([a.same_on_batch for a in affines]):
|
| 96 |
-
raise ValueError("all affines must have same_on_batch=True")
|
| 97 |
-
|
| 98 |
-
def merge_params(self, params: Sequence[Dict[str, torch.Tensor]]):
|
| 99 |
-
"""Merge params from affines."""
|
| 100 |
-
out = params[0].copy()
|
| 101 |
-
|
| 102 |
-
def _torchmax(x, dim):
|
| 103 |
-
return torch.max(x, dim=dim).values
|
| 104 |
-
|
| 105 |
-
ops = {
|
| 106 |
-
"translations": torch.sum,
|
| 107 |
-
"center": torch.mean,
|
| 108 |
-
"scale": torch.prod,
|
| 109 |
-
"shear_x": torch.sum,
|
| 110 |
-
"shear_y": torch.sum,
|
| 111 |
-
"angle": torch.sum,
|
| 112 |
-
"batch_prob": _torchmax,
|
| 113 |
-
}
|
| 114 |
-
for k, v in params[0].items():
|
| 115 |
-
ps = [p[k] for p in params if len(p[k]) > 0]
|
| 116 |
-
if len(ps) > 0 and k in ops:
|
| 117 |
-
v_new = torch.stack(ps, dim=0).float()
|
| 118 |
-
v_new = ops[k](v_new, dim=0)
|
| 119 |
-
v_new = v_new.to(v.dtype)
|
| 120 |
-
else:
|
| 121 |
-
v_new = v
|
| 122 |
-
out[k] = v_new
|
| 123 |
-
|
| 124 |
-
return out
|
| 125 |
-
|
| 126 |
-
def forward_parameters(
|
| 127 |
-
self, batch_shape: Tuple[int, ...]
|
| 128 |
-
) -> Dict[str, torch.Tensor]:
|
| 129 |
-
params = tuple(a.forward_parameters(batch_shape) for a in self._affines)
|
| 130 |
-
# print(params)
|
| 131 |
-
return self.merge_params(params)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# custom augmentations
|
| 135 |
-
class RandomIntensityScaleShift(K.IntensityAugmentationBase2D):
|
| 136 |
-
r"""Apply a random scale and shift to the image intensity.
|
| 137 |
-
|
| 138 |
-
Args:
|
| 139 |
-
p: probability of applying the transformation.
|
| 140 |
-
scale: the scale factor to apply
|
| 141 |
-
shift: the offset to apply
|
| 142 |
-
clip_output: if true clip output
|
| 143 |
-
same_on_batch: apply the same transformation across the batch.
|
| 144 |
-
keepdim: whether to keep the output shape the same as input (True) or broadcast it
|
| 145 |
-
to the batch form (False).
|
| 146 |
-
Shape:
|
| 147 |
-
- Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)`
|
| 148 |
-
- Output: :math:`(B, C, H, W)`
|
| 149 |
-
|
| 150 |
-
.. note::
|
| 151 |
-
This function internally uses :func:`kornia.enhance.adjust_brightness`
|
| 152 |
-
|
| 153 |
-
"""
|
| 154 |
-
|
| 155 |
-
def __init__(
|
| 156 |
-
self,
|
| 157 |
-
scale: Tuple[float, float] = (0.5, 2.0),
|
| 158 |
-
shift: Tuple[float, float] = (-0.1, 0.1),
|
| 159 |
-
clip_output: bool = True,
|
| 160 |
-
same_on_batch: bool = False,
|
| 161 |
-
p: float = 1.0,
|
| 162 |
-
keepdim: bool = False,
|
| 163 |
-
) -> None:
|
| 164 |
-
super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
|
| 165 |
-
self.scale = _range_bound(
|
| 166 |
-
scale, "scale", center=0, bounds=(-float("inf"), float("inf"))
|
| 167 |
-
)
|
| 168 |
-
self.shift = _range_bound(
|
| 169 |
-
shift, "shift", center=0, bounds=(-float("inf"), float("inf"))
|
| 170 |
-
)
|
| 171 |
-
self._param_generator = rg.PlainUniformGenerator(
|
| 172 |
-
(self.scale, "scale_factor", None, None),
|
| 173 |
-
(self.shift, "shift_factor", None, None),
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
self.clip_output = clip_output
|
| 177 |
-
|
| 178 |
-
def apply_transform(
|
| 179 |
-
self,
|
| 180 |
-
input: torch.Tensor,
|
| 181 |
-
params: Dict[str, torch.Tensor],
|
| 182 |
-
flags: Dict[str, Any],
|
| 183 |
-
transform: Optional[torch.Tensor] = None,
|
| 184 |
-
) -> torch.Tensor:
|
| 185 |
-
scale_factor = params["scale_factor"].to(input)
|
| 186 |
-
shift_factor = params["shift_factor"].to(input)
|
| 187 |
-
scale_factor = scale_factor.view(len(scale_factor), 1, 1, 1)
|
| 188 |
-
shift_factor = shift_factor.view(len(scale_factor), 1, 1, 1)
|
| 189 |
-
img_adjust = input * scale_factor + shift_factor
|
| 190 |
-
if self.clip_output:
|
| 191 |
-
img_adjust = img_adjust.clamp(min=0.0, max=1.0)
|
| 192 |
-
return img_adjust
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
class RandomTemporalAffine(K.RandomAffine):
|
| 196 |
-
r"""Apply a random 2D affine transformation to a batch of images while
|
| 197 |
-
varying the transformation across the time dimension from 0 to 1.
|
| 198 |
-
|
| 199 |
-
Same args/kwargs as K.RandomAffine
|
| 200 |
-
|
| 201 |
-
"""
|
| 202 |
-
|
| 203 |
-
def __init__(self, *args, **kwargs) -> None:
|
| 204 |
-
super().__init__(*args, same_on_batch=True, **kwargs)
|
| 205 |
-
|
| 206 |
-
def forward_parameters(
|
| 207 |
-
self, batch_shape: Tuple[int, ...]
|
| 208 |
-
) -> Dict[str, torch.Tensor]:
|
| 209 |
-
params = super().forward_parameters(batch_shape)
|
| 210 |
-
factor = torch.linspace(0, 1, batch_shape[0]).to(params["translations"])
|
| 211 |
-
for key in ["translations", "center", "angle", "shear_x", "shear_y"]:
|
| 212 |
-
v = params[key]
|
| 213 |
-
if len(v) > 0:
|
| 214 |
-
params[key] = v * factor.view(*((-1,) + (1,) * len(v.shape[1:])))
|
| 215 |
-
|
| 216 |
-
for key in [
|
| 217 |
-
"scale",
|
| 218 |
-
]:
|
| 219 |
-
v = params[key]
|
| 220 |
-
if len(v) > 0:
|
| 221 |
-
params[key] = 1 + (v - 1) * factor.view(
|
| 222 |
-
*((-1,) + (1,) * len(v.shape[1:]))
|
| 223 |
-
)
|
| 224 |
-
return params
|
| 225 |
-
|
| 226 |
-
# def compute_transformation(self, input: torch.Tensor,
|
| 227 |
-
# params: Dict[str, torch.Tensor],
|
| 228 |
-
# flags: Dict[str, Any]) -> torch.Tensor:
|
| 229 |
-
# factor = torch.linspace(0, 1, input.shape[0]).to(input)
|
| 230 |
-
# for key in ["translations", "center", "angle", "shear_x", "shear_y"]:
|
| 231 |
-
# v = params[key]
|
| 232 |
-
# params[key] = v * factor.view(*((-1,)+(1,)*len(v.shape[1:])))
|
| 233 |
-
|
| 234 |
-
# for key in ["scale", ]:
|
| 235 |
-
# v = params[key]
|
| 236 |
-
# params[key] = 1 + (v-1) * factor.view(*((-1,)+(1,)*len(v.shape[1:])))
|
| 237 |
-
|
| 238 |
-
# return super().compute_transformation(input, params, flags)
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
class BasicPipeline:
|
| 242 |
-
"""transforms img, mask, and points.
|
| 243 |
-
|
| 244 |
-
Only supports 2D transformations for now (any 3D object will preserve its z coordinates/dimensions)
|
| 245 |
-
"""
|
| 246 |
-
|
| 247 |
-
def __init__(self, augs: tuple, filter_points: bool = True):
|
| 248 |
-
self.data_keys = ("input", "mask", "keypoints")
|
| 249 |
-
self.pipeline = K.AugmentationSequential(
|
| 250 |
-
*augs,
|
| 251 |
-
# disable align_corners to not trigger lots of warnings from kornia
|
| 252 |
-
extra_args={
|
| 253 |
-
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": False}
|
| 254 |
-
},
|
| 255 |
-
data_keys=self.data_keys,
|
| 256 |
-
)
|
| 257 |
-
self.filter_points = filter_points
|
| 258 |
-
|
| 259 |
-
def __call__(
|
| 260 |
-
self,
|
| 261 |
-
img: np.ndarray,
|
| 262 |
-
mask: np.ndarray,
|
| 263 |
-
points: np.ndarray,
|
| 264 |
-
timepoints: np.ndarray,
|
| 265 |
-
):
|
| 266 |
-
ndim = img.ndim - 1
|
| 267 |
-
assert (
|
| 268 |
-
ndim in (2, 3)
|
| 269 |
-
and points.ndim == 2
|
| 270 |
-
and points.shape[-1] == ndim
|
| 271 |
-
and timepoints.ndim == 1
|
| 272 |
-
and img.shape == mask.shape
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
x = torch.from_numpy(img).float()
|
| 276 |
-
y = torch.from_numpy(mask.astype(np.int64)).float()
|
| 277 |
-
|
| 278 |
-
# if 2D add dummy channel
|
| 279 |
-
if ndim == 2:
|
| 280 |
-
x = x.unsqueeze(1)
|
| 281 |
-
y = y.unsqueeze(1)
|
| 282 |
-
p = points[..., [1, 0]]
|
| 283 |
-
# if 3D we use z as channel (i.e. fix augs across z)
|
| 284 |
-
elif ndim == 3:
|
| 285 |
-
p = points[..., [2, 1]]
|
| 286 |
-
|
| 287 |
-
# flip as kornia expects xy and not yx
|
| 288 |
-
p = torch.from_numpy(p).unsqueeze(0).float()
|
| 289 |
-
# add batch by duplicating to make kornia happy
|
| 290 |
-
p = p.expand(len(x), -1, -1)
|
| 291 |
-
# create a mask to know which timepoint the points belong to
|
| 292 |
-
ts = torch.from_numpy(timepoints).long()
|
| 293 |
-
n_points = p.shape[1]
|
| 294 |
-
if n_points > 0:
|
| 295 |
-
x, y, p = self.pipeline(x, y, p)
|
| 296 |
-
else:
|
| 297 |
-
# dummy keypoints
|
| 298 |
-
x, y = self.pipeline(x, y, torch.zeros((len(x), 1, 2)))[:2]
|
| 299 |
-
|
| 300 |
-
# remove batch
|
| 301 |
-
p = p[ts, torch.arange(n_points)]
|
| 302 |
-
# flip back
|
| 303 |
-
p = p[..., [1, 0]]
|
| 304 |
-
|
| 305 |
-
# remove channel
|
| 306 |
-
if ndim == 2:
|
| 307 |
-
x = x.squeeze(1)
|
| 308 |
-
y = y.squeeze(1)
|
| 309 |
-
|
| 310 |
-
x = x.numpy()
|
| 311 |
-
y = y.numpy().astype(np.uint16)
|
| 312 |
-
# p = p.squeeze(0).numpy()
|
| 313 |
-
p = p.numpy()
|
| 314 |
-
# add back z coordinates
|
| 315 |
-
if ndim == 3:
|
| 316 |
-
p = np.concatenate([points[..., 0:1], p], axis=-1)
|
| 317 |
-
ts = ts.numpy()
|
| 318 |
-
# remove points outside of img/mask
|
| 319 |
-
|
| 320 |
-
if self.filter_points:
|
| 321 |
-
idx = _filter_points(p, shape=x.shape[-ndim:])
|
| 322 |
-
|
| 323 |
-
else:
|
| 324 |
-
idx = np.arange(len(p), dtype=int)
|
| 325 |
-
|
| 326 |
-
p = p[idx]
|
| 327 |
-
return (x, y, p), idx
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
class RandomCrop:
|
| 331 |
-
def __init__(
|
| 332 |
-
self,
|
| 333 |
-
crop_size: Optional[Union[int, Tuple[int]]] = None,
|
| 334 |
-
ndim: int = 2,
|
| 335 |
-
ensure_inside_points: bool = False,
|
| 336 |
-
use_padding: bool = True,
|
| 337 |
-
padding_mode="constant",
|
| 338 |
-
) -> None:
|
| 339 |
-
"""crop_size: tuple of int
|
| 340 |
-
can be tuple of length 1 (all dimensions)
|
| 341 |
-
of length ndim (y,x,...)
|
| 342 |
-
of length 2*ndim (y1,y2, x1,x2, ...).
|
| 343 |
-
"""
|
| 344 |
-
if isinstance(crop_size, int):
|
| 345 |
-
crop_size = (crop_size,) * 2 * ndim
|
| 346 |
-
elif isinstance(crop_size, Iterable):
|
| 347 |
-
pass
|
| 348 |
-
else:
|
| 349 |
-
raise ValueError(f"{crop_size} has to be int or tuple of int")
|
| 350 |
-
|
| 351 |
-
if len(crop_size) == 1:
|
| 352 |
-
crop_size = (crop_size[0],) * 2 * ndim
|
| 353 |
-
elif len(crop_size) == ndim:
|
| 354 |
-
crop_size = tuple(chain(*tuple((c, c) for c in crop_size)))
|
| 355 |
-
elif len(crop_size) == 2 * ndim:
|
| 356 |
-
pass
|
| 357 |
-
else:
|
| 358 |
-
raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}")
|
| 359 |
-
|
| 360 |
-
crop_size = np.array(crop_size)
|
| 361 |
-
self._ndim = ndim
|
| 362 |
-
self._crop_bounds = crop_size[::2], crop_size[1::2]
|
| 363 |
-
self._use_padding = use_padding
|
| 364 |
-
self._ensure_inside_points = ensure_inside_points
|
| 365 |
-
self._rng = np.random.RandomState()
|
| 366 |
-
self._padding_mode = padding_mode
|
| 367 |
-
|
| 368 |
-
def crop_img(self, img: np.ndarray, corner: np.ndarray, crop_size: np.ndarray):
|
| 369 |
-
if not img.ndim == self._ndim + 1:
|
| 370 |
-
raise ValueError(
|
| 371 |
-
f"img has to be 1 (time) + {self._ndim} spatial dimensions"
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
pad_left = np.maximum(0, -corner)
|
| 375 |
-
pad_right = np.maximum(
|
| 376 |
-
0, corner + crop_size - np.array(img.shape[-self._ndim :])
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
img = np.pad(
|
| 380 |
-
img,
|
| 381 |
-
((0, 0), *tuple(np.stack((pad_left, pad_right)).T)),
|
| 382 |
-
mode=self._padding_mode,
|
| 383 |
-
)
|
| 384 |
-
slices = (
|
| 385 |
-
slice(None),
|
| 386 |
-
*tuple(slice(c, c + s) for c, s in zip(corner + pad_left, crop_size)),
|
| 387 |
-
)
|
| 388 |
-
return img[slices]
|
| 389 |
-
|
| 390 |
-
def crop_points(
|
| 391 |
-
self, points: np.ndarray, corner: np.ndarray, crop_size: np.ndarray
|
| 392 |
-
):
|
| 393 |
-
idx = _filter_points(points, shape=crop_size, origin=corner)
|
| 394 |
-
return points[idx] - corner, idx
|
| 395 |
-
|
| 396 |
-
def __call__(self, img: np.ndarray, mask: np.ndarray, points: np.ndarray):
|
| 397 |
-
assert (
|
| 398 |
-
img.ndim == self._ndim + 1
|
| 399 |
-
and points.ndim == 2
|
| 400 |
-
and points.shape[-1] == self._ndim
|
| 401 |
-
and img.shape == mask.shape
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
points = points.astype(int)
|
| 405 |
-
|
| 406 |
-
crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1)
|
| 407 |
-
# print(f'{crop_size=}')
|
| 408 |
-
|
| 409 |
-
if self._ensure_inside_points:
|
| 410 |
-
if len(points) == 0:
|
| 411 |
-
print("No points given, cannot ensure inside points")
|
| 412 |
-
return (img, mask, points), np.zeros((0,), int)
|
| 413 |
-
|
| 414 |
-
# sample point and corner relative to it
|
| 415 |
-
|
| 416 |
-
_idx = np.random.randint(len(points))
|
| 417 |
-
corner = (
|
| 418 |
-
points[_idx]
|
| 419 |
-
- crop_size
|
| 420 |
-
+ 1
|
| 421 |
-
+ self._rng.randint(crop_size // 4, 3 * crop_size // 4)
|
| 422 |
-
)
|
| 423 |
-
else:
|
| 424 |
-
corner = self._rng.randint(
|
| 425 |
-
0, np.maximum(1, np.array(img.shape[-self._ndim :]) - crop_size)
|
| 426 |
-
)
|
| 427 |
-
|
| 428 |
-
if not self._use_padding:
|
| 429 |
-
corner = np.maximum(0, corner)
|
| 430 |
-
crop_size = np.minimum(
|
| 431 |
-
crop_size, np.array(img.shape[-self._ndim :]) - corner
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
img = self.crop_img(img, corner, crop_size)
|
| 435 |
-
mask = self.crop_img(mask, corner, crop_size)
|
| 436 |
-
points, idx = self.crop_points(points, corner, crop_size)
|
| 437 |
-
|
| 438 |
-
return (img, mask, points), idx
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
class AugmentationPipeline(BasicPipeline):
|
| 442 |
-
"""transforms img, mask, and points."""
|
| 443 |
-
|
| 444 |
-
def __init__(self, p=0.5, filter_points=True, level=1):
|
| 445 |
-
if level == 1:
|
| 446 |
-
augs = [
|
| 447 |
-
# Augmentations for all images in a window
|
| 448 |
-
K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
|
| 449 |
-
K.RandomVerticalFlip(p=0.5, same_on_batch=True),
|
| 450 |
-
K.RandomAffine(
|
| 451 |
-
degrees=180,
|
| 452 |
-
shear=(-10, 10, -10, 10), # x_min, x_max, y_min, y_max
|
| 453 |
-
translate=(0.05, 0.05),
|
| 454 |
-
scale=(0.8, 1.2), # x_min, x_max, y_min, y_max
|
| 455 |
-
p=p,
|
| 456 |
-
same_on_batch=True,
|
| 457 |
-
),
|
| 458 |
-
K.RandomBrightness(
|
| 459 |
-
(0.5, 1.5), clip_output=False, p=p, same_on_batch=True
|
| 460 |
-
),
|
| 461 |
-
K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
|
| 462 |
-
]
|
| 463 |
-
elif level == 2:
|
| 464 |
-
# Crafted for DeepCell crop size 256
|
| 465 |
-
augs = [
|
| 466 |
-
# Augmentations for all images in a window
|
| 467 |
-
K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
|
| 468 |
-
K.RandomVerticalFlip(p=0.5, same_on_batch=True),
|
| 469 |
-
K.RandomAffine(
|
| 470 |
-
degrees=180,
|
| 471 |
-
shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max
|
| 472 |
-
translate=(0.03, 0.03),
|
| 473 |
-
scale=(0.8, 1.2), # isotropic
|
| 474 |
-
p=p,
|
| 475 |
-
same_on_batch=True,
|
| 476 |
-
),
|
| 477 |
-
# Anisotropic scaling
|
| 478 |
-
K.RandomAffine(
|
| 479 |
-
degrees=0,
|
| 480 |
-
scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max
|
| 481 |
-
p=p,
|
| 482 |
-
same_on_batch=True,
|
| 483 |
-
),
|
| 484 |
-
# Independet augmentations for each image in window
|
| 485 |
-
K.RandomAffine(
|
| 486 |
-
degrees=3,
|
| 487 |
-
shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max
|
| 488 |
-
translate=(0.04, 0.04),
|
| 489 |
-
p=p,
|
| 490 |
-
same_on_batch=False,
|
| 491 |
-
),
|
| 492 |
-
# not implemented for points in kornia 0.7.0
|
| 493 |
-
# K.RandomElasticTransform(alpha=50, sigma=5, p=p, same_on_batch=False),
|
| 494 |
-
# Intensity-based augmentations
|
| 495 |
-
K.RandomBrightness(
|
| 496 |
-
(0.5, 1.5), clip_output=False, p=p, same_on_batch=True
|
| 497 |
-
),
|
| 498 |
-
K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
|
| 499 |
-
]
|
| 500 |
-
elif level == 3:
|
| 501 |
-
# Crafted for DeepCell crop size 256
|
| 502 |
-
augs = [
|
| 503 |
-
# Augmentations for all images in a window
|
| 504 |
-
K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
|
| 505 |
-
K.RandomVerticalFlip(p=0.5, same_on_batch=True),
|
| 506 |
-
ConcatAffine([
|
| 507 |
-
K.RandomAffine(
|
| 508 |
-
degrees=180,
|
| 509 |
-
shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max
|
| 510 |
-
translate=(0.03, 0.03),
|
| 511 |
-
scale=(0.8, 1.2), # isotropic
|
| 512 |
-
p=p,
|
| 513 |
-
same_on_batch=True,
|
| 514 |
-
),
|
| 515 |
-
# Anisotropic scaling
|
| 516 |
-
K.RandomAffine(
|
| 517 |
-
degrees=0,
|
| 518 |
-
scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max
|
| 519 |
-
p=p,
|
| 520 |
-
same_on_batch=True,
|
| 521 |
-
),
|
| 522 |
-
]),
|
| 523 |
-
RandomTemporalAffine(
|
| 524 |
-
degrees=10,
|
| 525 |
-
translate=(0.05, 0.05),
|
| 526 |
-
p=p,
|
| 527 |
-
# same_on_batch=True,
|
| 528 |
-
),
|
| 529 |
-
# Independet augmentations for each image in window
|
| 530 |
-
K.RandomAffine(
|
| 531 |
-
degrees=2,
|
| 532 |
-
shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max
|
| 533 |
-
translate=(0.01, 0.01),
|
| 534 |
-
p=0.5 * p,
|
| 535 |
-
same_on_batch=False,
|
| 536 |
-
),
|
| 537 |
-
# Intensity-based augmentations
|
| 538 |
-
RandomIntensityScaleShift(
|
| 539 |
-
(0.5, 2.0), (-0.1, 0.1), clip_output=False, p=p, same_on_batch=True
|
| 540 |
-
),
|
| 541 |
-
K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
|
| 542 |
-
]
|
| 543 |
-
elif level == 4:
|
| 544 |
-
# debug
|
| 545 |
-
augs = [
|
| 546 |
-
K.RandomAffine(
|
| 547 |
-
degrees=30,
|
| 548 |
-
shear=(-0, 0, -0, 0), # x_min, x_max, y_min, y_max
|
| 549 |
-
translate=(0.0, 0.0),
|
| 550 |
-
p=1,
|
| 551 |
-
same_on_batch=True,
|
| 552 |
-
),
|
| 553 |
-
]
|
| 554 |
-
else:
|
| 555 |
-
raise ValueError(f"level {level} not supported")
|
| 556 |
-
|
| 557 |
-
super().__init__(augs, filter_points)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/data.py
DELETED
|
@@ -1,1509 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
# from collections.abc import Sequence
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from timeit import default_timer
|
| 5 |
-
from typing import Literal
|
| 6 |
-
|
| 7 |
-
import joblib
|
| 8 |
-
import lz4.frame
|
| 9 |
-
import networkx as nx
|
| 10 |
-
import numpy as np
|
| 11 |
-
import pandas as pd
|
| 12 |
-
import tifffile
|
| 13 |
-
import torch
|
| 14 |
-
from numba import njit
|
| 15 |
-
from scipy import ndimage as ndi
|
| 16 |
-
from scipy.spatial.distance import cdist
|
| 17 |
-
from skimage.measure import regionprops
|
| 18 |
-
from skimage.segmentation import relabel_sequential
|
| 19 |
-
from torch.utils.data import Dataset
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
|
| 22 |
-
from . import wrfeat
|
| 23 |
-
from ._check_ctc import _check_ctc, _get_node_attributes
|
| 24 |
-
from .augmentations import (
|
| 25 |
-
AugmentationPipeline,
|
| 26 |
-
RandomCrop,
|
| 27 |
-
default_augmenter,
|
| 28 |
-
)
|
| 29 |
-
from .features import (
|
| 30 |
-
_PROPERTIES,
|
| 31 |
-
extract_features_patch,
|
| 32 |
-
extract_features_regionprops,
|
| 33 |
-
)
|
| 34 |
-
from .matching import matching
|
| 35 |
-
|
| 36 |
-
from typing import List, Optional, Union, Tuple, Sequence
|
| 37 |
-
|
| 38 |
-
# from ..utils import blockwise_sum, normalize
|
| 39 |
-
from ..utils import blockwise_sum, normalize
|
| 40 |
-
|
| 41 |
-
logger = logging.getLogger(__name__)
|
| 42 |
-
logger.setLevel(logging.INFO)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _filter_track_df(df, start_frame, end_frame, downscale):
|
| 46 |
-
"""Only keep tracklets that are present in the given time interval."""
|
| 47 |
-
# only retain cells in interval
|
| 48 |
-
df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)]
|
| 49 |
-
|
| 50 |
-
# shift start and end of each cell
|
| 51 |
-
df.t1 = df.t1 - start_frame
|
| 52 |
-
df.t2 = df.t2 - start_frame
|
| 53 |
-
# set start/end to min/max
|
| 54 |
-
df.t1 = df.t1.clip(0, end_frame - start_frame - 1)
|
| 55 |
-
df.t2 = df.t2.clip(0, end_frame - start_frame - 1)
|
| 56 |
-
# set all parents to 0 that are not in the interval
|
| 57 |
-
df.loc[~df.parent.isin(df.label), "parent"] = 0
|
| 58 |
-
|
| 59 |
-
if downscale > 1:
|
| 60 |
-
if start_frame % downscale != 0:
|
| 61 |
-
raise ValueError("start_frame must be a multiple of downscale")
|
| 62 |
-
|
| 63 |
-
logger.info(f"Temporal downscaling of tracklet links by {downscale}")
|
| 64 |
-
|
| 65 |
-
# remove tracklets that have been fully deleted by temporal downsampling
|
| 66 |
-
|
| 67 |
-
mask = (
|
| 68 |
-
# (df["t2"] - df["t1"] < downscale - 1)
|
| 69 |
-
(df["t1"] % downscale != 0)
|
| 70 |
-
& (df["t2"] % downscale != 0)
|
| 71 |
-
& (df["t1"] // downscale == df["t2"] // downscale)
|
| 72 |
-
)
|
| 73 |
-
logger.info(
|
| 74 |
-
f"Remove {mask.sum()} tracklets that are fully deleted by downsampling"
|
| 75 |
-
)
|
| 76 |
-
logger.debug(f"Remove {df[mask]}")
|
| 77 |
-
|
| 78 |
-
df = df[~mask]
|
| 79 |
-
# set parent to 0 if it has been deleted
|
| 80 |
-
df.loc[~df.parent.isin(df.label), "parent"] = 0
|
| 81 |
-
|
| 82 |
-
df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int)
|
| 83 |
-
df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int)
|
| 84 |
-
|
| 85 |
-
# Correct for edge case of single frame tracklet
|
| 86 |
-
assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"]))
|
| 87 |
-
|
| 88 |
-
return df
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class _CompressedArray:
|
| 92 |
-
"""a simple class to compress and decompress a numpy arrays using lz4."""
|
| 93 |
-
|
| 94 |
-
# dont compress float types
|
| 95 |
-
def __init__(self, data):
|
| 96 |
-
self._data = lz4.frame.compress(data)
|
| 97 |
-
self._dtype = data.dtype.type
|
| 98 |
-
self._shape = data.shape
|
| 99 |
-
|
| 100 |
-
def decompress(self):
|
| 101 |
-
s = lz4.frame.decompress(self._data)
|
| 102 |
-
data = np.frombuffer(s, dtype=self._dtype).reshape(self._shape)
|
| 103 |
-
return data
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def debug_function(f):
|
| 107 |
-
def wrapper(*args, **kwargs):
|
| 108 |
-
try:
|
| 109 |
-
batch = f(*args, **kwargs)
|
| 110 |
-
except Exception as e:
|
| 111 |
-
logger.error(f"Error in {f.__name__}: {e}")
|
| 112 |
-
return None
|
| 113 |
-
logger.info(f"XXXX {len(batch['coords'])}")
|
| 114 |
-
return batch
|
| 115 |
-
|
| 116 |
-
return wrapper
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
class CTCData(Dataset):
|
| 120 |
-
def __init__(
|
| 121 |
-
self,
|
| 122 |
-
root: str = "",
|
| 123 |
-
ndim: int = 2,
|
| 124 |
-
use_gt: bool = True,
|
| 125 |
-
detection_folders: List[str] = ["TRA"],
|
| 126 |
-
window_size: int = 10,
|
| 127 |
-
max_tokens: Optional[int] = None,
|
| 128 |
-
slice_pct: tuple = (0.0, 1.0),
|
| 129 |
-
downscale_spatial: int = 1,
|
| 130 |
-
downscale_temporal: int = 1,
|
| 131 |
-
augment: int = 0,
|
| 132 |
-
features: Literal[
|
| 133 |
-
"none",
|
| 134 |
-
"regionprops",
|
| 135 |
-
"regionprops2",
|
| 136 |
-
"patch",
|
| 137 |
-
"patch_regionprops",
|
| 138 |
-
"wrfeat",
|
| 139 |
-
] = "wrfeat",
|
| 140 |
-
sanity_dist: bool = False,
|
| 141 |
-
crop_size: Optional[tuple] = None,
|
| 142 |
-
return_dense: bool = False,
|
| 143 |
-
compress: bool = False,
|
| 144 |
-
**kwargs,
|
| 145 |
-
) -> None:
|
| 146 |
-
"""_summary_.
|
| 147 |
-
|
| 148 |
-
Args:
|
| 149 |
-
root (str):
|
| 150 |
-
Folder containing the CTC TRA folder.
|
| 151 |
-
ndim (int):
|
| 152 |
-
Number of dimensions of the data. Defaults to 2d
|
| 153 |
-
(if ndim=3 and data is two dimensional, it will be cast to 3D)
|
| 154 |
-
detection_folders:
|
| 155 |
-
List of relative paths to folder with detections.
|
| 156 |
-
Defaults to ["TRA"], which uses the ground truth detections.
|
| 157 |
-
window_size (int):
|
| 158 |
-
Window size for transformer.
|
| 159 |
-
slice_pct (tuple):
|
| 160 |
-
Slice the dataset by percentages (from, to).
|
| 161 |
-
augment (int):
|
| 162 |
-
if 0, no data augmentation. if > 0, defines level of data augmentation.
|
| 163 |
-
features (str):
|
| 164 |
-
Types of features to use.
|
| 165 |
-
sanity_dist (bool):
|
| 166 |
-
Use euclidian distance instead of the association matrix as a target.
|
| 167 |
-
crop_size (tuple):
|
| 168 |
-
Size of the crops to use for augmentation. If None, no cropping is used.
|
| 169 |
-
return_dense (bool):
|
| 170 |
-
Return dense masks and images in the data samples.
|
| 171 |
-
compress (bool):
|
| 172 |
-
Compress elements/remove img if not needed to save memory for large datasets
|
| 173 |
-
"""
|
| 174 |
-
super().__init__()
|
| 175 |
-
|
| 176 |
-
self.root = Path(root)
|
| 177 |
-
self.name = self.root.name
|
| 178 |
-
self.use_gt = use_gt
|
| 179 |
-
self.slice_pct = slice_pct
|
| 180 |
-
if not 0 <= slice_pct[0] < slice_pct[1] <= 1:
|
| 181 |
-
raise ValueError(f"Invalid slice_pct {slice_pct}")
|
| 182 |
-
self.downscale_spatial = downscale_spatial
|
| 183 |
-
self.downscale_temporal = downscale_temporal
|
| 184 |
-
self.detection_folders = detection_folders
|
| 185 |
-
self.ndim = ndim
|
| 186 |
-
self.features = features
|
| 187 |
-
|
| 188 |
-
if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
|
| 189 |
-
raise ValueError(
|
| 190 |
-
f"'{features}' not one of the supported {ndim}D features"
|
| 191 |
-
f" {tuple(_PROPERTIES[ndim].keys())}"
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
logger.info(f"ROOT (config): \t{self.root}")
|
| 195 |
-
self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
|
| 196 |
-
logger.info(f"ROOT (guessed): \t{self.root}")
|
| 197 |
-
logger.info(f"GT TRA (guessed):\t{self.gt_tra_folder}")
|
| 198 |
-
if self.use_gt:
|
| 199 |
-
self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder)
|
| 200 |
-
else:
|
| 201 |
-
logger.info("Using dummy masks as GT")
|
| 202 |
-
self.gt_mask_folder = self._guess_det_folder(
|
| 203 |
-
self.root, self.detection_folders[0]
|
| 204 |
-
)
|
| 205 |
-
logger.info(f"GT MASK (guessed):\t{self.gt_mask_folder}")
|
| 206 |
-
|
| 207 |
-
# dont load image data if not needed
|
| 208 |
-
if features in ("none",):
|
| 209 |
-
self.img_folder = None
|
| 210 |
-
else:
|
| 211 |
-
self.img_folder = self._guess_img_folder(self.root)
|
| 212 |
-
logger.info(f"IMG (guessed):\t{self.img_folder}")
|
| 213 |
-
|
| 214 |
-
self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
|
| 215 |
-
ndim, features, augment, crop_size
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
if window_size <= 1:
|
| 219 |
-
raise ValueError("window must be >1")
|
| 220 |
-
self.window_size = window_size
|
| 221 |
-
self.max_tokens = max_tokens
|
| 222 |
-
|
| 223 |
-
self.slice_pct = slice_pct
|
| 224 |
-
self.sanity_dist = sanity_dist
|
| 225 |
-
self.return_dense = return_dense
|
| 226 |
-
self.compress = compress
|
| 227 |
-
self.start_frame = 0
|
| 228 |
-
self.end_frame = None
|
| 229 |
-
|
| 230 |
-
start = default_timer()
|
| 231 |
-
|
| 232 |
-
if self.features == "wrfeat":
|
| 233 |
-
self.windows = self._load_wrfeat()
|
| 234 |
-
else:
|
| 235 |
-
self.windows = self._load()
|
| 236 |
-
|
| 237 |
-
self.n_divs = self._get_ndivs(self.windows)
|
| 238 |
-
|
| 239 |
-
if len(self.windows) > 0:
|
| 240 |
-
self.ndim = self.windows[0]["coords"].shape[1]
|
| 241 |
-
self.n_objects = tuple(len(t["coords"]) for t in self.windows)
|
| 242 |
-
logger.info(
|
| 243 |
-
f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
|
| 244 |
-
f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
|
| 245 |
-
)
|
| 246 |
-
else:
|
| 247 |
-
self.n_objects = 0
|
| 248 |
-
logger.warning(f"Could not load any tracks from {self.root}")
|
| 249 |
-
|
| 250 |
-
if self.compress:
|
| 251 |
-
self._compress_data()
|
| 252 |
-
|
| 253 |
-
# def from_ctc
|
| 254 |
-
|
| 255 |
-
@classmethod
|
| 256 |
-
def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
|
| 257 |
-
self = cls(**train_args)
|
| 258 |
-
# for key, value in train_args.items():
|
| 259 |
-
# setattr(self, key, value)
|
| 260 |
-
|
| 261 |
-
# self.use_gt = use_gt
|
| 262 |
-
# self.slice_pct = slice_pct
|
| 263 |
-
# if not 0 <= slice_pct[0] < slice_pct[1] <= 1:
|
| 264 |
-
# raise ValueError(f"Invalid slice_pct {slice_pct}")
|
| 265 |
-
# self.downscale_spatial = downscale_spatial
|
| 266 |
-
# self.downscale_temporal = downscale_temporal
|
| 267 |
-
# self.detection_folders = detection_folders
|
| 268 |
-
# self.ndim = ndim
|
| 269 |
-
# self.features = features
|
| 270 |
-
|
| 271 |
-
# if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
|
| 272 |
-
# raise ValueError(
|
| 273 |
-
# f"'{features}' not one of the supported {ndim}D features {tuple(_PROPERTIES[ndim].keys())}"
|
| 274 |
-
# )
|
| 275 |
-
|
| 276 |
-
# logger.info(f"ROOT (config): {self.root}")
|
| 277 |
-
# self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
|
| 278 |
-
# logger.info(f"ROOT: \t{self.root}")
|
| 279 |
-
# logger.info(f"GT TRA:\t{self.gt_tra_folder}")
|
| 280 |
-
# if self.use_gt:
|
| 281 |
-
# self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder)
|
| 282 |
-
# else:
|
| 283 |
-
# logger.info("Using dummy masks as GT")
|
| 284 |
-
# self.gt_mask_folder = self._guess_det_folder(
|
| 285 |
-
# self.root, self.detection_folders[0]
|
| 286 |
-
# )
|
| 287 |
-
# logger.info(f"GT MASK:\t{self.gt_mask_folder}")
|
| 288 |
-
|
| 289 |
-
# dont load image data if not needed
|
| 290 |
-
# if features in ("none",):
|
| 291 |
-
# self.img_folder = None
|
| 292 |
-
# else:
|
| 293 |
-
# self.img_folder = self._guess_img_folder(self.root)
|
| 294 |
-
# logger.info(f"IMG:\t\t{self.img_folder}")
|
| 295 |
-
|
| 296 |
-
self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
|
| 297 |
-
self.ndim, self.features, self.augment, self.crop_size
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
start = default_timer()
|
| 301 |
-
|
| 302 |
-
if self.features == "wrfeat":
|
| 303 |
-
self.windows = self._load_wrfeat()
|
| 304 |
-
else:
|
| 305 |
-
self.windows = self._load()
|
| 306 |
-
|
| 307 |
-
self.n_divs = self._get_ndivs(self.windows)
|
| 308 |
-
|
| 309 |
-
if len(self.windows) > 0:
|
| 310 |
-
self.ndim = self.windows[0]["coords"].shape[1]
|
| 311 |
-
self.n_objects = tuple(len(t["coords"]) for t in self.windows)
|
| 312 |
-
logger.info(
|
| 313 |
-
f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
|
| 314 |
-
f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
|
| 315 |
-
)
|
| 316 |
-
else:
|
| 317 |
-
self.n_objects = 0
|
| 318 |
-
logger.warning(f"Could not load any tracks from {self.root}")
|
| 319 |
-
|
| 320 |
-
if self.compress:
|
| 321 |
-
self._compress_data()
|
| 322 |
-
|
| 323 |
-
def _get_ndivs(self, windows):
|
| 324 |
-
n_divs = []
|
| 325 |
-
for w in tqdm(windows, desc="Counting divisions", leave=False):
|
| 326 |
-
_n = (
|
| 327 |
-
(
|
| 328 |
-
blockwise_sum(
|
| 329 |
-
torch.from_numpy(w["assoc_matrix"]).float(),
|
| 330 |
-
torch.from_numpy(w["timepoints"]).long(),
|
| 331 |
-
).max(dim=0)[0]
|
| 332 |
-
== 2
|
| 333 |
-
)
|
| 334 |
-
.sum()
|
| 335 |
-
.item()
|
| 336 |
-
)
|
| 337 |
-
n_divs.append(_n)
|
| 338 |
-
return n_divs
|
| 339 |
-
|
| 340 |
-
def _setup_features_augs(
|
| 341 |
-
self, ndim: int, features: str, augment: int, crop_size: Tuple[int]
|
| 342 |
-
):
|
| 343 |
-
if self.features == "wrfeat":
|
| 344 |
-
return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size)
|
| 345 |
-
|
| 346 |
-
cropper = (
|
| 347 |
-
RandomCrop(
|
| 348 |
-
crop_size=crop_size,
|
| 349 |
-
ndim=ndim,
|
| 350 |
-
use_padding=False,
|
| 351 |
-
ensure_inside_points=True,
|
| 352 |
-
)
|
| 353 |
-
if crop_size is not None
|
| 354 |
-
else None
|
| 355 |
-
)
|
| 356 |
-
|
| 357 |
-
# Hack
|
| 358 |
-
if self.features == "none":
|
| 359 |
-
return 0, default_augmenter, cropper
|
| 360 |
-
|
| 361 |
-
if ndim == 2:
|
| 362 |
-
augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
|
| 363 |
-
feat_dim = {
|
| 364 |
-
"none": 0,
|
| 365 |
-
"regionprops": 7,
|
| 366 |
-
"regionprops2": 6,
|
| 367 |
-
"patch": 256,
|
| 368 |
-
"patch_regionprops": 256 + 5,
|
| 369 |
-
}[features]
|
| 370 |
-
elif ndim == 3:
|
| 371 |
-
augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
|
| 372 |
-
feat_dim = {
|
| 373 |
-
"none": 0,
|
| 374 |
-
"regionprops2": 11,
|
| 375 |
-
"patch_regionprops": 256 + 8,
|
| 376 |
-
}[features]
|
| 377 |
-
|
| 378 |
-
return feat_dim, augmenter, cropper
|
| 379 |
-
|
| 380 |
-
def _compress_data(self):
|
| 381 |
-
# compress masks and assoc_matrix
|
| 382 |
-
logger.info("Compressing masks and assoc_matrix to save memory")
|
| 383 |
-
for w in self.windows:
|
| 384 |
-
w["mask"] = _CompressedArray(w["mask"])
|
| 385 |
-
# dont compress full imgs (as needed for patch features)
|
| 386 |
-
w["img"] = _CompressedArray(w["img"])
|
| 387 |
-
w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"])
|
| 388 |
-
self.gt_masks = _CompressedArray(self.gt_masks)
|
| 389 |
-
self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()}
|
| 390 |
-
# dont compress full imgs (as needed for patch features)
|
| 391 |
-
self.imgs = _CompressedArray(self.imgs)
|
| 392 |
-
|
| 393 |
-
def _guess_root_and_gt_tra_folder(self, inp: Path):
|
| 394 |
-
"""Guesses the root and the ground truth folder from a given input path.
|
| 395 |
-
|
| 396 |
-
Args:
|
| 397 |
-
inp (Path): _description_
|
| 398 |
-
|
| 399 |
-
Returns:
|
| 400 |
-
Path: root folder,
|
| 401 |
-
"""
|
| 402 |
-
if inp.name == "TRA":
|
| 403 |
-
# 01_GT/TRA --> 01, 01_GT/TRA
|
| 404 |
-
root = inp.parent.parent / inp.parent.name.split("_")[0]
|
| 405 |
-
return root, inp
|
| 406 |
-
elif "ERR_SEG" in inp.name:
|
| 407 |
-
# 01_ERR_SEG --> 01, 01_GT/TRA. We know that the data is in CTC folder format
|
| 408 |
-
num = inp.name.split("_")[0]
|
| 409 |
-
return inp.parent / num, inp.parent / f"{num}_GT" / "TRA"
|
| 410 |
-
else:
|
| 411 |
-
ctc_tra = Path(f"{inp}_GT") / "TRA"
|
| 412 |
-
tra = ctc_tra if ctc_tra.exists() else inp / "TRA"
|
| 413 |
-
# 01 --> 01, 01_GT/TRA or 01/TRA
|
| 414 |
-
return inp, tra
|
| 415 |
-
|
| 416 |
-
def _guess_img_folder(self, root: Path):
|
| 417 |
-
"""Guesses the image folder corresponding to a root."""
|
| 418 |
-
if (root / "img").exists():
|
| 419 |
-
return root / "img"
|
| 420 |
-
else:
|
| 421 |
-
return root
|
| 422 |
-
|
| 423 |
-
def _guess_mask_folder(self, root: Path, gt_tra: Path):
|
| 424 |
-
"""Guesses the mask folder corresponding to a root.
|
| 425 |
-
|
| 426 |
-
In CTC format, we use silver truth segmentation masks.
|
| 427 |
-
"""
|
| 428 |
-
f = None
|
| 429 |
-
# first try CTC format
|
| 430 |
-
if gt_tra.parent.name.endswith("_GT"):
|
| 431 |
-
# We use the silver truth segmentation masks
|
| 432 |
-
f = root / str(gt_tra.parent.name).replace("_GT", "_ST") / "SEG"
|
| 433 |
-
# try our simpler 'img' format
|
| 434 |
-
if f is None or not f.exists():
|
| 435 |
-
f = gt_tra
|
| 436 |
-
if not f.exists():
|
| 437 |
-
raise ValueError(f"Could not find mask folder for {root}")
|
| 438 |
-
return f
|
| 439 |
-
|
| 440 |
-
@classmethod
|
| 441 |
-
def _guess_det_folder(cls, root: Path, suffix: str):
|
| 442 |
-
"""Checks for the annoying CTC format with dataset numbering as part of folder names."""
|
| 443 |
-
guesses = (
|
| 444 |
-
(root / suffix),
|
| 445 |
-
Path(f"{root}_{suffix}"),
|
| 446 |
-
Path(f"{root}_GT") / suffix,
|
| 447 |
-
)
|
| 448 |
-
for path in guesses:
|
| 449 |
-
if path.exists():
|
| 450 |
-
return path
|
| 451 |
-
|
| 452 |
-
logger.warning(f"Skipping non-existing detection folder {root / suffix}")
|
| 453 |
-
return None
|
| 454 |
-
|
| 455 |
-
def __len__(self):
|
| 456 |
-
return len(self.windows)
|
| 457 |
-
|
| 458 |
-
def _load_gt(self):
|
| 459 |
-
logger.info("Loading ground truth")
|
| 460 |
-
self.start_frame = int(
|
| 461 |
-
len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[0]
|
| 462 |
-
)
|
| 463 |
-
self.end_frame = int(
|
| 464 |
-
len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[1]
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
masks = self._load_tiffs(self.gt_mask_folder, dtype=np.int32)
|
| 468 |
-
masks = self._correct_gt_with_st(self.gt_mask_folder, masks, dtype=np.int32)
|
| 469 |
-
|
| 470 |
-
if self.use_gt:
|
| 471 |
-
track_df = self._load_tracklet_links(self.gt_tra_folder)
|
| 472 |
-
track_df = _filter_track_df(
|
| 473 |
-
track_df, self.start_frame, self.end_frame, self.downscale_temporal
|
| 474 |
-
)
|
| 475 |
-
else:
|
| 476 |
-
# create dummy track dataframe
|
| 477 |
-
logger.info("Using dummy track dataframe")
|
| 478 |
-
track_df = self._build_tracklets_without_gt(masks)
|
| 479 |
-
|
| 480 |
-
_check_ctc(track_df, _get_node_attributes(masks), masks)
|
| 481 |
-
|
| 482 |
-
# Build ground truth lineage graph
|
| 483 |
-
self.gt_labels, self.gt_timepoints, self.gt_graph = _ctc_lineages(
|
| 484 |
-
track_df, masks
|
| 485 |
-
)
|
| 486 |
-
|
| 487 |
-
return masks, track_df
|
| 488 |
-
|
| 489 |
-
def _correct_gt_with_st(
|
| 490 |
-
self, folder: Path, x: np.ndarray, dtype: Optional[str] = None
|
| 491 |
-
):
|
| 492 |
-
if str(folder).endswith("_GT/TRA"):
|
| 493 |
-
st_path = (
|
| 494 |
-
tuple(folder.parents)[1]
|
| 495 |
-
/ folder.parent.stem.replace("_GT", "_ST")
|
| 496 |
-
/ "SEG"
|
| 497 |
-
)
|
| 498 |
-
if not st_path.exists():
|
| 499 |
-
logger.debug("No _ST folder found, skipping correction")
|
| 500 |
-
else:
|
| 501 |
-
logger.info(f"ST MASK:\t\t{st_path} for correcting with ST masks")
|
| 502 |
-
st_masks = self._load_tiffs(st_path, dtype)
|
| 503 |
-
x = np.maximum(x, st_masks)
|
| 504 |
-
|
| 505 |
-
return x
|
| 506 |
-
|
| 507 |
-
def _load_tiffs(self, folder: Path, dtype=None):
|
| 508 |
-
assert isinstance(self.downscale_temporal, int)
|
| 509 |
-
logger.debug(f"Loading tiffs from {folder} as {dtype}")
|
| 510 |
-
logger.debug(
|
| 511 |
-
f"Temporal downscaling of {folder.name} by {self.downscale_temporal}"
|
| 512 |
-
)
|
| 513 |
-
x = np.stack([
|
| 514 |
-
tifffile.imread(f).astype(dtype)
|
| 515 |
-
for f in tqdm(
|
| 516 |
-
sorted(folder.glob("*.tif"))[
|
| 517 |
-
self.start_frame : self.end_frame : self.downscale_temporal
|
| 518 |
-
],
|
| 519 |
-
leave=False,
|
| 520 |
-
desc=f"Loading [{self.start_frame}:{self.end_frame}]",
|
| 521 |
-
)
|
| 522 |
-
])
|
| 523 |
-
|
| 524 |
-
# T, (Z), Y, X
|
| 525 |
-
assert isinstance(self.downscale_spatial, int)
|
| 526 |
-
if self.downscale_spatial > 1 or self.downscale_temporal > 1:
|
| 527 |
-
# TODO make safe for label arrays
|
| 528 |
-
logger.debug(
|
| 529 |
-
f"Spatial downscaling of {folder.name} by {self.downscale_spatial}"
|
| 530 |
-
)
|
| 531 |
-
slices = (
|
| 532 |
-
slice(None),
|
| 533 |
-
*tuple(
|
| 534 |
-
slice(None, None, self.downscale_spatial) for _ in range(x.ndim - 1)
|
| 535 |
-
),
|
| 536 |
-
)
|
| 537 |
-
x = x[slices]
|
| 538 |
-
|
| 539 |
-
logger.debug(f"Loaded array of shape {x.shape} from {folder}")
|
| 540 |
-
return x
|
| 541 |
-
|
| 542 |
-
def _masks2properties(self, masks):
|
| 543 |
-
"""Turn label masks into lists of properties, sorted (ascending) by time and label id.
|
| 544 |
-
|
| 545 |
-
Args:
|
| 546 |
-
masks (np.ndarray): T, (Z), H, W
|
| 547 |
-
|
| 548 |
-
Returns:
|
| 549 |
-
labels: List of labels
|
| 550 |
-
ts: List of timepoints
|
| 551 |
-
coords: List of coordinates
|
| 552 |
-
"""
|
| 553 |
-
# Get coordinates, timepoints, and labels of detections
|
| 554 |
-
labels = []
|
| 555 |
-
ts = []
|
| 556 |
-
coords = []
|
| 557 |
-
properties_by_time = dict()
|
| 558 |
-
assert len(self.imgs) == len(masks)
|
| 559 |
-
for _t, frame in tqdm(
|
| 560 |
-
enumerate(masks),
|
| 561 |
-
# total=len(detections),
|
| 562 |
-
leave=False,
|
| 563 |
-
desc="Loading masks and properties",
|
| 564 |
-
):
|
| 565 |
-
regions = regionprops(frame)
|
| 566 |
-
t_labels = []
|
| 567 |
-
t_ts = []
|
| 568 |
-
t_coords = []
|
| 569 |
-
for _r in regions:
|
| 570 |
-
t_labels.append(_r.label)
|
| 571 |
-
t_ts.append(_t)
|
| 572 |
-
centroid = np.array(_r.centroid).astype(int)
|
| 573 |
-
t_coords.append(centroid)
|
| 574 |
-
|
| 575 |
-
properties_by_time[_t] = dict(coords=t_coords, labels=t_labels)
|
| 576 |
-
labels.extend(t_labels)
|
| 577 |
-
ts.extend(t_ts)
|
| 578 |
-
coords.extend(t_coords)
|
| 579 |
-
|
| 580 |
-
labels = np.array(labels, dtype=int)
|
| 581 |
-
ts = np.array(ts, dtype=int)
|
| 582 |
-
coords = np.array(coords, dtype=int)
|
| 583 |
-
|
| 584 |
-
return labels, ts, coords, properties_by_time
|
| 585 |
-
|
| 586 |
-
def _load_tracklet_links(self, folder: Path) -> pd.DataFrame:
|
| 587 |
-
df = pd.read_csv(
|
| 588 |
-
folder / "man_track.txt",
|
| 589 |
-
delimiter=" ",
|
| 590 |
-
names=["label", "t1", "t2", "parent"],
|
| 591 |
-
dtype=int,
|
| 592 |
-
)
|
| 593 |
-
n_dets = (df.t2 - df.t1 + 1).sum()
|
| 594 |
-
logger.debug(f"{folder} has {n_dets} detections")
|
| 595 |
-
|
| 596 |
-
n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum()
|
| 597 |
-
logger.debug(f"{folder} has {n_divs} divisions")
|
| 598 |
-
return df
|
| 599 |
-
|
| 600 |
-
def _build_tracklets_without_gt(self, masks):
|
| 601 |
-
"""Create a dataframe with tracklets from masks."""
|
| 602 |
-
rows = []
|
| 603 |
-
for t, m in enumerate(masks):
|
| 604 |
-
for c in np.unique(m[m > 0]):
|
| 605 |
-
rows.append([c, t, t, 0])
|
| 606 |
-
df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"])
|
| 607 |
-
return df
|
| 608 |
-
|
| 609 |
-
def _check_dimensions(self, x: np.ndarray):
|
| 610 |
-
if self.ndim == 2 and not x.ndim == 3:
|
| 611 |
-
raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data")
|
| 612 |
-
elif self.ndim == 3:
|
| 613 |
-
# if ndim=3 and data is two dimensional, it will be cast to 3D
|
| 614 |
-
if x.ndim == 3:
|
| 615 |
-
x = np.expand_dims(x, axis=1)
|
| 616 |
-
elif x.ndim == 4:
|
| 617 |
-
pass
|
| 618 |
-
else:
|
| 619 |
-
raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data")
|
| 620 |
-
return x
|
| 621 |
-
|
| 622 |
-
def _load(self):
|
| 623 |
-
# Load ground truth
|
| 624 |
-
logger.info("Loading ground truth")
|
| 625 |
-
self.gt_masks, self.gt_track_df = self._load_gt()
|
| 626 |
-
|
| 627 |
-
self.gt_masks = self._check_dimensions(self.gt_masks)
|
| 628 |
-
|
| 629 |
-
# Load images
|
| 630 |
-
if self.img_folder is None:
|
| 631 |
-
self.imgs = np.zeros_like(self.gt_masks)
|
| 632 |
-
else:
|
| 633 |
-
logger.info("Loading images")
|
| 634 |
-
imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
|
| 635 |
-
self.imgs = np.stack([
|
| 636 |
-
normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
|
| 637 |
-
])
|
| 638 |
-
self.imgs = self._check_dimensions(self.imgs)
|
| 639 |
-
if self.compress:
|
| 640 |
-
# prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
|
| 641 |
-
self.imgs = np.stack([
|
| 642 |
-
_compress_img_mask_preproc(im, mask, self.features)
|
| 643 |
-
for im, mask in zip(self.imgs, self.gt_masks)
|
| 644 |
-
])
|
| 645 |
-
|
| 646 |
-
assert len(self.gt_masks) == len(self.imgs)
|
| 647 |
-
|
| 648 |
-
# Load each of the detection folders and create data samples with a sliding window
|
| 649 |
-
windows = []
|
| 650 |
-
self.properties_by_time = dict()
|
| 651 |
-
self.det_masks = dict()
|
| 652 |
-
for _f in self.detection_folders:
|
| 653 |
-
det_folder = self.root / _f
|
| 654 |
-
|
| 655 |
-
if det_folder == self.gt_mask_folder:
|
| 656 |
-
det_masks = self.gt_masks
|
| 657 |
-
logger.info("DET MASK:\tUsing GT masks")
|
| 658 |
-
(
|
| 659 |
-
det_labels,
|
| 660 |
-
det_ts,
|
| 661 |
-
det_coords,
|
| 662 |
-
det_properties_by_time,
|
| 663 |
-
) = self._masks2properties(det_masks)
|
| 664 |
-
|
| 665 |
-
det_gt_matching = {
|
| 666 |
-
t: {_l: _l for _l in det_properties_by_time[t]["labels"]}
|
| 667 |
-
for t in range(len(det_masks))
|
| 668 |
-
}
|
| 669 |
-
else:
|
| 670 |
-
det_folder = self._guess_det_folder(root=self.root, suffix=_f)
|
| 671 |
-
if det_folder is None:
|
| 672 |
-
continue
|
| 673 |
-
|
| 674 |
-
logger.info(f"DET MASK:\t{det_folder}")
|
| 675 |
-
det_masks = self._load_tiffs(det_folder, dtype=np.int32)
|
| 676 |
-
det_masks = self._correct_gt_with_st(
|
| 677 |
-
det_folder, det_masks, dtype=np.int32
|
| 678 |
-
)
|
| 679 |
-
det_masks = self._check_dimensions(det_masks)
|
| 680 |
-
(
|
| 681 |
-
det_labels,
|
| 682 |
-
det_ts,
|
| 683 |
-
det_coords,
|
| 684 |
-
det_properties_by_time,
|
| 685 |
-
) = self._masks2properties(det_masks)
|
| 686 |
-
|
| 687 |
-
# FIXME matching can be slow for big images
|
| 688 |
-
# raise NotImplementedError("Matching not implemented for 3d version")
|
| 689 |
-
det_gt_matching = {
|
| 690 |
-
t: {
|
| 691 |
-
_d: _gt
|
| 692 |
-
for _gt, _d in matching(
|
| 693 |
-
self.gt_masks[t],
|
| 694 |
-
det_masks[t],
|
| 695 |
-
threshold=0.3,
|
| 696 |
-
max_distance=16,
|
| 697 |
-
)
|
| 698 |
-
}
|
| 699 |
-
for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
|
| 700 |
-
}
|
| 701 |
-
|
| 702 |
-
self.properties_by_time[_f] = det_properties_by_time
|
| 703 |
-
self.det_masks[_f] = det_masks
|
| 704 |
-
_w = self._build_windows(
|
| 705 |
-
det_folder,
|
| 706 |
-
det_masks,
|
| 707 |
-
det_labels,
|
| 708 |
-
det_ts,
|
| 709 |
-
det_coords,
|
| 710 |
-
det_gt_matching,
|
| 711 |
-
)
|
| 712 |
-
|
| 713 |
-
windows.extend(_w)
|
| 714 |
-
|
| 715 |
-
return windows
|
| 716 |
-
|
| 717 |
-
def _build_windows(
|
| 718 |
-
self,
|
| 719 |
-
det_folder,
|
| 720 |
-
det_masks,
|
| 721 |
-
labels,
|
| 722 |
-
ts,
|
| 723 |
-
coords,
|
| 724 |
-
matching,
|
| 725 |
-
):
|
| 726 |
-
"""_summary_.
|
| 727 |
-
|
| 728 |
-
Args:
|
| 729 |
-
det_folder (_type_): _description_
|
| 730 |
-
det_masks (_type_): _description_
|
| 731 |
-
labels (_type_): _description_
|
| 732 |
-
ts (_type_): _description_
|
| 733 |
-
coords (_type_): _description_
|
| 734 |
-
matching (_type_): _description_
|
| 735 |
-
|
| 736 |
-
Raises:
|
| 737 |
-
ValueError: _description_
|
| 738 |
-
ValueError: _description_
|
| 739 |
-
|
| 740 |
-
Returns:
|
| 741 |
-
_type_: _description_
|
| 742 |
-
"""
|
| 743 |
-
window_size = self.window_size
|
| 744 |
-
windows = []
|
| 745 |
-
|
| 746 |
-
# Creates the data samples with a sliding window
|
| 747 |
-
masks = self.gt_masks
|
| 748 |
-
for t1, t2 in tqdm(
|
| 749 |
-
zip(range(0, len(masks)), range(window_size, len(masks) + 1)),
|
| 750 |
-
total=len(masks) - window_size + 1,
|
| 751 |
-
leave=False,
|
| 752 |
-
desc="Building windows",
|
| 753 |
-
):
|
| 754 |
-
idx = (ts >= t1) & (ts < t2)
|
| 755 |
-
_ts = ts[idx]
|
| 756 |
-
_labels = labels[idx]
|
| 757 |
-
_coords = coords[idx]
|
| 758 |
-
|
| 759 |
-
# Use GT
|
| 760 |
-
# _labels = self.gt_labels[idx]
|
| 761 |
-
# _ts = self.gt_timepoints[idx]
|
| 762 |
-
|
| 763 |
-
if len(_labels) == 0:
|
| 764 |
-
# raise ValueError(f"No detections in sample {det_folder}:{t1}")
|
| 765 |
-
A = np.zeros((0, 0), dtype=bool)
|
| 766 |
-
_coords = np.zeros((0, masks.ndim - 1), dtype=int)
|
| 767 |
-
else:
|
| 768 |
-
if len(np.unique(_ts)) == 1:
|
| 769 |
-
logger.debug(
|
| 770 |
-
"Only detections from a single timepoint in sample"
|
| 771 |
-
f" {det_folder}:{t1}"
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
# build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
|
| 775 |
-
A = _ctc_assoc_matrix(
|
| 776 |
-
_labels,
|
| 777 |
-
_ts,
|
| 778 |
-
self.gt_graph,
|
| 779 |
-
matching,
|
| 780 |
-
)
|
| 781 |
-
|
| 782 |
-
if self.sanity_dist:
|
| 783 |
-
# # Sanity check: Can the model learn the euclidian distances?
|
| 784 |
-
# c = coords - coords.mean(axis=0, keepdims=True)
|
| 785 |
-
# c /= c.std(axis=0, keepdims=True)
|
| 786 |
-
# A = np.einsum('id,jd',c,c)
|
| 787 |
-
# A = 1 / (1 + np.exp(-A))
|
| 788 |
-
A = np.exp(-0.01 * cdist(_coords, _coords))
|
| 789 |
-
|
| 790 |
-
w = dict(
|
| 791 |
-
coords=_coords,
|
| 792 |
-
# TODO imgs and masks are unaltered here
|
| 793 |
-
t1=t1,
|
| 794 |
-
img=self.imgs[t1:t2],
|
| 795 |
-
mask=det_masks[t1:t2],
|
| 796 |
-
assoc_matrix=A,
|
| 797 |
-
labels=_labels,
|
| 798 |
-
timepoints=_ts,
|
| 799 |
-
)
|
| 800 |
-
|
| 801 |
-
windows.append(w)
|
| 802 |
-
|
| 803 |
-
logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n")
|
| 804 |
-
return windows
|
| 805 |
-
|
| 806 |
-
def __getitem__(self, n: int, return_dense=None):
|
| 807 |
-
# if not set, use default
|
| 808 |
-
if self.features == "wrfeat":
|
| 809 |
-
return self._getitem_wrfeat(n, return_dense)
|
| 810 |
-
|
| 811 |
-
if return_dense is None:
|
| 812 |
-
return_dense = self.return_dense
|
| 813 |
-
|
| 814 |
-
track = self.windows[n]
|
| 815 |
-
coords = track["coords"]
|
| 816 |
-
assoc_matrix = track["assoc_matrix"]
|
| 817 |
-
labels = track["labels"]
|
| 818 |
-
img = track["img"]
|
| 819 |
-
mask = track["mask"]
|
| 820 |
-
timepoints = track["timepoints"]
|
| 821 |
-
min_time = track["t1"]
|
| 822 |
-
|
| 823 |
-
if isinstance(mask, _CompressedArray):
|
| 824 |
-
mask = mask.decompress()
|
| 825 |
-
if isinstance(img, _CompressedArray):
|
| 826 |
-
img = img.decompress()
|
| 827 |
-
if isinstance(assoc_matrix, _CompressedArray):
|
| 828 |
-
assoc_matrix = assoc_matrix.decompress()
|
| 829 |
-
|
| 830 |
-
# cropping
|
| 831 |
-
if self.cropper is not None:
|
| 832 |
-
(img2, mask2, coords2), idx = self.cropper(img, mask, coords)
|
| 833 |
-
cropped_timepoints = timepoints[idx]
|
| 834 |
-
|
| 835 |
-
# at least one detection in each timepoint to accept the crop
|
| 836 |
-
if len(np.unique(cropped_timepoints)) == self.window_size:
|
| 837 |
-
# at least two total detections to accept the crop
|
| 838 |
-
# if len(idx) >= 2:
|
| 839 |
-
img, mask, coords = img2, mask2, coords2
|
| 840 |
-
labels = labels[idx]
|
| 841 |
-
timepoints = timepoints[idx]
|
| 842 |
-
assoc_matrix = assoc_matrix[idx][:, idx]
|
| 843 |
-
else:
|
| 844 |
-
logger.debug("disable cropping as no trajectories would be left")
|
| 845 |
-
|
| 846 |
-
if self.features == "none":
|
| 847 |
-
if self.augmenter is not None:
|
| 848 |
-
coords = self.augmenter(coords)
|
| 849 |
-
# Empty features
|
| 850 |
-
features = np.zeros((len(coords), 0))
|
| 851 |
-
|
| 852 |
-
elif self.features in ("regionprops", "regionprops2"):
|
| 853 |
-
if self.augmenter is not None:
|
| 854 |
-
(img2, mask2, coords2), idx = self.augmenter(
|
| 855 |
-
img, mask, coords, timepoints - min_time
|
| 856 |
-
)
|
| 857 |
-
if len(idx) > 0:
|
| 858 |
-
img, mask, coords = img2, mask2, coords2
|
| 859 |
-
labels = labels[idx]
|
| 860 |
-
timepoints = timepoints[idx]
|
| 861 |
-
assoc_matrix = assoc_matrix[idx][:, idx]
|
| 862 |
-
mask = mask.astype(int)
|
| 863 |
-
else:
|
| 864 |
-
logger.debug(
|
| 865 |
-
"disable augmentation as no trajectories would be left"
|
| 866 |
-
)
|
| 867 |
-
|
| 868 |
-
features = tuple(
|
| 869 |
-
extract_features_regionprops(
|
| 870 |
-
m, im, labels[timepoints == i + min_time], properties=self.features
|
| 871 |
-
)
|
| 872 |
-
for i, (m, im) in enumerate(zip(mask, img))
|
| 873 |
-
)
|
| 874 |
-
features = np.concatenate(features, axis=0)
|
| 875 |
-
# features = np.zeros((len(coords), self.feat_dim))
|
| 876 |
-
|
| 877 |
-
elif self.features == "patch":
|
| 878 |
-
if self.augmenter is not None:
|
| 879 |
-
(img2, mask2, coords2), idx = self.augmenter(
|
| 880 |
-
img, mask, coords, timepoints - min_time
|
| 881 |
-
)
|
| 882 |
-
if len(idx) > 0:
|
| 883 |
-
img, mask, coords = img2, mask2, coords2
|
| 884 |
-
labels = labels[idx]
|
| 885 |
-
timepoints = timepoints[idx]
|
| 886 |
-
assoc_matrix = assoc_matrix[idx][:, idx]
|
| 887 |
-
mask = mask.astype(int)
|
| 888 |
-
else:
|
| 889 |
-
print("disable augmentation as no trajectories would be left")
|
| 890 |
-
|
| 891 |
-
features = tuple(
|
| 892 |
-
extract_features_patch(
|
| 893 |
-
m,
|
| 894 |
-
im,
|
| 895 |
-
coords[timepoints == min_time + i],
|
| 896 |
-
labels[timepoints == min_time + i],
|
| 897 |
-
)
|
| 898 |
-
for i, (m, im) in enumerate(zip(mask, img))
|
| 899 |
-
)
|
| 900 |
-
features = np.concatenate(features, axis=0)
|
| 901 |
-
elif self.features == "patch_regionprops":
|
| 902 |
-
if self.augmenter is not None:
|
| 903 |
-
(img2, mask2, coords2), idx = self.augmenter(
|
| 904 |
-
img, mask, coords, timepoints - min_time
|
| 905 |
-
)
|
| 906 |
-
if len(idx) > 0:
|
| 907 |
-
img, mask, coords = img2, mask2, coords2
|
| 908 |
-
labels = labels[idx]
|
| 909 |
-
timepoints = timepoints[idx]
|
| 910 |
-
assoc_matrix = assoc_matrix[idx][:, idx]
|
| 911 |
-
mask = mask.astype(int)
|
| 912 |
-
else:
|
| 913 |
-
print("disable augmentation as no trajectories would be left")
|
| 914 |
-
|
| 915 |
-
features1 = tuple(
|
| 916 |
-
extract_features_patch(
|
| 917 |
-
m,
|
| 918 |
-
im,
|
| 919 |
-
coords[timepoints == min_time + i],
|
| 920 |
-
labels[timepoints == min_time + i],
|
| 921 |
-
)
|
| 922 |
-
for i, (m, im) in enumerate(zip(mask, img))
|
| 923 |
-
)
|
| 924 |
-
features2 = tuple(
|
| 925 |
-
extract_features_regionprops(
|
| 926 |
-
m,
|
| 927 |
-
im,
|
| 928 |
-
labels[timepoints == i + min_time],
|
| 929 |
-
properties=self.features,
|
| 930 |
-
)
|
| 931 |
-
for i, (m, im) in enumerate(zip(mask, img))
|
| 932 |
-
)
|
| 933 |
-
|
| 934 |
-
features = tuple(
|
| 935 |
-
np.concatenate((f1, f2), axis=-1)
|
| 936 |
-
for f1, f2 in zip(features1, features2)
|
| 937 |
-
)
|
| 938 |
-
|
| 939 |
-
features = np.concatenate(features, axis=0)
|
| 940 |
-
|
| 941 |
-
# remove temporal offset and add timepoints to coords
|
| 942 |
-
relative_timepoints = timepoints - track["t1"]
|
| 943 |
-
coords = np.concatenate((relative_timepoints[:, None], coords), axis=-1)
|
| 944 |
-
|
| 945 |
-
if self.max_tokens and len(timepoints) > self.max_tokens:
|
| 946 |
-
time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
|
| 947 |
-
n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
|
| 948 |
-
timepoints = timepoints[:n_elems]
|
| 949 |
-
labels = labels[:n_elems]
|
| 950 |
-
coords = coords[:n_elems]
|
| 951 |
-
features = features[:n_elems]
|
| 952 |
-
assoc_matrix = assoc_matrix[:n_elems, :n_elems]
|
| 953 |
-
logger.info(
|
| 954 |
-
f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
|
| 955 |
-
)
|
| 956 |
-
|
| 957 |
-
coords0 = torch.from_numpy(coords).float()
|
| 958 |
-
features = torch.from_numpy(features).float()
|
| 959 |
-
assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float()
|
| 960 |
-
labels = torch.from_numpy(labels).long()
|
| 961 |
-
timepoints = torch.from_numpy(timepoints).long()
|
| 962 |
-
|
| 963 |
-
if self.augmenter is not None:
|
| 964 |
-
coords = coords0.clone()
|
| 965 |
-
coords[:, 1:] += torch.randint(0, 256, (1, self.ndim))
|
| 966 |
-
else:
|
| 967 |
-
coords = coords0.clone()
|
| 968 |
-
res = dict(
|
| 969 |
-
features=features,
|
| 970 |
-
coords0=coords0,
|
| 971 |
-
coords=coords,
|
| 972 |
-
assoc_matrix=assoc_matrix,
|
| 973 |
-
timepoints=timepoints,
|
| 974 |
-
labels=labels,
|
| 975 |
-
)
|
| 976 |
-
|
| 977 |
-
if return_dense:
|
| 978 |
-
if all([x is not None for x in img]):
|
| 979 |
-
img = torch.from_numpy(img).float()
|
| 980 |
-
res["img"] = img
|
| 981 |
-
|
| 982 |
-
mask = torch.from_numpy(mask.astype(int)).long()
|
| 983 |
-
res["mask"] = mask
|
| 984 |
-
|
| 985 |
-
return res
|
| 986 |
-
|
| 987 |
-
# wrfeat functions...
|
| 988 |
-
# TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way
|
| 989 |
-
|
| 990 |
-
def _setup_features_augs_wrfeat(
|
| 991 |
-
self, ndim: int, features: str, augment: int, crop_size: Tuple[int]
|
| 992 |
-
):
|
| 993 |
-
# FIXME: hardcoded
|
| 994 |
-
feat_dim = 7 if ndim == 2 else 12
|
| 995 |
-
if augment == 1:
|
| 996 |
-
augmenter = wrfeat.WRAugmentationPipeline([
|
| 997 |
-
wrfeat.WRRandomFlip(p=0.5),
|
| 998 |
-
wrfeat.WRRandomAffine(
|
| 999 |
-
p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
|
| 1000 |
-
),
|
| 1001 |
-
# wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)),
|
| 1002 |
-
# wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
|
| 1003 |
-
])
|
| 1004 |
-
elif augment == 2:
|
| 1005 |
-
augmenter = wrfeat.WRAugmentationPipeline([
|
| 1006 |
-
wrfeat.WRRandomFlip(p=0.5),
|
| 1007 |
-
wrfeat.WRRandomAffine(
|
| 1008 |
-
p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
|
| 1009 |
-
),
|
| 1010 |
-
wrfeat.WRRandomBrightness(p=0.8),
|
| 1011 |
-
wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
|
| 1012 |
-
])
|
| 1013 |
-
elif augment == 3:
|
| 1014 |
-
augmenter = wrfeat.WRAugmentationPipeline([
|
| 1015 |
-
wrfeat.WRRandomFlip(p=0.5),
|
| 1016 |
-
wrfeat.WRRandomAffine(
|
| 1017 |
-
p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
|
| 1018 |
-
),
|
| 1019 |
-
wrfeat.WRRandomBrightness(p=0.8),
|
| 1020 |
-
wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3),
|
| 1021 |
-
wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
|
| 1022 |
-
])
|
| 1023 |
-
else:
|
| 1024 |
-
augmenter = None
|
| 1025 |
-
|
| 1026 |
-
cropper = (
|
| 1027 |
-
wrfeat.WRRandomCrop(
|
| 1028 |
-
crop_size=crop_size,
|
| 1029 |
-
ndim=ndim,
|
| 1030 |
-
)
|
| 1031 |
-
if crop_size is not None
|
| 1032 |
-
else None
|
| 1033 |
-
)
|
| 1034 |
-
return feat_dim, augmenter, cropper
|
| 1035 |
-
|
| 1036 |
-
def _load_wrfeat(self):
|
| 1037 |
-
# Load ground truth
|
| 1038 |
-
self.gt_masks, self.gt_track_df = self._load_gt()
|
| 1039 |
-
self.gt_masks = self._check_dimensions(self.gt_masks)
|
| 1040 |
-
|
| 1041 |
-
# Load images
|
| 1042 |
-
if self.img_folder is None:
|
| 1043 |
-
if self.gt_masks is not None:
|
| 1044 |
-
self.imgs = np.zeros_like(self.gt_masks)
|
| 1045 |
-
else:
|
| 1046 |
-
raise NotImplementedError("No images and no GT masks")
|
| 1047 |
-
else:
|
| 1048 |
-
logger.info("Loading images")
|
| 1049 |
-
imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
|
| 1050 |
-
self.imgs = np.stack([
|
| 1051 |
-
normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
|
| 1052 |
-
])
|
| 1053 |
-
self.imgs = self._check_dimensions(self.imgs)
|
| 1054 |
-
if self.compress:
|
| 1055 |
-
# prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
|
| 1056 |
-
self.imgs = np.stack([
|
| 1057 |
-
_compress_img_mask_preproc(im, mask, self.features)
|
| 1058 |
-
for im, mask in zip(self.imgs, self.gt_masks)
|
| 1059 |
-
])
|
| 1060 |
-
|
| 1061 |
-
assert len(self.gt_masks) == len(self.imgs)
|
| 1062 |
-
|
| 1063 |
-
# Load each of the detection folders and create data samples with a sliding window
|
| 1064 |
-
windows = []
|
| 1065 |
-
self.properties_by_time = dict()
|
| 1066 |
-
self.det_masks = dict()
|
| 1067 |
-
logger.info("Loading detections")
|
| 1068 |
-
for _f in self.detection_folders:
|
| 1069 |
-
det_folder = self.root / _f
|
| 1070 |
-
|
| 1071 |
-
if det_folder == self.gt_mask_folder:
|
| 1072 |
-
det_masks = self.gt_masks
|
| 1073 |
-
logger.info("DET MASK:\tUsing GT masks")
|
| 1074 |
-
# identity matching
|
| 1075 |
-
det_gt_matching = {
|
| 1076 |
-
t: {_l: _l for _l in set(np.unique(d)) - {0}}
|
| 1077 |
-
for t, d in enumerate(det_masks)
|
| 1078 |
-
}
|
| 1079 |
-
else:
|
| 1080 |
-
det_folder = self._guess_det_folder(root=self.root, suffix=_f)
|
| 1081 |
-
if det_folder is None:
|
| 1082 |
-
continue
|
| 1083 |
-
logger.info(f"DET MASK (guessed):\t{det_folder}")
|
| 1084 |
-
det_masks = self._load_tiffs(det_folder, dtype=np.int32)
|
| 1085 |
-
det_masks = self._correct_gt_with_st(
|
| 1086 |
-
det_folder, det_masks, dtype=np.int32
|
| 1087 |
-
)
|
| 1088 |
-
det_masks = self._check_dimensions(det_masks)
|
| 1089 |
-
# FIXME matching can be slow for big images
|
| 1090 |
-
# raise NotImplementedError("Matching not implemented for 3d version")
|
| 1091 |
-
det_gt_matching = {
|
| 1092 |
-
t: {
|
| 1093 |
-
_d: _gt
|
| 1094 |
-
for _gt, _d in matching(
|
| 1095 |
-
self.gt_masks[t],
|
| 1096 |
-
det_masks[t],
|
| 1097 |
-
threshold=0.3,
|
| 1098 |
-
max_distance=16,
|
| 1099 |
-
)
|
| 1100 |
-
}
|
| 1101 |
-
for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
|
| 1102 |
-
}
|
| 1103 |
-
|
| 1104 |
-
self.det_masks[_f] = det_masks
|
| 1105 |
-
|
| 1106 |
-
# build features
|
| 1107 |
-
|
| 1108 |
-
features = joblib.Parallel(n_jobs=8)(
|
| 1109 |
-
joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
|
| 1110 |
-
mask=mask[None], img=img[None], t_start=t
|
| 1111 |
-
)
|
| 1112 |
-
for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
|
| 1113 |
-
)
|
| 1114 |
-
|
| 1115 |
-
properties_by_time = dict()
|
| 1116 |
-
for _t, _feats in enumerate(features):
|
| 1117 |
-
properties_by_time[_t] = dict(
|
| 1118 |
-
coords=_feats.coords, labels=_feats.labels
|
| 1119 |
-
)
|
| 1120 |
-
self.properties_by_time[_f] = properties_by_time
|
| 1121 |
-
|
| 1122 |
-
_w = self._build_windows_wrfeat(
|
| 1123 |
-
features,
|
| 1124 |
-
det_masks,
|
| 1125 |
-
det_gt_matching,
|
| 1126 |
-
)
|
| 1127 |
-
|
| 1128 |
-
windows.extend(_w)
|
| 1129 |
-
|
| 1130 |
-
return windows
|
| 1131 |
-
|
| 1132 |
-
def _build_windows_wrfeat(
|
| 1133 |
-
self,
|
| 1134 |
-
features: Sequence[wrfeat.WRFeatures],
|
| 1135 |
-
det_masks: np.ndarray,
|
| 1136 |
-
matching: Tuple[dict],
|
| 1137 |
-
):
|
| 1138 |
-
assert len(self.imgs) == len(det_masks)
|
| 1139 |
-
|
| 1140 |
-
window_size = self.window_size
|
| 1141 |
-
windows = []
|
| 1142 |
-
|
| 1143 |
-
# Creates the data samples with a sliding window
|
| 1144 |
-
for t1, t2 in tqdm(
|
| 1145 |
-
zip(range(0, len(det_masks)), range(window_size, len(det_masks) + 1)),
|
| 1146 |
-
total=len(det_masks) - window_size + 1,
|
| 1147 |
-
leave=False,
|
| 1148 |
-
desc="Building windows",
|
| 1149 |
-
):
|
| 1150 |
-
img = self.imgs[t1:t2]
|
| 1151 |
-
mask = det_masks[t1:t2]
|
| 1152 |
-
feat = wrfeat.WRFeatures.concat(features[t1:t2])
|
| 1153 |
-
|
| 1154 |
-
labels = feat.labels
|
| 1155 |
-
timepoints = feat.timepoints
|
| 1156 |
-
coords = feat.coords
|
| 1157 |
-
|
| 1158 |
-
if len(feat) == 0:
|
| 1159 |
-
A = np.zeros((0, 0), dtype=bool)
|
| 1160 |
-
coords = np.zeros((0, feat.ndim), dtype=int)
|
| 1161 |
-
else:
|
| 1162 |
-
# build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
|
| 1163 |
-
A = _ctc_assoc_matrix(
|
| 1164 |
-
labels,
|
| 1165 |
-
timepoints,
|
| 1166 |
-
self.gt_graph,
|
| 1167 |
-
matching,
|
| 1168 |
-
)
|
| 1169 |
-
w = dict(
|
| 1170 |
-
coords=coords,
|
| 1171 |
-
# TODO imgs and masks are unaltered here
|
| 1172 |
-
t1=t1,
|
| 1173 |
-
img=img,
|
| 1174 |
-
mask=mask,
|
| 1175 |
-
assoc_matrix=A,
|
| 1176 |
-
labels=labels,
|
| 1177 |
-
timepoints=timepoints,
|
| 1178 |
-
wrfeat=feat,
|
| 1179 |
-
)
|
| 1180 |
-
windows.append(w)
|
| 1181 |
-
|
| 1182 |
-
logger.debug(f"Built {len(windows)} track windows.\n")
|
| 1183 |
-
return windows
|
| 1184 |
-
|
| 1185 |
-
def _getitem_wrfeat(self, n: int, return_dense=None):
|
| 1186 |
-
# if not set, use default
|
| 1187 |
-
|
| 1188 |
-
if return_dense is None:
|
| 1189 |
-
return_dense = self.return_dense
|
| 1190 |
-
|
| 1191 |
-
track = self.windows[n]
|
| 1192 |
-
# coords = track["coords"]
|
| 1193 |
-
assoc_matrix = track["assoc_matrix"]
|
| 1194 |
-
labels = track["labels"]
|
| 1195 |
-
img = track["img"]
|
| 1196 |
-
mask = track["mask"]
|
| 1197 |
-
timepoints = track["timepoints"]
|
| 1198 |
-
# track["t1"]
|
| 1199 |
-
feat = track["wrfeat"]
|
| 1200 |
-
|
| 1201 |
-
if return_dense and isinstance(mask, _CompressedArray):
|
| 1202 |
-
mask = mask.decompress()
|
| 1203 |
-
if return_dense and isinstance(img, _CompressedArray):
|
| 1204 |
-
img = img.decompress()
|
| 1205 |
-
if isinstance(assoc_matrix, _CompressedArray):
|
| 1206 |
-
assoc_matrix = assoc_matrix.decompress()
|
| 1207 |
-
|
| 1208 |
-
# cropping
|
| 1209 |
-
if self.cropper is not None:
|
| 1210 |
-
# Use only if there is at least one timepoint per detection
|
| 1211 |
-
cropped_feat, cropped_idx = self.cropper(feat)
|
| 1212 |
-
cropped_timepoints = timepoints[cropped_idx]
|
| 1213 |
-
if len(np.unique(cropped_timepoints)) == self.window_size:
|
| 1214 |
-
idx = cropped_idx
|
| 1215 |
-
feat = cropped_feat
|
| 1216 |
-
labels = labels[idx]
|
| 1217 |
-
timepoints = timepoints[idx]
|
| 1218 |
-
assoc_matrix = assoc_matrix[idx][:, idx]
|
| 1219 |
-
else:
|
| 1220 |
-
logger.debug("Skipping cropping")
|
| 1221 |
-
|
| 1222 |
-
if self.augmenter is not None:
|
| 1223 |
-
feat = self.augmenter(feat)
|
| 1224 |
-
|
| 1225 |
-
coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1)
|
| 1226 |
-
coords0 = torch.from_numpy(coords0).float()
|
| 1227 |
-
assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32))
|
| 1228 |
-
features = torch.from_numpy(feat.features_stacked).float()
|
| 1229 |
-
labels = torch.from_numpy(feat.labels).long()
|
| 1230 |
-
timepoints = torch.from_numpy(feat.timepoints).long()
|
| 1231 |
-
|
| 1232 |
-
if self.max_tokens and len(timepoints) > self.max_tokens:
|
| 1233 |
-
time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
|
| 1234 |
-
n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
|
| 1235 |
-
timepoints = timepoints[:n_elems]
|
| 1236 |
-
labels = labels[:n_elems]
|
| 1237 |
-
coords0 = coords0[:n_elems]
|
| 1238 |
-
features = features[:n_elems]
|
| 1239 |
-
assoc_matrix = assoc_matrix[:n_elems, :n_elems]
|
| 1240 |
-
logger.debug(
|
| 1241 |
-
f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
|
| 1242 |
-
)
|
| 1243 |
-
|
| 1244 |
-
if self.augmenter is not None:
|
| 1245 |
-
coords = coords0.clone()
|
| 1246 |
-
coords[:, 1:] += torch.randint(0, 512, (1, self.ndim))
|
| 1247 |
-
else:
|
| 1248 |
-
coords = coords0.clone()
|
| 1249 |
-
res = dict(
|
| 1250 |
-
features=features,
|
| 1251 |
-
coords0=coords0,
|
| 1252 |
-
coords=coords,
|
| 1253 |
-
assoc_matrix=assoc_matrix,
|
| 1254 |
-
timepoints=timepoints,
|
| 1255 |
-
labels=labels,
|
| 1256 |
-
)
|
| 1257 |
-
|
| 1258 |
-
if return_dense:
|
| 1259 |
-
if all([x is not None for x in img]):
|
| 1260 |
-
img = torch.from_numpy(img).float()
|
| 1261 |
-
res["img"] = img
|
| 1262 |
-
|
| 1263 |
-
mask = torch.from_numpy(mask.astype(int)).long()
|
| 1264 |
-
res["mask"] = mask
|
| 1265 |
-
|
| 1266 |
-
return res
|
| 1267 |
-
|
| 1268 |
-
|
| 1269 |
-
def _ctc_lineages(df, masks, t1=0, t2=None):
|
| 1270 |
-
"""From a ctc dataframe, create a digraph that contains all sublineages
|
| 1271 |
-
between t1 and t2 (exclusive t2).
|
| 1272 |
-
|
| 1273 |
-
Args:
|
| 1274 |
-
df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt)
|
| 1275 |
-
masks: List of masks. If t1 is not 0, then the masks are assumed to be already cropped accordingly.
|
| 1276 |
-
t1: Start timepoint
|
| 1277 |
-
t2: End timepoint (exclusive). If None, then t2 is set to len(masks)
|
| 1278 |
-
|
| 1279 |
-
Returns:
|
| 1280 |
-
labels: List of label ids extracted from the masks, ordered by timepoint.
|
| 1281 |
-
ts: List of corresponding timepoints
|
| 1282 |
-
graph: The digraph of the lineages between t1 and t2.
|
| 1283 |
-
"""
|
| 1284 |
-
if t1 > 0:
|
| 1285 |
-
assert t2 is not None
|
| 1286 |
-
assert t2 - t1 == len(masks)
|
| 1287 |
-
if t2 is None:
|
| 1288 |
-
t2 = len(masks)
|
| 1289 |
-
|
| 1290 |
-
graph = nx.DiGraph()
|
| 1291 |
-
labels = []
|
| 1292 |
-
ts = []
|
| 1293 |
-
|
| 1294 |
-
# get all objects that are present in the time interval
|
| 1295 |
-
df_sub = df[(df.t1 < t2) & (df.t2 >= t1)]
|
| 1296 |
-
|
| 1297 |
-
# Correct offset
|
| 1298 |
-
df_sub.loc[:, "t1"] -= t1
|
| 1299 |
-
df_sub.loc[:, "t2"] -= t1
|
| 1300 |
-
|
| 1301 |
-
# all_labels = df_sub.label.unique()
|
| 1302 |
-
# TODO speed up by precalculating unique values once
|
| 1303 |
-
# in_masks = set(np.where(np.bincount(np.stack(masks[t1:t2]).ravel()))[0]) - {0}
|
| 1304 |
-
# all_labels = [l for l in all_labels if l in in_masks]
|
| 1305 |
-
all_labels = set()
|
| 1306 |
-
|
| 1307 |
-
for t in tqdm(
|
| 1308 |
-
range(0, t2 - t1), desc="Building and checking lineage graph", leave=False
|
| 1309 |
-
):
|
| 1310 |
-
# get all entities at timepoint
|
| 1311 |
-
obs = df_sub[(df_sub.t1 <= t) & (df_sub.t2 >= t)]
|
| 1312 |
-
in_t = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0}
|
| 1313 |
-
all_labels.update(in_t)
|
| 1314 |
-
for row in obs.itertuples():
|
| 1315 |
-
label, t1, t2, parent = row.label, row.t1, row.t2, row.parent
|
| 1316 |
-
if label not in in_t:
|
| 1317 |
-
continue
|
| 1318 |
-
|
| 1319 |
-
labels.append(label)
|
| 1320 |
-
ts.append(t)
|
| 1321 |
-
|
| 1322 |
-
# add label as node if not already in graph
|
| 1323 |
-
if not graph.has_node(label):
|
| 1324 |
-
graph.add_node(label)
|
| 1325 |
-
|
| 1326 |
-
# Parents have been added in previous timepoints
|
| 1327 |
-
if parent in all_labels:
|
| 1328 |
-
if not graph.has_node(parent):
|
| 1329 |
-
graph.add_node(parent)
|
| 1330 |
-
graph.add_edge(parent, label)
|
| 1331 |
-
|
| 1332 |
-
labels = np.array(labels)
|
| 1333 |
-
ts = np.array(ts)
|
| 1334 |
-
return labels, ts, graph
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
@njit
|
| 1338 |
-
def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray):
|
| 1339 |
-
"""For each detection, associate with all detections that are."""
|
| 1340 |
-
for i in range(len(labels)):
|
| 1341 |
-
for j in range(len(labels)):
|
| 1342 |
-
A[i, j] = family[i, labels[j]]
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
def _ctc_assoc_matrix(detections, ts, graph, matching):
|
| 1346 |
-
"""Create the association matrix for a list of labels and a tracklet parent -> childrend graph.
|
| 1347 |
-
|
| 1348 |
-
Each detection is associated with all its ancestors and descendants, but not its siblings and their offspring.
|
| 1349 |
-
|
| 1350 |
-
Args:
|
| 1351 |
-
detections: list of integer labels, ordered by timepoint
|
| 1352 |
-
ts: list of timepoints corresponding to the detections
|
| 1353 |
-
graph: networkx DiGraph with each ground truth tracklet id (spanning n timepoints) as a single node
|
| 1354 |
-
and parent -> children relationships as edges.
|
| 1355 |
-
matching: for each timepoint, a dictionary that maps from detection id to gt tracklet id
|
| 1356 |
-
"""
|
| 1357 |
-
assert 0 not in graph
|
| 1358 |
-
matched_gt = []
|
| 1359 |
-
for i, (label, t) in enumerate(zip(detections, ts)):
|
| 1360 |
-
gt_tracklet_id = matching[t].get(label, 0)
|
| 1361 |
-
matched_gt.append(gt_tracklet_id)
|
| 1362 |
-
matched_gt = np.array(matched_gt, dtype=int)
|
| 1363 |
-
# Now we have the subset of gt nodes that is matched to any detection in the current window
|
| 1364 |
-
|
| 1365 |
-
# relabel to reduce the size of lookup matrices
|
| 1366 |
-
# offset 0 not allowed in skimage, which makes this very annoying
|
| 1367 |
-
relabeled_gt, fwd_map, _inv_map = relabel_sequential(matched_gt, offset=1)
|
| 1368 |
-
# dict is faster than arraymap
|
| 1369 |
-
fwd_map = dict(zip(fwd_map.in_values, fwd_map.out_values))
|
| 1370 |
-
# inv_map = dict(zip(inv_map.in_values, inv_map.out_values))
|
| 1371 |
-
|
| 1372 |
-
# the family relationships for each ground truth detection,
|
| 1373 |
-
# Maps from local detection number (0-indexed) to global gt tracklet id (1-indexed)
|
| 1374 |
-
family = np.zeros((len(detections), len(relabeled_gt) + 1), bool)
|
| 1375 |
-
|
| 1376 |
-
# Connects each tracklet id with its children and parent tracklets (according to man_track.txt)
|
| 1377 |
-
for i, (label, t) in enumerate(zip(detections, ts)):
|
| 1378 |
-
# Get the original label corresponding to the graph
|
| 1379 |
-
gt_tracklet_id = matching[t].get(label, None)
|
| 1380 |
-
if gt_tracklet_id is not None:
|
| 1381 |
-
ancestors = []
|
| 1382 |
-
descendants = []
|
| 1383 |
-
# This iterates recursively through the graph
|
| 1384 |
-
for n in nx.descendants(graph, gt_tracklet_id):
|
| 1385 |
-
if n in fwd_map:
|
| 1386 |
-
descendants.append(fwd_map[n])
|
| 1387 |
-
for n in nx.ancestors(graph, gt_tracklet_id):
|
| 1388 |
-
if n in fwd_map:
|
| 1389 |
-
ancestors.append(fwd_map[n])
|
| 1390 |
-
|
| 1391 |
-
family[i, np.array([fwd_map[gt_tracklet_id], *ancestors, *descendants])] = (
|
| 1392 |
-
True
|
| 1393 |
-
)
|
| 1394 |
-
else:
|
| 1395 |
-
pass
|
| 1396 |
-
# Now we match to nothing, so even the matrix diagonal will not be filled.
|
| 1397 |
-
|
| 1398 |
-
# This assures that matching to 0 is always false
|
| 1399 |
-
assert family[:, 0].sum() == 0
|
| 1400 |
-
|
| 1401 |
-
# Create the detection-to-detection association matrix
|
| 1402 |
-
A = np.zeros((len(detections), len(detections)), dtype=bool)
|
| 1403 |
-
|
| 1404 |
-
_assoc(A, relabeled_gt, family)
|
| 1405 |
-
|
| 1406 |
-
return A
|
| 1407 |
-
|
| 1408 |
-
|
| 1409 |
-
def sigmoid(x):
|
| 1410 |
-
return 1 / (1 + np.exp(-x))
|
| 1411 |
-
|
| 1412 |
-
|
| 1413 |
-
def _compress_img_mask_preproc(img, mask, features):
|
| 1414 |
-
"""Remove certain img pixels if not needed to save memory for large datasets."""
|
| 1415 |
-
# dont change anything if we need patch values
|
| 1416 |
-
if features in ("patch", "patch_regionprops"):
|
| 1417 |
-
# clear img pixels outside of patch_mask of size 16x16
|
| 1418 |
-
patch_width = 16 # TOD: hardcoded: change this if needed
|
| 1419 |
-
coords = tuple(np.array(r.centroid).astype(int) for r in regionprops(mask))
|
| 1420 |
-
img2 = np.zeros_like(img)
|
| 1421 |
-
if len(coords) > 0:
|
| 1422 |
-
coords = np.stack(coords)
|
| 1423 |
-
coords = np.clip(coords, 0, np.array(mask.shape)[None] - 1)
|
| 1424 |
-
patch_mask = np.zeros_like(img, dtype=bool)
|
| 1425 |
-
patch_mask[tuple(coords.T)] = True
|
| 1426 |
-
# retain 3*patch_width+1 around center to be safe...
|
| 1427 |
-
patch_mask = ndi.maximum_filter(patch_mask, 3 * patch_width + 1)
|
| 1428 |
-
img2[patch_mask] = img[patch_mask]
|
| 1429 |
-
|
| 1430 |
-
else:
|
| 1431 |
-
# otherwise set img value inside masks to mean
|
| 1432 |
-
# FIXME: change when using other intensity based regionprops
|
| 1433 |
-
img2 = np.zeros_like(img)
|
| 1434 |
-
for reg in regionprops(mask, intensity_image=img):
|
| 1435 |
-
m = mask[reg.slice] == reg.label
|
| 1436 |
-
img2[reg.slice][m] = reg.mean_intensity
|
| 1437 |
-
return img2
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
def pad_tensor(x, n_max: int, dim=0, value=0):
|
| 1441 |
-
n = x.shape[dim]
|
| 1442 |
-
if n_max < n:
|
| 1443 |
-
raise ValueError(f"pad_tensor: n_max={n_max} must be larger than n={n} !")
|
| 1444 |
-
pad_shape = list(x.shape)
|
| 1445 |
-
pad_shape[dim] = n_max - n
|
| 1446 |
-
# pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype).to(x.device)
|
| 1447 |
-
pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype)
|
| 1448 |
-
return torch.cat((x, pad), dim=dim)
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
def collate_sequence_padding(batch):
|
| 1452 |
-
"""Collate function that pads all sequences to the same length."""
|
| 1453 |
-
lens = tuple(len(x["coords"]) for x in batch)
|
| 1454 |
-
n_max_len = max(lens)
|
| 1455 |
-
# print(tuple(len(x["coords"]) for x in batch))
|
| 1456 |
-
# print(tuple(len(x["features"]) for x in batch))
|
| 1457 |
-
# print(batch[0].keys())
|
| 1458 |
-
tuple(batch[0].keys())
|
| 1459 |
-
normal_keys = {
|
| 1460 |
-
"coords": 0,
|
| 1461 |
-
"features": 0,
|
| 1462 |
-
"labels": 0, # Not needed, remove for speed.
|
| 1463 |
-
"timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that.
|
| 1464 |
-
}
|
| 1465 |
-
n_pads = tuple(n_max_len - s for s in lens)
|
| 1466 |
-
batch_new = dict(
|
| 1467 |
-
(
|
| 1468 |
-
k,
|
| 1469 |
-
torch.stack(
|
| 1470 |
-
[pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0
|
| 1471 |
-
),
|
| 1472 |
-
)
|
| 1473 |
-
for k, v in normal_keys.items()
|
| 1474 |
-
)
|
| 1475 |
-
batch_new["assoc_matrix"] = torch.stack(
|
| 1476 |
-
[
|
| 1477 |
-
pad_tensor(
|
| 1478 |
-
pad_tensor(x["assoc_matrix"], n_max_len, dim=0), n_max_len, dim=1
|
| 1479 |
-
)
|
| 1480 |
-
for x in batch
|
| 1481 |
-
],
|
| 1482 |
-
dim=0,
|
| 1483 |
-
)
|
| 1484 |
-
|
| 1485 |
-
# add boolean mask that signifies whether tokens are padded or not (such that they can be ignored later)
|
| 1486 |
-
pad_mask = torch.zeros((len(batch), n_max_len), dtype=torch.bool)
|
| 1487 |
-
for i, n_pad in enumerate(n_pads):
|
| 1488 |
-
pad_mask[i, n_max_len - n_pad :] = True
|
| 1489 |
-
|
| 1490 |
-
batch_new["padding_mask"] = pad_mask.bool()
|
| 1491 |
-
return batch_new
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
if __name__ == "__main__":
|
| 1495 |
-
dummy_data = CTCData(
|
| 1496 |
-
root="../../scripts/data/synthetic_cells/01",
|
| 1497 |
-
ndim=2,
|
| 1498 |
-
detection_folders=["TRA"],
|
| 1499 |
-
window_size=4,
|
| 1500 |
-
max_tokens=None,
|
| 1501 |
-
augment=3,
|
| 1502 |
-
features="none",
|
| 1503 |
-
downscale_temporal=1,
|
| 1504 |
-
downscale_spatial=1,
|
| 1505 |
-
sanity_dist=False,
|
| 1506 |
-
crop_size=(256, 256),
|
| 1507 |
-
)
|
| 1508 |
-
|
| 1509 |
-
x = dummy_data[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/distributed.py
DELETED
|
@@ -1,316 +0,0 @@
|
|
| 1 |
-
"""Data loading and sampling utils for distributed training."""
|
| 2 |
-
|
| 3 |
-
import hashlib
|
| 4 |
-
import json
|
| 5 |
-
import logging
|
| 6 |
-
import pickle
|
| 7 |
-
# from collections.abc import Iterable
|
| 8 |
-
from copy import deepcopy
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from timeit import default_timer
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
-
# from lightning import LightningDataModule
|
| 15 |
-
from torch.utils.data import (
|
| 16 |
-
BatchSampler,
|
| 17 |
-
ConcatDataset,
|
| 18 |
-
DataLoader,
|
| 19 |
-
Dataset,
|
| 20 |
-
DistributedSampler,
|
| 21 |
-
)
|
| 22 |
-
from typing import Optional, Iterable
|
| 23 |
-
from .data import CTCData
|
| 24 |
-
|
| 25 |
-
logger = logging.getLogger(__name__)
|
| 26 |
-
logger.setLevel(logging.INFO)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def cache_class(cachedir=None):
|
| 30 |
-
"""A simple file cache for CTCData."""
|
| 31 |
-
|
| 32 |
-
def make_hashable(obj):
|
| 33 |
-
if isinstance(obj, tuple | list):
|
| 34 |
-
return tuple(make_hashable(e) for e in obj)
|
| 35 |
-
elif isinstance(obj, Path):
|
| 36 |
-
return obj.as_posix()
|
| 37 |
-
elif isinstance(obj, dict):
|
| 38 |
-
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
|
| 39 |
-
else:
|
| 40 |
-
return obj
|
| 41 |
-
|
| 42 |
-
def hash_args_kwargs(*args, **kwargs):
|
| 43 |
-
hashable_args = tuple(make_hashable(arg) for arg in args)
|
| 44 |
-
hashable_kwargs = make_hashable(kwargs)
|
| 45 |
-
combined_serialized = json.dumps(
|
| 46 |
-
[hashable_args, hashable_kwargs], sort_keys=True
|
| 47 |
-
)
|
| 48 |
-
hash_obj = hashlib.sha256(combined_serialized.encode())
|
| 49 |
-
return hash_obj.hexdigest()
|
| 50 |
-
|
| 51 |
-
if cachedir is None:
|
| 52 |
-
return CTCData
|
| 53 |
-
else:
|
| 54 |
-
cachedir = Path(cachedir)
|
| 55 |
-
|
| 56 |
-
def _wrapped(*args, **kwargs):
|
| 57 |
-
h = hash_args_kwargs(*args, **kwargs)
|
| 58 |
-
cachedir.mkdir(exist_ok=True, parents=True)
|
| 59 |
-
cache_file = cachedir / f"{h}.pkl"
|
| 60 |
-
if cache_file.exists():
|
| 61 |
-
logger.info(f"Loading cached dataset from {cache_file}")
|
| 62 |
-
with open(cache_file, "rb") as f:
|
| 63 |
-
return pickle.load(f)
|
| 64 |
-
else:
|
| 65 |
-
c = CTCData(*args, **kwargs)
|
| 66 |
-
logger.info(f"Saving cached dataset to {cache_file}")
|
| 67 |
-
pickle.dump(c, open(cache_file, "wb"))
|
| 68 |
-
return c
|
| 69 |
-
|
| 70 |
-
return _wrapped
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class BalancedBatchSampler(BatchSampler):
|
| 74 |
-
"""samples batch indices such that the number of objects in each batch is balanced
|
| 75 |
-
(so to reduce the number of paddings in the batch).
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
"""
|
| 79 |
-
|
| 80 |
-
def __init__(
|
| 81 |
-
self,
|
| 82 |
-
dataset: torch.utils.data.Dataset,
|
| 83 |
-
batch_size: int,
|
| 84 |
-
n_pool: int = 10,
|
| 85 |
-
num_samples: Optional[int] = None,
|
| 86 |
-
weight_by_ndivs: bool = False,
|
| 87 |
-
weight_by_dataset: bool = False,
|
| 88 |
-
drop_last: bool = False,
|
| 89 |
-
):
|
| 90 |
-
"""Setting n_pool =1 will result in a regular random batch sampler.
|
| 91 |
-
|
| 92 |
-
weight_by_ndivs: if True, the probability of sampling an element is proportional to the number of divisions
|
| 93 |
-
weight_by_dataset: if True, the probability of sampling an element is inversely proportional to the length of the dataset
|
| 94 |
-
"""
|
| 95 |
-
if isinstance(dataset, CTCData):
|
| 96 |
-
self.n_objects = dataset.n_objects
|
| 97 |
-
self.n_divs = np.array(dataset.n_divs)
|
| 98 |
-
self.n_sizes = np.ones(len(dataset)) * len(dataset)
|
| 99 |
-
elif isinstance(dataset, ConcatDataset):
|
| 100 |
-
self.n_objects = tuple(n for d in dataset.datasets for n in d.n_objects)
|
| 101 |
-
self.n_divs = np.array(tuple(n for d in dataset.datasets for n in d.n_divs))
|
| 102 |
-
self.n_sizes = np.array(
|
| 103 |
-
tuple(len(d) for d in dataset.datasets for _ in range(len(d)))
|
| 104 |
-
)
|
| 105 |
-
else:
|
| 106 |
-
raise NotImplementedError(
|
| 107 |
-
f"BalancedBatchSampler: Unknown dataset type {type(dataset)}"
|
| 108 |
-
)
|
| 109 |
-
assert len(self.n_objects) == len(self.n_divs) == len(self.n_sizes)
|
| 110 |
-
|
| 111 |
-
self.batch_size = batch_size
|
| 112 |
-
self.n_pool = n_pool
|
| 113 |
-
self.drop_last = drop_last
|
| 114 |
-
self.num_samples = num_samples
|
| 115 |
-
self.weight_by_ndivs = weight_by_ndivs
|
| 116 |
-
self.weight_by_dataset = weight_by_dataset
|
| 117 |
-
logger.debug(f"{weight_by_ndivs=}")
|
| 118 |
-
logger.debug(f"{weight_by_dataset=}")
|
| 119 |
-
|
| 120 |
-
def get_probs(self, idx):
|
| 121 |
-
idx = np.array(idx)
|
| 122 |
-
if self.weight_by_ndivs:
|
| 123 |
-
probs = 1 + np.sqrt(self.n_divs[idx])
|
| 124 |
-
else:
|
| 125 |
-
probs = np.ones(len(idx))
|
| 126 |
-
if self.weight_by_dataset:
|
| 127 |
-
probs = probs / (self.n_sizes[idx] + 1e-6)
|
| 128 |
-
|
| 129 |
-
probs = probs / (probs.sum() + 1e-10)
|
| 130 |
-
return probs
|
| 131 |
-
|
| 132 |
-
def sample_batches(self, idx: Iterable[int]):
|
| 133 |
-
# we will split the indices into pools of size n_pool
|
| 134 |
-
num_samples = self.num_samples if self.num_samples is not None else len(idx)
|
| 135 |
-
# sample from the indices with replacement and given probabilites
|
| 136 |
-
idx = np.random.choice(idx, num_samples, replace=True, p=self.get_probs(idx))
|
| 137 |
-
|
| 138 |
-
n_pool = min(
|
| 139 |
-
self.n_pool * self.batch_size,
|
| 140 |
-
(len(idx) // self.batch_size) * self.batch_size,
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
batches = []
|
| 144 |
-
for i in range(0, len(idx), n_pool):
|
| 145 |
-
# the indices in the pool are sorted by their number of objects
|
| 146 |
-
idx_pool = idx[i : i + n_pool]
|
| 147 |
-
idx_pool = sorted(idx_pool, key=lambda i: self.n_objects[i])
|
| 148 |
-
|
| 149 |
-
# such that we can create batches where each element has a similar number of objects
|
| 150 |
-
jj = np.arange(0, len(idx_pool), self.batch_size)
|
| 151 |
-
np.random.shuffle(jj)
|
| 152 |
-
|
| 153 |
-
for j in jj:
|
| 154 |
-
# dont drop_last, as this leads to a lot of lightning problems....
|
| 155 |
-
# if j + self.batch_size > len(idx_pool): # assume drop_last=True
|
| 156 |
-
# continue
|
| 157 |
-
batch = idx_pool[j : j + self.batch_size]
|
| 158 |
-
batches.append(batch)
|
| 159 |
-
return batches
|
| 160 |
-
|
| 161 |
-
def __iter__(self):
|
| 162 |
-
idx = np.arange(len(self.n_objects))
|
| 163 |
-
batches = self.sample_batches(idx)
|
| 164 |
-
return iter(batches)
|
| 165 |
-
|
| 166 |
-
def __len__(self):
|
| 167 |
-
if self.num_samples is not None:
|
| 168 |
-
return self.num_samples // self.batch_size
|
| 169 |
-
else:
|
| 170 |
-
return len(self.n_objects) // self.batch_size
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
class BalancedDistributedSampler(DistributedSampler):
|
| 174 |
-
def __init__(
|
| 175 |
-
self,
|
| 176 |
-
dataset: Dataset,
|
| 177 |
-
batch_size: int,
|
| 178 |
-
n_pool: int,
|
| 179 |
-
num_samples: int,
|
| 180 |
-
weight_by_ndivs: bool = False,
|
| 181 |
-
weight_by_dataset: bool = False,
|
| 182 |
-
*args,
|
| 183 |
-
**kwargs,
|
| 184 |
-
) -> None:
|
| 185 |
-
super().__init__(dataset=dataset, *args, drop_last=True, **kwargs)
|
| 186 |
-
self._balanced_batch_sampler = BalancedBatchSampler(
|
| 187 |
-
dataset,
|
| 188 |
-
batch_size=batch_size,
|
| 189 |
-
n_pool=n_pool,
|
| 190 |
-
num_samples=max(1, num_samples // self.num_replicas),
|
| 191 |
-
weight_by_ndivs=weight_by_ndivs,
|
| 192 |
-
weight_by_dataset=weight_by_dataset,
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
def __len__(self) -> int:
|
| 196 |
-
if self.num_samples is not None:
|
| 197 |
-
return self._balanced_batch_sampler.num_samples
|
| 198 |
-
else:
|
| 199 |
-
return super().__len__()
|
| 200 |
-
|
| 201 |
-
def __iter__(self):
|
| 202 |
-
indices = list(super().__iter__())
|
| 203 |
-
batches = self._balanced_batch_sampler.sample_batches(indices)
|
| 204 |
-
for batch in batches:
|
| 205 |
-
yield from batch
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
# class BalancedDataModule(LightningDataModule):
|
| 209 |
-
# def __init__(
|
| 210 |
-
# self,
|
| 211 |
-
# input_train: list,
|
| 212 |
-
# input_val: list,
|
| 213 |
-
# cachedir: str,
|
| 214 |
-
# augment: int,
|
| 215 |
-
# distributed: bool,
|
| 216 |
-
# dataset_kwargs: dict,
|
| 217 |
-
# sampler_kwargs: dict,
|
| 218 |
-
# loader_kwargs: dict,
|
| 219 |
-
# ):
|
| 220 |
-
# super().__init__()
|
| 221 |
-
# self.input_train = input_train
|
| 222 |
-
# self.input_val = input_val
|
| 223 |
-
# self.cachedir = cachedir
|
| 224 |
-
# self.augment = augment
|
| 225 |
-
# self.distributed = distributed
|
| 226 |
-
# self.dataset_kwargs = dataset_kwargs
|
| 227 |
-
# self.sampler_kwargs = sampler_kwargs
|
| 228 |
-
# self.loader_kwargs = loader_kwargs
|
| 229 |
-
|
| 230 |
-
# def prepare_data(self):
|
| 231 |
-
# """Loads and caches the datasets if not already done.
|
| 232 |
-
|
| 233 |
-
# Running on the main CPU process.
|
| 234 |
-
# """
|
| 235 |
-
# CTCData = cache_class(self.cachedir)
|
| 236 |
-
# datasets = dict()
|
| 237 |
-
# for split, inps in zip(
|
| 238 |
-
# ("train", "val"),
|
| 239 |
-
# (self.input_train, self.input_val),
|
| 240 |
-
# ):
|
| 241 |
-
# logger.info(f"Loading {split.upper()} data")
|
| 242 |
-
# start = default_timer()
|
| 243 |
-
# datasets[split] = torch.utils.data.ConcatDataset(
|
| 244 |
-
# CTCData(
|
| 245 |
-
# root=Path(inp),
|
| 246 |
-
# augment=self.augment if split == "train" else 0,
|
| 247 |
-
# **self.dataset_kwargs,
|
| 248 |
-
# )
|
| 249 |
-
# for inp in inps
|
| 250 |
-
# )
|
| 251 |
-
# logger.info(
|
| 252 |
-
# f"Loaded {len(datasets[split])} {split.upper()} samples (in"
|
| 253 |
-
# f" {(default_timer() - start):.1f} s)\n\n"
|
| 254 |
-
# )
|
| 255 |
-
|
| 256 |
-
# del datasets
|
| 257 |
-
|
| 258 |
-
# def setup(self, stage: str):
|
| 259 |
-
# CTCData = cache_class(self.cachedir)
|
| 260 |
-
# self.datasets = dict()
|
| 261 |
-
# for split, inps in zip(
|
| 262 |
-
# ("train", "val"),
|
| 263 |
-
# (self.input_train, self.input_val),
|
| 264 |
-
# ):
|
| 265 |
-
# logger.info(f"Loading {split.upper()} data")
|
| 266 |
-
# start = default_timer()
|
| 267 |
-
# self.datasets[split] = torch.utils.data.ConcatDataset(
|
| 268 |
-
# CTCData(
|
| 269 |
-
# root=Path(inp),
|
| 270 |
-
# augment=self.augment if split == "train" else 0,
|
| 271 |
-
# **self.dataset_kwargs,
|
| 272 |
-
# )
|
| 273 |
-
# for inp in inps
|
| 274 |
-
# )
|
| 275 |
-
# logger.info(
|
| 276 |
-
# f"Loaded {len(self.datasets[split])} {split.upper()} samples (in"
|
| 277 |
-
# f" {(default_timer() - start):.1f} s)\n\n"
|
| 278 |
-
# )
|
| 279 |
-
|
| 280 |
-
# def train_dataloader(self):
|
| 281 |
-
# loader_kwargs = self.loader_kwargs.copy()
|
| 282 |
-
# if self.distributed:
|
| 283 |
-
# sampler = BalancedDistributedSampler(
|
| 284 |
-
# self.datasets["train"],
|
| 285 |
-
# **self.sampler_kwargs,
|
| 286 |
-
# )
|
| 287 |
-
# batch_sampler = None
|
| 288 |
-
# else:
|
| 289 |
-
# sampler = None
|
| 290 |
-
# batch_sampler = BalancedBatchSampler(
|
| 291 |
-
# self.datasets["train"],
|
| 292 |
-
# **self.sampler_kwargs,
|
| 293 |
-
# )
|
| 294 |
-
# if not loader_kwargs["batch_size"] == batch_sampler.batch_size:
|
| 295 |
-
# raise ValueError(
|
| 296 |
-
# f"Batch size in loader_kwargs ({loader_kwargs['batch_size']}) and sampler_kwargs ({batch_sampler.batch_size}) must match"
|
| 297 |
-
# )
|
| 298 |
-
# del loader_kwargs["batch_size"]
|
| 299 |
-
|
| 300 |
-
# loader = DataLoader(
|
| 301 |
-
# self.datasets["train"],
|
| 302 |
-
# sampler=sampler,
|
| 303 |
-
# batch_sampler=batch_sampler,
|
| 304 |
-
# **loader_kwargs,
|
| 305 |
-
# )
|
| 306 |
-
# return loader
|
| 307 |
-
|
| 308 |
-
# def val_dataloader(self):
|
| 309 |
-
# val_loader_kwargs = deepcopy(self.loader_kwargs)
|
| 310 |
-
# val_loader_kwargs["persistent_workers"] = False
|
| 311 |
-
# val_loader_kwargs["num_workers"] = 1
|
| 312 |
-
# return DataLoader(
|
| 313 |
-
# self.datasets["val"],
|
| 314 |
-
# shuffle=False,
|
| 315 |
-
# **val_loader_kwargs,
|
| 316 |
-
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/example_data.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
|
| 3 |
-
import tifffile
|
| 4 |
-
|
| 5 |
-
root = Path(__file__).parent / "resources"
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def example_data_bacteria():
|
| 9 |
-
"""Bacteria images and masks from.
|
| 10 |
-
|
| 11 |
-
Van Vliet et al. Spatially Correlated Gene Expression in Bacterial Groups: The Role of Lineage History, Spatial Gradients, and Cell-Cell Interactions (2018)
|
| 12 |
-
https://doi.org/10.1016/j.cels.2018.03.009
|
| 13 |
-
|
| 14 |
-
subset of timelapse trpL/150310-11
|
| 15 |
-
"""
|
| 16 |
-
img = tifffile.imread(root / "trpL_150310-11_img.tif")
|
| 17 |
-
mask = tifffile.imread(root / "trpL_150310-11_mask.tif")
|
| 18 |
-
return img, mask
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def example_data_hela():
|
| 22 |
-
"""Hela data from the cell tracking challenge.
|
| 23 |
-
|
| 24 |
-
Neumann et al. Phenotypic profiling of the human genome by time-lapse microscopy reveals cell division genes (2010)
|
| 25 |
-
|
| 26 |
-
subset of Fluo-N2DL-HeLa/train/02
|
| 27 |
-
"""
|
| 28 |
-
img = tifffile.imread(root / "Fluo_Hela_02_img.tif")
|
| 29 |
-
mask = tifffile.imread(root / "Fluo_Hela_02_ERR_SEG.tif")
|
| 30 |
-
print(img.shape, mask.shape)
|
| 31 |
-
return img, mask
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def example_data_fluo_3d():
|
| 35 |
-
"""Fluo-N3DH-CHO data from the cell tracking challenge.
|
| 36 |
-
|
| 37 |
-
Dzyubachyk et al. Advanced Level-Set-Based Cell Tracking in Time-Lapse Fluorescence Microscopy (2010)
|
| 38 |
-
|
| 39 |
-
subset of Fluo-N3DH-CHO/train/02
|
| 40 |
-
"""
|
| 41 |
-
img = tifffile.imread(root / "Fluo-N3DH-CHO_02_img.tif")
|
| 42 |
-
mask = tifffile.imread(root / "Fluo-N3DH-CHO_02_ERR_SEG.tif")
|
| 43 |
-
return img, mask
|
| 44 |
-
|
| 45 |
-
def data_hela():
|
| 46 |
-
img = tifffile.imread("02_imgs.tif")
|
| 47 |
-
mask = tifffile.imread("02_masks.tif")
|
| 48 |
-
return img, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/features.py
DELETED
|
@@ -1,148 +0,0 @@
|
|
| 1 |
-
import itertools
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from skimage.measure import regionprops_table
|
| 6 |
-
|
| 7 |
-
# the property keys that are supported for 2 and 3 dim
|
| 8 |
-
|
| 9 |
-
_PROPERTIES = {
|
| 10 |
-
2: {
|
| 11 |
-
# FIXME: The only image regionprop possible now (when compressing) is mean_intensity,
|
| 12 |
-
# since we store a mask with the mean intensity of each detection as the image.
|
| 13 |
-
"regionprops": (
|
| 14 |
-
"label",
|
| 15 |
-
"area",
|
| 16 |
-
"intensity_mean",
|
| 17 |
-
"eccentricity",
|
| 18 |
-
"solidity",
|
| 19 |
-
"inertia_tensor",
|
| 20 |
-
),
|
| 21 |
-
# faster
|
| 22 |
-
"regionprops2": (
|
| 23 |
-
"label",
|
| 24 |
-
"area",
|
| 25 |
-
"intensity_mean",
|
| 26 |
-
"inertia_tensor",
|
| 27 |
-
),
|
| 28 |
-
"patch_regionprops": (
|
| 29 |
-
"label",
|
| 30 |
-
"area",
|
| 31 |
-
"intensity_mean",
|
| 32 |
-
"inertia_tensor",
|
| 33 |
-
),
|
| 34 |
-
},
|
| 35 |
-
3: {
|
| 36 |
-
"regionprops2": (
|
| 37 |
-
"label",
|
| 38 |
-
"area",
|
| 39 |
-
"intensity_mean",
|
| 40 |
-
"inertia_tensor",
|
| 41 |
-
),
|
| 42 |
-
"patch_regionprops": (
|
| 43 |
-
"label",
|
| 44 |
-
"area",
|
| 45 |
-
"intensity_mean",
|
| 46 |
-
"inertia_tensor",
|
| 47 |
-
),
|
| 48 |
-
},
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def extract_features_regionprops(
|
| 53 |
-
mask: np.ndarray,
|
| 54 |
-
img: np.ndarray,
|
| 55 |
-
labels: np.ndarray,
|
| 56 |
-
properties="regionprops2",
|
| 57 |
-
):
|
| 58 |
-
ndim = mask.ndim
|
| 59 |
-
assert ndim in (2, 3)
|
| 60 |
-
assert mask.shape == img.shape
|
| 61 |
-
|
| 62 |
-
prop_dict = _PROPERTIES[ndim]
|
| 63 |
-
if properties not in prop_dict:
|
| 64 |
-
raise ValueError(f"properties must be one of {prop_dict.keys()}")
|
| 65 |
-
properties_tuple = prop_dict[properties]
|
| 66 |
-
|
| 67 |
-
assert properties_tuple[0] == "label"
|
| 68 |
-
|
| 69 |
-
labels = np.asarray(labels)
|
| 70 |
-
|
| 71 |
-
# remove mask labels that are not present
|
| 72 |
-
# not needed, remove for speed
|
| 73 |
-
# mask[~np.isin(mask, labels)] = 0
|
| 74 |
-
|
| 75 |
-
df = pd.DataFrame(
|
| 76 |
-
regionprops_table(mask, intensity_image=img, properties=properties_tuple)
|
| 77 |
-
)
|
| 78 |
-
assert df.columns[0] == "label"
|
| 79 |
-
assert df.columns[1] == "area"
|
| 80 |
-
|
| 81 |
-
# the bnumber of inertia tensor columns depends on the dimensionality
|
| 82 |
-
n_cols_inertia = ndim**2
|
| 83 |
-
assert np.all(["inertia_tensor" in col for col in df.columns[-n_cols_inertia:]])
|
| 84 |
-
|
| 85 |
-
# Hack for backwards compatibility
|
| 86 |
-
if properties in ("regionprops", "patch_regionprops"):
|
| 87 |
-
# Nice for conceptual clarity, but does not matter for speed
|
| 88 |
-
# drop upper triangular part of symmetric inertia tensor
|
| 89 |
-
for i, j in itertools.product(range(ndim), repeat=2):
|
| 90 |
-
if i > j:
|
| 91 |
-
df.drop(f"inertia_tensor-{i}-{j}", axis=1, inplace=True)
|
| 92 |
-
|
| 93 |
-
table = df.to_numpy()
|
| 94 |
-
table[:, 1] *= 0.001
|
| 95 |
-
table[:, -n_cols_inertia:] *= 0.01
|
| 96 |
-
# reorder according to labels
|
| 97 |
-
features = np.zeros((len(labels), len(df.columns) - 1))
|
| 98 |
-
|
| 99 |
-
# faster than iterating over pandas dataframe
|
| 100 |
-
for row in table:
|
| 101 |
-
# old version with tuple indexing, slow.
|
| 102 |
-
# n = labels.index(int(row.label))
|
| 103 |
-
# features[n] = row.to_numpy()[1:]
|
| 104 |
-
|
| 105 |
-
# Only process regions present in the labels
|
| 106 |
-
n = np.where(labels == int(row[0]))[0]
|
| 107 |
-
if len(n) > 0:
|
| 108 |
-
# Remove label column (0)!
|
| 109 |
-
features[n[0]] = row[1:]
|
| 110 |
-
|
| 111 |
-
return features
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def extract_features_patch(
|
| 115 |
-
mask: np.ndarray,
|
| 116 |
-
img: np.ndarray,
|
| 117 |
-
coords: np.ndarray,
|
| 118 |
-
labels: np.ndarray,
|
| 119 |
-
width_patch: int = 16,
|
| 120 |
-
):
|
| 121 |
-
"""16x16 Image patch around detection."""
|
| 122 |
-
ndim = mask.ndim
|
| 123 |
-
assert ndim in (2, 3) and mask.shape == img.shape
|
| 124 |
-
if len(coords) == 0:
|
| 125 |
-
return np.zeros((0, width_patch * width_patch))
|
| 126 |
-
|
| 127 |
-
pads = (width_patch // 2,) * ndim
|
| 128 |
-
|
| 129 |
-
img = np.pad(
|
| 130 |
-
img,
|
| 131 |
-
tuple((p, p) for p in pads),
|
| 132 |
-
mode="constant",
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
coords = coords.astype(int) + np.array(pads)
|
| 136 |
-
|
| 137 |
-
ss = tuple(
|
| 138 |
-
tuple(slice(_c - width_patch // 2, _c + width_patch // 2) for _c in c)
|
| 139 |
-
for c in coords
|
| 140 |
-
)
|
| 141 |
-
fs = tuple(img[_s] for _s in ss)
|
| 142 |
-
|
| 143 |
-
# max project along z if 3D
|
| 144 |
-
if ndim == 3:
|
| 145 |
-
fs = tuple(f.max(0) for f in fs)
|
| 146 |
-
|
| 147 |
-
features = np.stack([f.flatten() for f in fs])
|
| 148 |
-
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/matching.py
DELETED
|
@@ -1,251 +0,0 @@
|
|
| 1 |
-
# Adapted from https://github.com/stardist/stardist/blob/master/stardist/matching.py
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
from numba import jit
|
| 5 |
-
from scipy.optimize import linear_sum_assignment
|
| 6 |
-
from scipy.spatial.distance import cdist
|
| 7 |
-
from skimage.measure import regionprops
|
| 8 |
-
|
| 9 |
-
matching_criteria = dict()
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def label_are_sequential(y):
|
| 13 |
-
"""Returns true if y has only sequential labels from 1..."""
|
| 14 |
-
labels = np.unique(y)
|
| 15 |
-
return (set(labels) - {0}) == set(range(1, 1 + labels.max()))
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def is_array_of_integers(y):
|
| 19 |
-
return isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _check_label_array(y, name=None, check_sequential=False):
|
| 23 |
-
err = ValueError(
|
| 24 |
-
"{label} must be an array of {integers}.".format(
|
| 25 |
-
label="labels" if name is None else name,
|
| 26 |
-
integers=("sequential " if check_sequential else "")
|
| 27 |
-
+ "non-negative integers",
|
| 28 |
-
)
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
if not is_array_of_integers(y):
|
| 32 |
-
raise err
|
| 33 |
-
if len(y) == 0:
|
| 34 |
-
return True
|
| 35 |
-
if check_sequential and not label_are_sequential(y):
|
| 36 |
-
raise err
|
| 37 |
-
else:
|
| 38 |
-
if not y.min() >= 0:
|
| 39 |
-
raise err
|
| 40 |
-
return True
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def label_overlap(x, y, check=True):
|
| 44 |
-
if check:
|
| 45 |
-
_check_label_array(x, "x", True)
|
| 46 |
-
_check_label_array(y, "y", True)
|
| 47 |
-
if not x.shape == y.shape:
|
| 48 |
-
raise ValueError("x and y must have the same shape")
|
| 49 |
-
return _label_overlap(x, y)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@jit(nopython=True)
|
| 53 |
-
def _label_overlap(x, y):
|
| 54 |
-
x = x.ravel()
|
| 55 |
-
y = y.ravel()
|
| 56 |
-
overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint32)
|
| 57 |
-
for i in range(len(x)):
|
| 58 |
-
overlap[x[i], y[i]] += 1
|
| 59 |
-
return overlap[1:, 1:]
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _safe_divide(x, y, eps=1e-10):
|
| 63 |
-
"""Computes a safe divide which returns 0 if y is zero."""
|
| 64 |
-
if np.isscalar(x) and np.isscalar(y):
|
| 65 |
-
return x / y if np.abs(y) > eps else 0.0
|
| 66 |
-
else:
|
| 67 |
-
out = np.zeros(np.broadcast(x, y).shape, np.float32)
|
| 68 |
-
np.divide(x, y, out=out, where=np.abs(y) > eps)
|
| 69 |
-
return out
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def intersection_over_union(overlap):
|
| 73 |
-
_check_label_array(overlap, "overlap")
|
| 74 |
-
if np.sum(overlap) == 0:
|
| 75 |
-
return overlap
|
| 76 |
-
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
|
| 77 |
-
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
|
| 78 |
-
return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap))
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def dist_score(y_true, y_pred, max_distance: int = 10):
|
| 82 |
-
"""Compute distance score between centroids of regions in y_true and y_pred
|
| 83 |
-
and returns a score matrix of shape (n_true, n_pred) with values in [0,1]
|
| 84 |
-
where
|
| 85 |
-
distance >= max_distance -> score = 0
|
| 86 |
-
distance = 0 -> score = 1.
|
| 87 |
-
"""
|
| 88 |
-
c_true = np.stack([r.centroid for r in regionprops(y_true)], axis=0)
|
| 89 |
-
c_pred = np.stack([r.centroid for r in regionprops(y_pred)], axis=0)
|
| 90 |
-
dist = np.minimum(cdist(c_true, c_pred), max_distance)
|
| 91 |
-
score = 1 - dist / max_distance
|
| 92 |
-
return score
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# copied from scikit-image master for now (remove when part of a release)
|
| 96 |
-
def relabel_sequential(label_field, offset=1):
|
| 97 |
-
"""Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.
|
| 98 |
-
|
| 99 |
-
This function also returns the forward map (mapping the original labels to
|
| 100 |
-
the reduced labels) and the inverse map (mapping the reduced labels back
|
| 101 |
-
to the original ones).
|
| 102 |
-
|
| 103 |
-
Parameters
|
| 104 |
-
----------
|
| 105 |
-
label_field : numpy array of int, arbitrary shape
|
| 106 |
-
An array of labels, which must be non-negative integers.
|
| 107 |
-
offset : int, optional
|
| 108 |
-
The return labels will start at `offset`, which should be
|
| 109 |
-
strictly positive.
|
| 110 |
-
|
| 111 |
-
Returns:
|
| 112 |
-
-------
|
| 113 |
-
relabeled : numpy array of int, same shape as `label_field`
|
| 114 |
-
The input label field with labels mapped to
|
| 115 |
-
{offset, ..., number_of_labels + offset - 1}.
|
| 116 |
-
The data type will be the same as `label_field`, except when
|
| 117 |
-
offset + number_of_labels causes overflow of the current data type.
|
| 118 |
-
forward_map : numpy array of int, shape ``(label_field.max() + 1,)``
|
| 119 |
-
The map from the original label space to the returned label
|
| 120 |
-
space. Can be used to re-apply the same mapping. See examples
|
| 121 |
-
for usage. The data type will be the same as `relabeled`.
|
| 122 |
-
inverse_map : 1D numpy array of int, of length offset + number of labels
|
| 123 |
-
The map from the new label space to the original space. This
|
| 124 |
-
can be used to reconstruct the original label field from the
|
| 125 |
-
relabeled one. The data type will be the same as `relabeled`.
|
| 126 |
-
|
| 127 |
-
Notes:
|
| 128 |
-
-----
|
| 129 |
-
The label 0 is assumed to denote the background and is never remapped.
|
| 130 |
-
|
| 131 |
-
The forward map can be extremely big for some inputs, since its
|
| 132 |
-
length is given by the maximum of the label field. However, in most
|
| 133 |
-
situations, ``label_field.max()`` is much smaller than
|
| 134 |
-
``label_field.size``, and in these cases the forward map is
|
| 135 |
-
guaranteed to be smaller than either the input or output images.
|
| 136 |
-
|
| 137 |
-
Examples:
|
| 138 |
-
--------
|
| 139 |
-
>>> from skimage.segmentation import relabel_sequential
|
| 140 |
-
>>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])
|
| 141 |
-
>>> relab, fw, inv = relabel_sequential(label_field)
|
| 142 |
-
>>> relab
|
| 143 |
-
array([1, 1, 2, 2, 3, 5, 4])
|
| 144 |
-
>>> fw
|
| 145 |
-
array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 146 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,
|
| 147 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 148 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 149 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])
|
| 150 |
-
>>> inv
|
| 151 |
-
array([ 0, 1, 5, 8, 42, 99])
|
| 152 |
-
>>> (fw[label_field] == relab).all()
|
| 153 |
-
True
|
| 154 |
-
>>> (inv[relab] == label_field).all()
|
| 155 |
-
True
|
| 156 |
-
>>> relab, fw, inv = relabel_sequential(label_field, offset=5)
|
| 157 |
-
>>> relab
|
| 158 |
-
array([5, 5, 6, 6, 7, 9, 8])
|
| 159 |
-
"""
|
| 160 |
-
offset = int(offset)
|
| 161 |
-
if offset <= 0:
|
| 162 |
-
raise ValueError("Offset must be strictly positive.")
|
| 163 |
-
if np.min(label_field) < 0:
|
| 164 |
-
raise ValueError("Cannot relabel array that contains negative values.")
|
| 165 |
-
max_label = int(label_field.max()) # Ensure max_label is an integer
|
| 166 |
-
if not np.issubdtype(label_field.dtype, np.integer):
|
| 167 |
-
new_type = np.min_scalar_type(max_label)
|
| 168 |
-
label_field = label_field.astype(new_type)
|
| 169 |
-
labels = np.unique(label_field)
|
| 170 |
-
labels0 = labels[labels != 0]
|
| 171 |
-
new_max_label = offset - 1 + len(labels0)
|
| 172 |
-
new_labels0 = np.arange(offset, new_max_label + 1)
|
| 173 |
-
output_type = label_field.dtype
|
| 174 |
-
required_type = np.min_scalar_type(new_max_label)
|
| 175 |
-
if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:
|
| 176 |
-
output_type = required_type
|
| 177 |
-
forward_map = np.zeros(max_label + 1, dtype=output_type)
|
| 178 |
-
forward_map[labels0] = new_labels0
|
| 179 |
-
inverse_map = np.zeros(new_max_label + 1, dtype=output_type)
|
| 180 |
-
inverse_map[offset:] = labels0
|
| 181 |
-
relabeled = forward_map[label_field]
|
| 182 |
-
return relabeled, forward_map, inverse_map
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def matching(y_true, y_pred, threshold=0.5, max_distance: int = 16):
|
| 186 |
-
"""Computes IoU and distance score between all pairs of regions in y_true and y_pred.
|
| 187 |
-
|
| 188 |
-
returns the true/pred matching based on the higher of the two scores for each pair of regions
|
| 189 |
-
|
| 190 |
-
Parameters
|
| 191 |
-
----------
|
| 192 |
-
y_true: ndarray
|
| 193 |
-
ground truth label image (integer valued)
|
| 194 |
-
y_pred: ndarray
|
| 195 |
-
predicted label image (integer valued)
|
| 196 |
-
threshold: float
|
| 197 |
-
threshold for matching criterion (default 0.5)
|
| 198 |
-
max_distance: int
|
| 199 |
-
maximum distance between centroids of regions in y_true and y_pred (default 16)
|
| 200 |
-
|
| 201 |
-
Returns:
|
| 202 |
-
-------
|
| 203 |
-
gt_pred: tuple
|
| 204 |
-
tuple of all matched region label pairs in y_true and y_pred
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
"""
|
| 208 |
-
y_true, y_pred = y_true.astype(np.int32), y_pred.astype(np.int32)
|
| 209 |
-
_check_label_array(y_true, "y_true")
|
| 210 |
-
_check_label_array(y_pred, "y_pred")
|
| 211 |
-
if not y_true.shape == y_pred.shape:
|
| 212 |
-
raise ValueError(
|
| 213 |
-
f"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes"
|
| 214 |
-
)
|
| 215 |
-
if threshold is None:
|
| 216 |
-
threshold = 0
|
| 217 |
-
|
| 218 |
-
threshold = float(threshold) if np.isscalar(threshold) else map(float, threshold)
|
| 219 |
-
|
| 220 |
-
y_true, _, map_rev_true = relabel_sequential(y_true)
|
| 221 |
-
y_pred, _, map_rev_pred = relabel_sequential(y_pred)
|
| 222 |
-
|
| 223 |
-
overlap = label_overlap(y_true, y_pred, check=False)
|
| 224 |
-
|
| 225 |
-
scores_iou = intersection_over_union(overlap)
|
| 226 |
-
scores_dist = dist_score(y_true, y_pred, max_distance)
|
| 227 |
-
scores = np.maximum(scores_iou, scores_dist)
|
| 228 |
-
|
| 229 |
-
assert 0 <= np.min(scores) <= np.max(scores) <= 1
|
| 230 |
-
|
| 231 |
-
n_true, n_pred = scores.shape
|
| 232 |
-
n_matched = min(n_true, n_pred)
|
| 233 |
-
|
| 234 |
-
# not_trivial = n_matched > 0 and np.any(scores >= thr)
|
| 235 |
-
not_trivial = n_matched > 0
|
| 236 |
-
if not_trivial:
|
| 237 |
-
# compute optimal matching with scores as tie-breaker
|
| 238 |
-
costs = -(scores >= threshold).astype(float) - scores / (2 * n_matched)
|
| 239 |
-
true_ind, pred_ind = linear_sum_assignment(costs)
|
| 240 |
-
assert n_matched == len(true_ind) == len(pred_ind)
|
| 241 |
-
match_ok = scores[true_ind, pred_ind] >= threshold
|
| 242 |
-
true_ind = true_ind[match_ok]
|
| 243 |
-
pred_ind = pred_ind[match_ok]
|
| 244 |
-
matched = tuple(
|
| 245 |
-
(int(map_rev_true[i]), int(map_rev_pred[j]))
|
| 246 |
-
for i, j in zip(1 + true_ind, 1 + pred_ind)
|
| 247 |
-
)
|
| 248 |
-
else:
|
| 249 |
-
matched = ()
|
| 250 |
-
|
| 251 |
-
return matched
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/data/utils.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import sys
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pandas as pd
|
| 7 |
-
|
| 8 |
-
# from .data import CTCData
|
| 9 |
-
import tifffile
|
| 10 |
-
from tqdm import tqdm
|
| 11 |
-
from typing import Optional, Union, Tuple
|
| 12 |
-
|
| 13 |
-
logger = logging.getLogger(__name__)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def load_tiff_timeseries(
|
| 17 |
-
dir: Path,
|
| 18 |
-
dtype: Optional[Union[str, type]] = None,
|
| 19 |
-
downscale: Optional[Tuple[int, ...]] = None,
|
| 20 |
-
start_frame: int = 0,
|
| 21 |
-
end_frame: Optional[int] = None,
|
| 22 |
-
) -> np.ndarray:
|
| 23 |
-
"""Loads a folder of `.tif` or `.tiff` files into a numpy array.
|
| 24 |
-
Each file is interpreted as a frame of a time series.
|
| 25 |
-
|
| 26 |
-
Args:
|
| 27 |
-
folder:
|
| 28 |
-
dtype:
|
| 29 |
-
downscale: One int for each dimension of the data. Avoids memory overhead.
|
| 30 |
-
start_frame: The first frame to load.
|
| 31 |
-
end_frame: The last frame to load.
|
| 32 |
-
|
| 33 |
-
Returns:
|
| 34 |
-
np.ndarray: The loaded data.
|
| 35 |
-
"""
|
| 36 |
-
# TODO make safe for label arrays
|
| 37 |
-
logger.debug(f"Loading tiffs from {dir} as {dtype}")
|
| 38 |
-
files = sorted(list(dir.glob("*.tif")) + list(dir.glob("*.tiff")))[
|
| 39 |
-
start_frame:end_frame
|
| 40 |
-
]
|
| 41 |
-
shape = tifffile.imread(files[0]).shape
|
| 42 |
-
if downscale:
|
| 43 |
-
assert len(downscale) == len(shape)
|
| 44 |
-
else:
|
| 45 |
-
downscale = (1,) * len(shape)
|
| 46 |
-
|
| 47 |
-
files = files[:: downscale[0]]
|
| 48 |
-
|
| 49 |
-
x = []
|
| 50 |
-
for f in tqdm(
|
| 51 |
-
files,
|
| 52 |
-
leave=False,
|
| 53 |
-
desc=f"Loading [{start_frame}:{end_frame}:{downscale[0]}]",
|
| 54 |
-
):
|
| 55 |
-
_x = tifffile.imread(f)
|
| 56 |
-
if dtype:
|
| 57 |
-
_x = _x.astype(dtype)
|
| 58 |
-
assert _x.shape == shape
|
| 59 |
-
slices = tuple(slice(None, None, d) for d in downscale[1:])
|
| 60 |
-
_x = _x[slices]
|
| 61 |
-
x.append(_x)
|
| 62 |
-
|
| 63 |
-
x = np.stack(x)
|
| 64 |
-
logger.debug(f"Loaded array of shape {x.shape} from {dir}")
|
| 65 |
-
return x
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def load_tracklet_links(folder: Path) -> pd.DataFrame:
|
| 69 |
-
candidates = [
|
| 70 |
-
folder / "man_track.txt",
|
| 71 |
-
folder / "res_track.txt",
|
| 72 |
-
]
|
| 73 |
-
for c in candidates:
|
| 74 |
-
if c.exists():
|
| 75 |
-
path = c
|
| 76 |
-
break
|
| 77 |
-
else:
|
| 78 |
-
raise FileNotFoundError(f"Could not find tracklet links in {folder}")
|
| 79 |
-
|
| 80 |
-
df = pd.read_csv(
|
| 81 |
-
path,
|
| 82 |
-
delimiter=" ",
|
| 83 |
-
names=["label", "t1", "t2", "parent"],
|
| 84 |
-
dtype=int,
|
| 85 |
-
)
|
| 86 |
-
# Remove invalid tracks with t2 > t1
|
| 87 |
-
df = df[df.t1 <= df.t2]
|
| 88 |
-
|
| 89 |
-
n_dets = (df.t2 - df.t1 + 1).sum()
|
| 90 |
-
logger.debug(f"{folder} has {n_dets} detections")
|
| 91 |
-
|
| 92 |
-
n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum()
|
| 93 |
-
logger.debug(f"{folder} has {n_divs} divisions")
|
| 94 |
-
return df
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def filter_track_df(
|
| 98 |
-
df: pd.DataFrame,
|
| 99 |
-
start_frame: int = 0,
|
| 100 |
-
end_frame: int = sys.maxsize,
|
| 101 |
-
downscale: int = 1,
|
| 102 |
-
) -> pd.DataFrame:
|
| 103 |
-
"""Only keep tracklets that are present in the given time interval."""
|
| 104 |
-
df.columns = ["label", "t1", "t2", "parent"]
|
| 105 |
-
# only retain cells in interval
|
| 106 |
-
df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)]
|
| 107 |
-
|
| 108 |
-
# shift start and end of each cell
|
| 109 |
-
df.t1 = df.t1 - start_frame
|
| 110 |
-
df.t2 = df.t2 - start_frame
|
| 111 |
-
# set start/end to min/max
|
| 112 |
-
df.t1 = df.t1.clip(0, end_frame - start_frame - 1)
|
| 113 |
-
df.t2 = df.t2.clip(0, end_frame - start_frame - 1)
|
| 114 |
-
# set all parents to 0 that are not in the interval
|
| 115 |
-
df.loc[~df.parent.isin(df.label), "parent"] = 0
|
| 116 |
-
|
| 117 |
-
if downscale > 1:
|
| 118 |
-
if start_frame % downscale != 0:
|
| 119 |
-
raise ValueError("start_frame must be a multiple of downscale")
|
| 120 |
-
|
| 121 |
-
logger.debug(f"Temporal downscaling of tracklet links by {downscale}")
|
| 122 |
-
|
| 123 |
-
# remove tracklets that have been fully deleted by temporal downsampling
|
| 124 |
-
|
| 125 |
-
mask = (
|
| 126 |
-
# (df["t2"] - df["t1"] < downscale - 1)
|
| 127 |
-
(df["t1"] % downscale != 0)
|
| 128 |
-
& (df["t2"] % downscale != 0)
|
| 129 |
-
& (df["t1"] // downscale == df["t2"] // downscale)
|
| 130 |
-
)
|
| 131 |
-
logger.debug(
|
| 132 |
-
f"Remove {mask.sum()} tracklets that are fully deleted by downsampling"
|
| 133 |
-
)
|
| 134 |
-
logger.debug(f"Remove {df[mask]}")
|
| 135 |
-
|
| 136 |
-
df = df[~mask]
|
| 137 |
-
# set parent to 0 if it has been deleted
|
| 138 |
-
df.loc[~df.parent.isin(df.label), "parent"] = 0
|
| 139 |
-
|
| 140 |
-
df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int)
|
| 141 |
-
df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int)
|
| 142 |
-
|
| 143 |
-
# Correct for edge case of single frame tracklet
|
| 144 |
-
assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"]))
|
| 145 |
-
|
| 146 |
-
return df
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
# TODO fix
|
| 150 |
-
# def dataset_to_ctc(dataset: CTCData, path, start: int = 0, stop: int | None = None):
|
| 151 |
-
# """save dataset to ctc format for debugging purposes"""
|
| 152 |
-
# out = Path(path)
|
| 153 |
-
# print(f"Saving dataset to {out}")
|
| 154 |
-
# out_img = out / "img"
|
| 155 |
-
# out_img.mkdir(exist_ok=True, parents=True)
|
| 156 |
-
# out_mask = out / "TRA"
|
| 157 |
-
# out_mask.mkdir(exist_ok=True, parents=True)
|
| 158 |
-
# if stop is None:
|
| 159 |
-
# stop = len(self)
|
| 160 |
-
# lines = []
|
| 161 |
-
# masks, imgs = [], []
|
| 162 |
-
# t_offset = 0
|
| 163 |
-
# max_mask = 0
|
| 164 |
-
# n_lines = 0
|
| 165 |
-
# all_coords = []
|
| 166 |
-
# for i in tqdm(range(start, stop)):
|
| 167 |
-
# d = dataset.__getitem__(i, return_dense=True)
|
| 168 |
-
# mask = d["mask"].numpy()
|
| 169 |
-
# mask[mask > 0] += max_mask
|
| 170 |
-
# max_mask = max(max_mask, mask.max())
|
| 171 |
-
# masks.extend(mask)
|
| 172 |
-
# imgs.extend(d["img"].numpy())
|
| 173 |
-
# # add vertices
|
| 174 |
-
# coords = d["coords0"].numpy()
|
| 175 |
-
# ts, coords = coords[:, 0].astype(int), coords[:, 1:]
|
| 176 |
-
# A = d["assoc_matrix"].numpy()
|
| 177 |
-
# t_unique = sorted(np.unique(ts))
|
| 178 |
-
# for t1, t2 in zip(t_unique[:-1], t_unique[1:]):
|
| 179 |
-
# A_sub = A[ts == t1][:, ts == t2]
|
| 180 |
-
# for i, a in enumerate(A_sub):
|
| 181 |
-
|
| 182 |
-
# v1 = coords[ts == t1][i]
|
| 183 |
-
# for j in np.where(a > 0)[0]:
|
| 184 |
-
# v2 = coords[ts == t2][j]
|
| 185 |
-
# # lines.append(
|
| 186 |
-
# # {
|
| 187 |
-
# # "index": n_lines,
|
| 188 |
-
# # "shape-type": "line",
|
| 189 |
-
# # "vertex-index": 0,
|
| 190 |
-
# # "axis-0": t2 + t_offset,
|
| 191 |
-
# # "axis-1": v1[0],
|
| 192 |
-
# # "axis-2": v1[1],
|
| 193 |
-
# # }
|
| 194 |
-
# # )
|
| 195 |
-
# # lines.append(
|
| 196 |
-
# # {
|
| 197 |
-
# # "index": n_lines,
|
| 198 |
-
# # "shape-type": "line",
|
| 199 |
-
# # "vertex-index": 1,
|
| 200 |
-
# # "axis-0": t2 + t_offset,
|
| 201 |
-
# # "axis-1": v2[0],
|
| 202 |
-
# # "axis-2": v2[1],
|
| 203 |
-
# # }
|
| 204 |
-
# # )
|
| 205 |
-
# lines.append([n_lines, "line", 0, t2 + t_offset] + v1.tolist())
|
| 206 |
-
# lines.append([n_lines, "line", 1, t2 + t_offset] + v2.tolist())
|
| 207 |
-
# n_lines += 1
|
| 208 |
-
|
| 209 |
-
# c = d["coords0"].numpy()
|
| 210 |
-
# c[:, 0] += t_offset
|
| 211 |
-
# all_coords.extend(c)
|
| 212 |
-
# t_offset += len(mask)
|
| 213 |
-
|
| 214 |
-
# ax_cols = [f"axis-{i}" for i in range(dataset.ndim + 1)]
|
| 215 |
-
# df = pd.DataFrame(lines, columns=["index", "shape-type", "vertex-index"] + ax_cols)
|
| 216 |
-
# df.to_csv(out / "lines.csv", index=False)
|
| 217 |
-
|
| 218 |
-
# df_c = pd.DataFrame(all_coords, columns=ax_cols)
|
| 219 |
-
# df_c.to_csv(out / "coords.csv", index=False)
|
| 220 |
-
|
| 221 |
-
# for i, m in enumerate(imgs):
|
| 222 |
-
# # tifffile.imwrite(out_img/f'img_{i:04d}.tif', m)
|
| 223 |
-
# if dataset.ndim == 2:
|
| 224 |
-
# imageio.imwrite(
|
| 225 |
-
# out_img / f"img_{i:04d}.jpg",
|
| 226 |
-
# np.clip(20 + 100 * m, 0, 255).astype(np.uint8),
|
| 227 |
-
# )
|
| 228 |
-
|
| 229 |
-
# for i, m in enumerate(masks):
|
| 230 |
-
# tifffile.imwrite(out_mask / f"mask_{i:04d}.tif", m, compression="zstd")
|
| 231 |
-
|
| 232 |
-
# return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# ruff: noqa: F401
|
| 2 |
-
|
| 3 |
-
from .model import TrackingTransformer
|
| 4 |
-
from .model_api import Trackastra
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/model_api.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Literal, Union, Optional, Tuple
|
| 5 |
-
|
| 6 |
-
import dask.array as da
|
| 7 |
-
import numpy as np
|
| 8 |
-
import tifffile
|
| 9 |
-
import torch
|
| 10 |
-
import yaml
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from ..data import build_windows, get_features, load_tiff_timeseries
|
| 14 |
-
from ..tracking import TrackGraph, build_graph, track_greedy
|
| 15 |
-
from ..utils import normalize
|
| 16 |
-
from .model import TrackingTransformer
|
| 17 |
-
from .predict import predict_windows
|
| 18 |
-
from .pretrained import download_pretrained
|
| 19 |
-
|
| 20 |
-
logging.basicConfig(level=logging.INFO)
|
| 21 |
-
logger = logging.getLogger(__name__)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Trackastra:
|
| 25 |
-
"""A transformer-based tracking model for time-lapse data.
|
| 26 |
-
|
| 27 |
-
Trackastra links segmented objects across time frames by predicting
|
| 28 |
-
associations with a transformer model trained on diverse time-lapse videos.
|
| 29 |
-
|
| 30 |
-
The model takes as input:
|
| 31 |
-
- A sequence of images of shape (T,(Z),Y,X)
|
| 32 |
-
- Corresponding instance segmentation masks of shape (T,(Z),Y,X)
|
| 33 |
-
|
| 34 |
-
It supports multiple tracking modes:
|
| 35 |
-
- greedy_nodiv: Fast greedy linking without division
|
| 36 |
-
- greedy: Fast greedy linking with division
|
| 37 |
-
- ilp: Integer Linear Programming based linking (more accurate but slower)
|
| 38 |
-
|
| 39 |
-
Examples:
|
| 40 |
-
>>> # Load example data
|
| 41 |
-
>>> from trackastra.data import example_data_bacteria
|
| 42 |
-
>>> imgs, masks = example_data_bacteria()
|
| 43 |
-
>>>
|
| 44 |
-
>>> # Load pretrained model and track
|
| 45 |
-
>>> model = Trackastra.from_pretrained("general_2d", device="cuda")
|
| 46 |
-
>>> track_graph = model.track(imgs, masks, mode="greedy")
|
| 47 |
-
"""
|
| 48 |
-
|
| 49 |
-
def __init__(
|
| 50 |
-
self,
|
| 51 |
-
transformer: TrackingTransformer,
|
| 52 |
-
train_args: dict,
|
| 53 |
-
device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
|
| 54 |
-
):
|
| 55 |
-
"""Initialize Trackastra model.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
transformer: The underlying transformer model.
|
| 59 |
-
train_args: Training configuration arguments.
|
| 60 |
-
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
|
| 61 |
-
"""
|
| 62 |
-
if device == "cuda":
|
| 63 |
-
if torch.cuda.is_available():
|
| 64 |
-
self.device = "cuda"
|
| 65 |
-
else:
|
| 66 |
-
logger.info("Cuda not available, falling back to cpu.")
|
| 67 |
-
self.device = "cpu"
|
| 68 |
-
elif device == "mps":
|
| 69 |
-
if (
|
| 70 |
-
torch.backends.mps.is_available()
|
| 71 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
|
| 72 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
|
| 73 |
-
):
|
| 74 |
-
self.device = "mps"
|
| 75 |
-
else:
|
| 76 |
-
logger.info("Mps not available, falling back to cpu.")
|
| 77 |
-
self.device = "cpu"
|
| 78 |
-
elif device == "cpu":
|
| 79 |
-
self.device = "cpu"
|
| 80 |
-
elif device == "automatic" or device is None:
|
| 81 |
-
should_use_mps = (
|
| 82 |
-
torch.backends.mps.is_available()
|
| 83 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
|
| 84 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
|
| 85 |
-
)
|
| 86 |
-
self.device = (
|
| 87 |
-
"cuda"
|
| 88 |
-
if torch.cuda.is_available()
|
| 89 |
-
else (
|
| 90 |
-
"mps"
|
| 91 |
-
if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
|
| 92 |
-
else "cpu"
|
| 93 |
-
)
|
| 94 |
-
)
|
| 95 |
-
else:
|
| 96 |
-
raise ValueError(f"Device {device} not recognized.")
|
| 97 |
-
|
| 98 |
-
logger.info(f"Using device {self.device}")
|
| 99 |
-
|
| 100 |
-
self.transformer = transformer.to(self.device)
|
| 101 |
-
self.train_args = train_args
|
| 102 |
-
|
| 103 |
-
@classmethod
|
| 104 |
-
def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None):
|
| 105 |
-
"""Load a Trackastra model from a local folder.
|
| 106 |
-
|
| 107 |
-
Args:
|
| 108 |
-
dir: Path to model folder containing:
|
| 109 |
-
- model weights
|
| 110 |
-
- train_config.yaml with training arguments
|
| 111 |
-
device: Device to run model on.
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
Trackastra model instance.
|
| 115 |
-
"""
|
| 116 |
-
# Always load to cpu first
|
| 117 |
-
transformer = TrackingTransformer.from_folder(
|
| 118 |
-
Path(dir).expanduser(), map_location="cpu"
|
| 119 |
-
)
|
| 120 |
-
train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
|
| 121 |
-
return cls(transformer=transformer, train_args=train_args, device=device)
|
| 122 |
-
|
| 123 |
-
@classmethod
|
| 124 |
-
def from_pretrained(
|
| 125 |
-
cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None
|
| 126 |
-
):
|
| 127 |
-
"""Load a pretrained Trackastra model.
|
| 128 |
-
|
| 129 |
-
Available pretrained models are described in detail in pretrained.json.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
name: Name of pretrained model (e.g. "general_2d").
|
| 133 |
-
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
|
| 134 |
-
download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
Trackastra model instance.
|
| 138 |
-
"""
|
| 139 |
-
folder = download_pretrained(name, download_dir)
|
| 140 |
-
# download zip from github to location/name, then unzip
|
| 141 |
-
return cls.from_folder(folder, device=device)
|
| 142 |
-
|
| 143 |
-
def _predict(
|
| 144 |
-
self,
|
| 145 |
-
imgs: Union[np.ndarray, da.Array],
|
| 146 |
-
masks: Union[np.ndarray, da.Array],
|
| 147 |
-
edge_threshold: float = 0.05,
|
| 148 |
-
n_workers: int = 0,
|
| 149 |
-
normalize_imgs: bool = True,
|
| 150 |
-
progbar_class=tqdm,
|
| 151 |
-
):
|
| 152 |
-
logger.info("Predicting weights for candidate graph")
|
| 153 |
-
if normalize_imgs:
|
| 154 |
-
if isinstance(imgs, da.Array):
|
| 155 |
-
imgs = imgs.map_blocks(normalize)
|
| 156 |
-
else:
|
| 157 |
-
imgs = normalize(imgs)
|
| 158 |
-
|
| 159 |
-
self.transformer.eval()
|
| 160 |
-
|
| 161 |
-
features = get_features(
|
| 162 |
-
detections=masks,
|
| 163 |
-
imgs=imgs,
|
| 164 |
-
ndim=self.transformer.config["coord_dim"],
|
| 165 |
-
n_workers=n_workers,
|
| 166 |
-
progbar_class=progbar_class,
|
| 167 |
-
)
|
| 168 |
-
logger.info("Building windows")
|
| 169 |
-
windows = build_windows(
|
| 170 |
-
features,
|
| 171 |
-
window_size=self.transformer.config["window"],
|
| 172 |
-
progbar_class=progbar_class,
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
logger.info("Predicting windows")
|
| 176 |
-
predictions = predict_windows(
|
| 177 |
-
windows=windows,
|
| 178 |
-
features=features,
|
| 179 |
-
model=self.transformer,
|
| 180 |
-
edge_threshold=edge_threshold,
|
| 181 |
-
spatial_dim=masks.ndim - 1,
|
| 182 |
-
progbar_class=progbar_class,
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
return predictions
|
| 186 |
-
|
| 187 |
-
def _track_from_predictions(
|
| 188 |
-
self,
|
| 189 |
-
predictions,
|
| 190 |
-
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
|
| 191 |
-
use_distance: bool = False,
|
| 192 |
-
max_distance: int = 256,
|
| 193 |
-
max_neighbors: int = 10,
|
| 194 |
-
delta_t: int = 1,
|
| 195 |
-
**kwargs,
|
| 196 |
-
):
|
| 197 |
-
logger.info("Running greedy tracker")
|
| 198 |
-
nodes = predictions["nodes"]
|
| 199 |
-
weights = predictions["weights"]
|
| 200 |
-
|
| 201 |
-
candidate_graph = build_graph(
|
| 202 |
-
nodes=nodes,
|
| 203 |
-
weights=weights,
|
| 204 |
-
use_distance=use_distance,
|
| 205 |
-
max_distance=max_distance,
|
| 206 |
-
max_neighbors=max_neighbors,
|
| 207 |
-
delta_t=delta_t,
|
| 208 |
-
)
|
| 209 |
-
if mode == "greedy":
|
| 210 |
-
return track_greedy(candidate_graph)
|
| 211 |
-
elif mode == "greedy_nodiv":
|
| 212 |
-
return track_greedy(candidate_graph, allow_divisions=False)
|
| 213 |
-
elif mode == "ilp":
|
| 214 |
-
from trackastra.tracking.ilp import track_ilp
|
| 215 |
-
|
| 216 |
-
return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
|
| 217 |
-
else:
|
| 218 |
-
raise ValueError(f"Tracking mode {mode} does not exist.")
|
| 219 |
-
|
| 220 |
-
def track(
|
| 221 |
-
self,
|
| 222 |
-
imgs: Union[np.ndarray, da.Array],
|
| 223 |
-
masks: Union[np.ndarray, da.Array],
|
| 224 |
-
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
|
| 225 |
-
normalize_imgs: bool = True,
|
| 226 |
-
progbar_class=tqdm,
|
| 227 |
-
n_workers: int = 0,
|
| 228 |
-
**kwargs,
|
| 229 |
-
) -> TrackGraph:
|
| 230 |
-
"""Track objects across time frames.
|
| 231 |
-
|
| 232 |
-
This method links segmented objects across time frames using the specified
|
| 233 |
-
tracking mode. No hyperparameters need to be chosen beyond the tracking mode.
|
| 234 |
-
|
| 235 |
-
Args:
|
| 236 |
-
imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array)
|
| 237 |
-
masks: Instance segmentation masks of shape (T,(Z),Y,X).
|
| 238 |
-
mode: Tracking mode:
|
| 239 |
-
- "greedy_nodiv": Fast greedy linking without division
|
| 240 |
-
- "greedy": Fast greedy linking with division
|
| 241 |
-
- "ilp": Integer Linear Programming based linking (more accurate but slower)
|
| 242 |
-
progbar_class: Progress bar class to use.
|
| 243 |
-
n_workers: Number of worker processes for feature extraction.
|
| 244 |
-
normalize_imgs: Whether to normalize the images.
|
| 245 |
-
**kwargs: Additional arguments passed to tracking algorithm.
|
| 246 |
-
|
| 247 |
-
Returns:
|
| 248 |
-
TrackGraph containing the tracking results.
|
| 249 |
-
"""
|
| 250 |
-
if not imgs.shape == masks.shape:
|
| 251 |
-
raise RuntimeError(
|
| 252 |
-
f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
if not imgs.ndim == self.transformer.config["coord_dim"] + 1:
|
| 256 |
-
raise RuntimeError(
|
| 257 |
-
f"images should be a sequence of {self.transformer.config['coord_dim']}D images"
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
predictions = self._predict(
|
| 261 |
-
imgs,
|
| 262 |
-
masks,
|
| 263 |
-
normalize_imgs=normalize_imgs,
|
| 264 |
-
progbar_class=progbar_class,
|
| 265 |
-
n_workers=n_workers,
|
| 266 |
-
)
|
| 267 |
-
track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
|
| 268 |
-
return track_graph
|
| 269 |
-
|
| 270 |
-
def track_from_disk(
|
| 271 |
-
self,
|
| 272 |
-
imgs_path: Path,
|
| 273 |
-
masks_path: Path,
|
| 274 |
-
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
|
| 275 |
-
normalize_imgs: bool = True,
|
| 276 |
-
**kwargs,
|
| 277 |
-
) -> Tuple[TrackGraph, np.ndarray]:
|
| 278 |
-
"""Track objects directly from image and mask files on disk.
|
| 279 |
-
|
| 280 |
-
This method supports both single tiff files and directories
|
| 281 |
-
|
| 282 |
-
Args:
|
| 283 |
-
imgs_path: Path to input images. Can be:
|
| 284 |
-
- Directory containing numbered tiff files of shape (C),(Z),Y,X
|
| 285 |
-
- Single tiff file with time series of shape T,(C),(Z),Y,X
|
| 286 |
-
masks_path: Path to mask files. Can be:
|
| 287 |
-
- Directory containing numbered tiff files of shape (Z),Y,X
|
| 288 |
-
- Single tiff file with time series of shape T,(Z),Y,X
|
| 289 |
-
mode: Tracking mode:
|
| 290 |
-
- "greedy_nodiv": Fast greedy linking without division
|
| 291 |
-
- "greedy": Fast greedy linking with division
|
| 292 |
-
- "ilp": Integer Linear Programming based linking (more accurate but slower)
|
| 293 |
-
normalize_imgs: Whether to normalize the images.
|
| 294 |
-
**kwargs: Additional arguments passed to tracking algorithm.
|
| 295 |
-
|
| 296 |
-
Returns:
|
| 297 |
-
Tuple of (TrackGraph, tracked masks).
|
| 298 |
-
"""
|
| 299 |
-
if not imgs_path.exists():
|
| 300 |
-
raise FileNotFoundError(f"{imgs_path=} does not exist.")
|
| 301 |
-
if not masks_path.exists():
|
| 302 |
-
raise FileNotFoundError(f"{masks_path=} does not exist.")
|
| 303 |
-
|
| 304 |
-
if imgs_path.is_dir():
|
| 305 |
-
imgs = load_tiff_timeseries(imgs_path)
|
| 306 |
-
else:
|
| 307 |
-
imgs = tifffile.imread(imgs_path)
|
| 308 |
-
|
| 309 |
-
if masks_path.is_dir():
|
| 310 |
-
masks = load_tiff_timeseries(masks_path)
|
| 311 |
-
else:
|
| 312 |
-
masks = tifffile.imread(masks_path)
|
| 313 |
-
|
| 314 |
-
if len(imgs) != len(masks):
|
| 315 |
-
raise RuntimeError(
|
| 316 |
-
f"#imgs and #masks do not match. Found {len(imgs)} images,"
|
| 317 |
-
f" {len(masks)} masks."
|
| 318 |
-
)
|
| 319 |
-
|
| 320 |
-
if imgs.ndim - 1 == masks.ndim:
|
| 321 |
-
if imgs[1] == 1:
|
| 322 |
-
logger.info(
|
| 323 |
-
"Found a channel dimension with a single channel. Removing dim."
|
| 324 |
-
)
|
| 325 |
-
masks = np.squeeze(masks, 1)
|
| 326 |
-
else:
|
| 327 |
-
raise RuntimeError(
|
| 328 |
-
"Trackastra currently only supports single channel images."
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
if imgs.shape != masks.shape:
|
| 332 |
-
raise RuntimeError(
|
| 333 |
-
f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
|
| 334 |
-
)
|
| 335 |
-
|
| 336 |
-
return self.track(
|
| 337 |
-
imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs
|
| 338 |
-
), masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/model_parts.py
DELETED
|
@@ -1,287 +0,0 @@
|
|
| 1 |
-
"""Transformer class."""
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
import math
|
| 5 |
-
from typing import Literal
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
from torch import nn
|
| 10 |
-
|
| 11 |
-
from .rope import RotaryPositionalEncoding
|
| 12 |
-
|
| 13 |
-
from typing import Tuple
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _pos_embed_fourier1d_init(
|
| 19 |
-
cutoff: float = 256, n: int = 32, cutoff_start: float = 1
|
| 20 |
-
):
|
| 21 |
-
return (
|
| 22 |
-
torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n))
|
| 23 |
-
.unsqueeze(0)
|
| 24 |
-
.unsqueeze(0)
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class FeedForward(nn.Module):
|
| 29 |
-
def __init__(self, d_model, expand: float = 2, bias: bool = True):
|
| 30 |
-
super().__init__()
|
| 31 |
-
self.fc1 = nn.Linear(d_model, int(d_model * expand))
|
| 32 |
-
self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias)
|
| 33 |
-
self.act = nn.GELU()
|
| 34 |
-
|
| 35 |
-
def forward(self, x):
|
| 36 |
-
return self.fc2(self.act(self.fc1(x)))
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class PositionalEncoding(nn.Module):
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
cutoffs: Tuple[float] = (256,),
|
| 43 |
-
n_pos: Tuple[int] = (32,),
|
| 44 |
-
cutoffs_start=None,
|
| 45 |
-
):
|
| 46 |
-
"""Positional encoding with given cutoff and number of frequencies for each dimension.
|
| 47 |
-
number of dimension is inferred from the length of cutoffs and n_pos.
|
| 48 |
-
"""
|
| 49 |
-
super().__init__()
|
| 50 |
-
if cutoffs_start is None:
|
| 51 |
-
cutoffs_start = (1,) * len(cutoffs)
|
| 52 |
-
|
| 53 |
-
assert len(cutoffs) == len(n_pos)
|
| 54 |
-
self.freqs = nn.ParameterList([
|
| 55 |
-
nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
|
| 56 |
-
for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start)
|
| 57 |
-
])
|
| 58 |
-
|
| 59 |
-
def forward(self, coords: torch.Tensor):
|
| 60 |
-
_B, _N, D = coords.shape
|
| 61 |
-
assert D == len(self.freqs)
|
| 62 |
-
embed = torch.cat(
|
| 63 |
-
tuple(
|
| 64 |
-
torch.cat(
|
| 65 |
-
(
|
| 66 |
-
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq),
|
| 67 |
-
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq),
|
| 68 |
-
),
|
| 69 |
-
axis=-1,
|
| 70 |
-
)
|
| 71 |
-
/ math.sqrt(len(freq))
|
| 72 |
-
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
|
| 73 |
-
),
|
| 74 |
-
axis=-1,
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
return embed
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class NoPositionalEncoding(nn.Module):
|
| 81 |
-
def __init__(self, d):
|
| 82 |
-
"""One learnable input token that ignores positional information."""
|
| 83 |
-
super().__init__()
|
| 84 |
-
self.d = d
|
| 85 |
-
# self.token = nn.Parameter(torch.randn(d))
|
| 86 |
-
|
| 87 |
-
def forward(self, coords: torch.Tensor):
|
| 88 |
-
B, N, _ = coords.shape
|
| 89 |
-
return (
|
| 90 |
-
# torch.ones((B, N, self.d), device=coords.device) * 0.1
|
| 91 |
-
# torch.randn((1, 1, self.d), device=coords.device).expand(B, N, -1) * 0.01
|
| 92 |
-
torch.randn((B, N, self.d), device=coords.device) * 0.01
|
| 93 |
-
+ torch.randn((1, 1, self.d), device=coords.device).expand(B, N, -1) * 0.1
|
| 94 |
-
)
|
| 95 |
-
# return self.token.view(1, 1, -1).expand(B, N, -1)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def _bin_init_exp(cutoff: float, n: int):
|
| 99 |
-
return torch.exp(torch.linspace(0, math.log(cutoff + 1), n))
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def _bin_init_linear(cutoff: float, n: int):
|
| 103 |
-
return torch.linspace(-cutoff, cutoff, n)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class RelativePositionalBias(nn.Module):
|
| 107 |
-
def __init__(
|
| 108 |
-
self,
|
| 109 |
-
n_head: int,
|
| 110 |
-
cutoff_spatial: float,
|
| 111 |
-
cutoff_temporal: float,
|
| 112 |
-
n_spatial: int = 32,
|
| 113 |
-
n_temporal: int = 16,
|
| 114 |
-
):
|
| 115 |
-
"""Learnt relative positional bias to add to self-attention matrix.
|
| 116 |
-
|
| 117 |
-
Spatial bins are exponentially spaced, temporal bins are linearly spaced.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
n_head (int): Number of pos bias heads. Equal to number of attention heads
|
| 121 |
-
cutoff_spatial (float): Maximum distance in space.
|
| 122 |
-
cutoff_temporal (float): Maxium distance in time. Equal to window size of transformer.
|
| 123 |
-
n_spatial (int, optional): Number of spatial bins.
|
| 124 |
-
n_temporal (int, optional): Number of temporal bins in each direction. Should be equal to window size. Total = 2 * n_temporal + 1. Defaults to 16.
|
| 125 |
-
"""
|
| 126 |
-
super().__init__()
|
| 127 |
-
self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial)
|
| 128 |
-
self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1)
|
| 129 |
-
self.register_buffer("spatial_bins", self._spatial_bins)
|
| 130 |
-
self.register_buffer("temporal_bins", self._temporal_bins)
|
| 131 |
-
self.n_spatial = n_spatial
|
| 132 |
-
self.n_head = n_head
|
| 133 |
-
self.bias = nn.Parameter(
|
| 134 |
-
-0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head)
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
def forward(self, coords: torch.Tensor):
|
| 138 |
-
_B, _N, _D = coords.shape
|
| 139 |
-
t = coords[..., 0]
|
| 140 |
-
yx = coords[..., 1:]
|
| 141 |
-
temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2)
|
| 142 |
-
spatial_dist = torch.cdist(yx, yx)
|
| 143 |
-
|
| 144 |
-
spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins)
|
| 145 |
-
torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1)
|
| 146 |
-
temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins)
|
| 147 |
-
torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1)
|
| 148 |
-
|
| 149 |
-
# do some index gymnastics such that backward is not super slow
|
| 150 |
-
# https://discuss.pytorch.org/t/how-to-select-multiple-indexes-over-multiple-dimensions-at-the-same-time/98532/2
|
| 151 |
-
idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial
|
| 152 |
-
bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head))
|
| 153 |
-
# -> B, nH, N, N
|
| 154 |
-
bias = bias.transpose(-1, 1)
|
| 155 |
-
return bias
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
class RelativePositionalAttention(nn.Module):
|
| 159 |
-
def __init__(
|
| 160 |
-
self,
|
| 161 |
-
coord_dim: int,
|
| 162 |
-
embed_dim: int,
|
| 163 |
-
n_head: int,
|
| 164 |
-
cutoff_spatial: float = 256,
|
| 165 |
-
cutoff_temporal: float = 16,
|
| 166 |
-
n_spatial: int = 32,
|
| 167 |
-
n_temporal: int = 16,
|
| 168 |
-
dropout: float = 0.0,
|
| 169 |
-
mode: Literal["bias", "rope", "none"] = "bias",
|
| 170 |
-
attn_dist_mode: str = "v0",
|
| 171 |
-
):
|
| 172 |
-
super().__init__()
|
| 173 |
-
|
| 174 |
-
if not embed_dim % (2 * n_head) == 0:
|
| 175 |
-
raise ValueError(
|
| 176 |
-
f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}"
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# qkv projection
|
| 180 |
-
self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 181 |
-
self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 182 |
-
self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 183 |
-
|
| 184 |
-
# output projection
|
| 185 |
-
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 186 |
-
# regularization
|
| 187 |
-
self.dropout = dropout
|
| 188 |
-
self.n_head = n_head
|
| 189 |
-
self.embed_dim = embed_dim
|
| 190 |
-
self.cutoff_spatial = cutoff_spatial
|
| 191 |
-
self.attn_dist_mode = attn_dist_mode
|
| 192 |
-
|
| 193 |
-
if mode == "bias" or mode is True:
|
| 194 |
-
self.pos_bias = RelativePositionalBias(
|
| 195 |
-
n_head=n_head,
|
| 196 |
-
cutoff_spatial=cutoff_spatial,
|
| 197 |
-
cutoff_temporal=cutoff_temporal,
|
| 198 |
-
n_spatial=n_spatial,
|
| 199 |
-
n_temporal=n_temporal,
|
| 200 |
-
)
|
| 201 |
-
elif mode == "rope":
|
| 202 |
-
# each part needs to be divisible by 2
|
| 203 |
-
n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head))
|
| 204 |
-
|
| 205 |
-
self.rot_pos_enc = RotaryPositionalEncoding(
|
| 206 |
-
cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim),
|
| 207 |
-
n_pos=(embed_dim // n_head - coord_dim * n_split,)
|
| 208 |
-
+ (n_split,) * coord_dim,
|
| 209 |
-
)
|
| 210 |
-
elif mode == "none":
|
| 211 |
-
pass
|
| 212 |
-
elif mode is None or mode is False:
|
| 213 |
-
logger.warning(
|
| 214 |
-
"attn_positional_bias is not set (None or False), no positional bias."
|
| 215 |
-
)
|
| 216 |
-
pass
|
| 217 |
-
else:
|
| 218 |
-
raise ValueError(f"Unknown mode {mode}")
|
| 219 |
-
|
| 220 |
-
self._mode = mode
|
| 221 |
-
|
| 222 |
-
def forward(
|
| 223 |
-
self,
|
| 224 |
-
query: torch.Tensor,
|
| 225 |
-
key: torch.Tensor,
|
| 226 |
-
value: torch.Tensor,
|
| 227 |
-
coords: torch.Tensor,
|
| 228 |
-
padding_mask: torch.Tensor = None,
|
| 229 |
-
):
|
| 230 |
-
B, N, D = query.size()
|
| 231 |
-
q = self.q_pro(query) # (B, N, D)
|
| 232 |
-
k = self.k_pro(key) # (B, N, D)
|
| 233 |
-
v = self.v_pro(value) # (B, N, D)
|
| 234 |
-
# (B, nh, N, hs)
|
| 235 |
-
k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
|
| 236 |
-
q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
|
| 237 |
-
v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
|
| 238 |
-
|
| 239 |
-
attn_mask = torch.zeros(
|
| 240 |
-
(B, self.n_head, N, N), device=query.device, dtype=q.dtype
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
# add negative value but not too large to keep mixed precision loss from becoming nan
|
| 244 |
-
attn_ignore_val = -1e3
|
| 245 |
-
|
| 246 |
-
# spatial cutoff
|
| 247 |
-
yx = coords[..., 1:]
|
| 248 |
-
spatial_dist = torch.cdist(yx, yx)
|
| 249 |
-
spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
|
| 250 |
-
attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
|
| 251 |
-
|
| 252 |
-
# dont add positional bias to self-attention if coords is None
|
| 253 |
-
if coords is not None:
|
| 254 |
-
if self._mode == "bias":
|
| 255 |
-
attn_mask = attn_mask + self.pos_bias(coords)
|
| 256 |
-
elif self._mode == "rope":
|
| 257 |
-
q, k = self.rot_pos_enc(q, k, coords)
|
| 258 |
-
else:
|
| 259 |
-
pass
|
| 260 |
-
|
| 261 |
-
if self.attn_dist_mode == "v0":
|
| 262 |
-
dist = torch.cdist(coords, coords, p=2)
|
| 263 |
-
attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
|
| 264 |
-
elif self.attn_dist_mode == "v1":
|
| 265 |
-
attn_mask += torch.exp(
|
| 266 |
-
-5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial
|
| 267 |
-
)
|
| 268 |
-
else:
|
| 269 |
-
raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
|
| 270 |
-
|
| 271 |
-
# if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens)
|
| 272 |
-
if padding_mask is not None:
|
| 273 |
-
ignore_mask = torch.logical_or(
|
| 274 |
-
padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
|
| 275 |
-
).unsqueeze(1)
|
| 276 |
-
attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
|
| 277 |
-
|
| 278 |
-
# self.attn_mask = attn_mask.clone()
|
| 279 |
-
|
| 280 |
-
y = F.scaled_dot_product_attention(
|
| 281 |
-
q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
y = y.transpose(1, 2).contiguous().view(B, N, D)
|
| 285 |
-
# output projection
|
| 286 |
-
y = self.proj(y)
|
| 287 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/model_sd.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Literal, Union, Optional, Tuple
|
| 5 |
-
|
| 6 |
-
import dask.array as da
|
| 7 |
-
import numpy as np
|
| 8 |
-
import tifffile
|
| 9 |
-
import torch
|
| 10 |
-
import yaml
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from ..data import build_windows, get_features, load_tiff_timeseries
|
| 14 |
-
from ..tracking import TrackGraph, build_graph, track_greedy
|
| 15 |
-
from ..utils import normalize
|
| 16 |
-
from .model import TrackingTransformer
|
| 17 |
-
from .predict import predict_windows
|
| 18 |
-
from .pretrained import download_pretrained
|
| 19 |
-
|
| 20 |
-
logging.basicConfig(level=logging.INFO)
|
| 21 |
-
logger = logging.getLogger(__name__)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Trackastra:
|
| 25 |
-
"""A transformer-based tracking model for time-lapse data.
|
| 26 |
-
|
| 27 |
-
Trackastra links segmented objects across time frames by predicting
|
| 28 |
-
associations with a transformer model trained on diverse time-lapse videos.
|
| 29 |
-
|
| 30 |
-
The model takes as input:
|
| 31 |
-
- A sequence of images of shape (T,(Z),Y,X)
|
| 32 |
-
- Corresponding instance segmentation masks of shape (T,(Z),Y,X)
|
| 33 |
-
|
| 34 |
-
It supports multiple tracking modes:
|
| 35 |
-
- greedy_nodiv: Fast greedy linking without division
|
| 36 |
-
- greedy: Fast greedy linking with division
|
| 37 |
-
- ilp: Integer Linear Programming based linking (more accurate but slower)
|
| 38 |
-
|
| 39 |
-
Examples:
|
| 40 |
-
>>> # Load example data
|
| 41 |
-
>>> from trackastra.data import example_data_bacteria
|
| 42 |
-
>>> imgs, masks = example_data_bacteria()
|
| 43 |
-
>>>
|
| 44 |
-
>>> # Load pretrained model and track
|
| 45 |
-
>>> model = Trackastra.from_pretrained("general_2d", device="cuda")
|
| 46 |
-
>>> track_graph = model.track(imgs, masks, mode="greedy")
|
| 47 |
-
"""
|
| 48 |
-
|
| 49 |
-
def __init__(
|
| 50 |
-
self,
|
| 51 |
-
transformer: TrackingTransformer,
|
| 52 |
-
train_args: dict,
|
| 53 |
-
device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
|
| 54 |
-
):
|
| 55 |
-
"""Initialize Trackastra model.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
transformer: The underlying transformer model.
|
| 59 |
-
train_args: Training configuration arguments.
|
| 60 |
-
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
|
| 61 |
-
"""
|
| 62 |
-
if device == "cuda":
|
| 63 |
-
if torch.cuda.is_available():
|
| 64 |
-
self.device = "cuda"
|
| 65 |
-
else:
|
| 66 |
-
logger.info("Cuda not available, falling back to cpu.")
|
| 67 |
-
self.device = "cpu"
|
| 68 |
-
elif device == "mps":
|
| 69 |
-
if (
|
| 70 |
-
torch.backends.mps.is_available()
|
| 71 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
|
| 72 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
|
| 73 |
-
):
|
| 74 |
-
self.device = "mps"
|
| 75 |
-
else:
|
| 76 |
-
logger.info("Mps not available, falling back to cpu.")
|
| 77 |
-
self.device = "cpu"
|
| 78 |
-
elif device == "cpu":
|
| 79 |
-
self.device = "cpu"
|
| 80 |
-
elif device == "automatic" or device is None:
|
| 81 |
-
should_use_mps = (
|
| 82 |
-
torch.backends.mps.is_available()
|
| 83 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
|
| 84 |
-
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
|
| 85 |
-
)
|
| 86 |
-
self.device = (
|
| 87 |
-
"cuda"
|
| 88 |
-
if torch.cuda.is_available()
|
| 89 |
-
else (
|
| 90 |
-
"mps"
|
| 91 |
-
if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
|
| 92 |
-
else "cpu"
|
| 93 |
-
)
|
| 94 |
-
)
|
| 95 |
-
else:
|
| 96 |
-
raise ValueError(f"Device {device} not recognized.")
|
| 97 |
-
|
| 98 |
-
logger.info(f"Using device {self.device}")
|
| 99 |
-
|
| 100 |
-
self.transformer = transformer.to(self.device)
|
| 101 |
-
self.train_args = train_args
|
| 102 |
-
|
| 103 |
-
@classmethod
|
| 104 |
-
def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None):
|
| 105 |
-
"""Load a Trackastra model from a local folder.
|
| 106 |
-
|
| 107 |
-
Args:
|
| 108 |
-
dir: Path to model folder containing:
|
| 109 |
-
- model weights
|
| 110 |
-
- train_config.yaml with training arguments
|
| 111 |
-
device: Device to run model on.
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
Trackastra model instance.
|
| 115 |
-
"""
|
| 116 |
-
# Always load to cpu first
|
| 117 |
-
transformer = TrackingTransformer.from_folder(
|
| 118 |
-
Path(dir).expanduser(), map_location="cpu"
|
| 119 |
-
)
|
| 120 |
-
train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
|
| 121 |
-
return cls(transformer=transformer, train_args=train_args, device=device)
|
| 122 |
-
|
| 123 |
-
@classmethod
|
| 124 |
-
def from_pretrained(
|
| 125 |
-
cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None
|
| 126 |
-
):
|
| 127 |
-
"""Load a pretrained Trackastra model.
|
| 128 |
-
|
| 129 |
-
Available pretrained models are described in detail in pretrained.json.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
name: Name of pretrained model (e.g. "general_2d").
|
| 133 |
-
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
|
| 134 |
-
download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
Trackastra model instance.
|
| 138 |
-
"""
|
| 139 |
-
folder = download_pretrained(name, download_dir)
|
| 140 |
-
# download zip from github to location/name, then unzip
|
| 141 |
-
return cls.from_folder(folder, device=device)
|
| 142 |
-
|
| 143 |
-
def _predict(
|
| 144 |
-
self,
|
| 145 |
-
imgs: Union[np.ndarray, da.Array],
|
| 146 |
-
masks: Union[np.ndarray, da.Array],
|
| 147 |
-
edge_threshold: float = 0.05,
|
| 148 |
-
n_workers: int = 0,
|
| 149 |
-
normalize_imgs: bool = True,
|
| 150 |
-
progbar_class=tqdm,
|
| 151 |
-
):
|
| 152 |
-
logger.info("Predicting weights for candidate graph")
|
| 153 |
-
if normalize_imgs:
|
| 154 |
-
if isinstance(imgs, da.Array):
|
| 155 |
-
imgs = imgs.map_blocks(normalize)
|
| 156 |
-
else:
|
| 157 |
-
imgs = normalize(imgs)
|
| 158 |
-
|
| 159 |
-
self.transformer.eval()
|
| 160 |
-
|
| 161 |
-
features = get_features(
|
| 162 |
-
detections=masks,
|
| 163 |
-
imgs=imgs,
|
| 164 |
-
ndim=self.transformer.config["coord_dim"],
|
| 165 |
-
n_workers=n_workers,
|
| 166 |
-
progbar_class=progbar_class,
|
| 167 |
-
)
|
| 168 |
-
logger.info("Building windows")
|
| 169 |
-
windows = build_windows(
|
| 170 |
-
features,
|
| 171 |
-
window_size=self.transformer.config["window"],
|
| 172 |
-
progbar_class=progbar_class,
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
logger.info("Predicting windows")
|
| 176 |
-
predictions = predict_windows(
|
| 177 |
-
windows=windows,
|
| 178 |
-
features=features,
|
| 179 |
-
model=self.transformer,
|
| 180 |
-
edge_threshold=edge_threshold,
|
| 181 |
-
spatial_dim=masks.ndim - 1,
|
| 182 |
-
progbar_class=progbar_class,
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
return predictions
|
| 186 |
-
|
| 187 |
-
def _track_from_predictions(
|
| 188 |
-
self,
|
| 189 |
-
predictions,
|
| 190 |
-
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
|
| 191 |
-
use_distance: bool = False,
|
| 192 |
-
max_distance: int = 256,
|
| 193 |
-
max_neighbors: int = 10,
|
| 194 |
-
delta_t: int = 1,
|
| 195 |
-
**kwargs,
|
| 196 |
-
):
|
| 197 |
-
logger.info("Running greedy tracker")
|
| 198 |
-
nodes = predictions["nodes"]
|
| 199 |
-
weights = predictions["weights"]
|
| 200 |
-
|
| 201 |
-
candidate_graph = build_graph(
|
| 202 |
-
nodes=nodes,
|
| 203 |
-
weights=weights,
|
| 204 |
-
use_distance=use_distance,
|
| 205 |
-
max_distance=max_distance,
|
| 206 |
-
max_neighbors=max_neighbors,
|
| 207 |
-
delta_t=delta_t,
|
| 208 |
-
)
|
| 209 |
-
if mode == "greedy":
|
| 210 |
-
return track_greedy(candidate_graph)
|
| 211 |
-
elif mode == "greedy_nodiv":
|
| 212 |
-
return track_greedy(candidate_graph, allow_divisions=False)
|
| 213 |
-
elif mode == "ilp":
|
| 214 |
-
from trackastra.tracking.ilp import track_ilp
|
| 215 |
-
|
| 216 |
-
return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
|
| 217 |
-
else:
|
| 218 |
-
raise ValueError(f"Tracking mode {mode} does not exist.")
|
| 219 |
-
|
| 220 |
-
def track(
|
| 221 |
-
self,
|
| 222 |
-
imgs: Union[np.ndarray, da.Array],
|
| 223 |
-
masks: Union[np.ndarray, da.Array],
|
| 224 |
-
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
|
| 225 |
-
normalize_imgs: bool = True,
|
| 226 |
-
progbar_class=tqdm,
|
| 227 |
-
n_workers: int = 0,
|
| 228 |
-
**kwargs,
|
| 229 |
-
) -> TrackGraph:
|
| 230 |
-
"""Track objects across time frames.
|
| 231 |
-
|
| 232 |
-
This method links segmented objects across time frames using the specified
|
| 233 |
-
tracking mode. No hyperparameters need to be chosen beyond the tracking mode.
|
| 234 |
-
|
| 235 |
-
Args:
|
| 236 |
-
imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array)
|
| 237 |
-
masks: Instance segmentation masks of shape (T,(Z),Y,X).
|
| 238 |
-
mode: Tracking mode:
|
| 239 |
-
- "greedy_nodiv": Fast greedy linking without division
|
| 240 |
-
- "greedy": Fast greedy linking with division
|
| 241 |
-
- "ilp": Integer Linear Programming based linking (more accurate but slower)
|
| 242 |
-
progbar_class: Progress bar class to use.
|
| 243 |
-
n_workers: Number of worker processes for feature extraction.
|
| 244 |
-
normalize_imgs: Whether to normalize the images.
|
| 245 |
-
**kwargs: Additional arguments passed to tracking algorithm.
|
| 246 |
-
|
| 247 |
-
Returns:
|
| 248 |
-
TrackGraph containing the tracking results.
|
| 249 |
-
"""
|
| 250 |
-
if not imgs.shape == masks.shape:
|
| 251 |
-
raise RuntimeError(
|
| 252 |
-
f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
if not imgs.ndim == self.transformer.config["coord_dim"] + 1:
|
| 256 |
-
raise RuntimeError(
|
| 257 |
-
f"images should be a sequence of {self.transformer.config['coord_dim']}D images"
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
predictions = self._predict(
|
| 261 |
-
imgs,
|
| 262 |
-
masks,
|
| 263 |
-
normalize_imgs=normalize_imgs,
|
| 264 |
-
progbar_class=progbar_class,
|
| 265 |
-
n_workers=n_workers,
|
| 266 |
-
)
|
| 267 |
-
track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
|
| 268 |
-
return track_graph
|
| 269 |
-
|
| 270 |
-
def track_from_disk(
|
| 271 |
-
self,
|
| 272 |
-
imgs_path: Path,
|
| 273 |
-
masks_path: Path,
|
| 274 |
-
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
|
| 275 |
-
normalize_imgs: bool = True,
|
| 276 |
-
**kwargs,
|
| 277 |
-
) -> Tuple[TrackGraph, np.ndarray]:
|
| 278 |
-
"""Track objects directly from image and mask files on disk.
|
| 279 |
-
|
| 280 |
-
This method supports both single tiff files and directories
|
| 281 |
-
|
| 282 |
-
Args:
|
| 283 |
-
imgs_path: Path to input images. Can be:
|
| 284 |
-
- Directory containing numbered tiff files of shape (C),(Z),Y,X
|
| 285 |
-
- Single tiff file with time series of shape T,(C),(Z),Y,X
|
| 286 |
-
masks_path: Path to mask files. Can be:
|
| 287 |
-
- Directory containing numbered tiff files of shape (Z),Y,X
|
| 288 |
-
- Single tiff file with time series of shape T,(Z),Y,X
|
| 289 |
-
mode: Tracking mode:
|
| 290 |
-
- "greedy_nodiv": Fast greedy linking without division
|
| 291 |
-
- "greedy": Fast greedy linking with division
|
| 292 |
-
- "ilp": Integer Linear Programming based linking (more accurate but slower)
|
| 293 |
-
normalize_imgs: Whether to normalize the images.
|
| 294 |
-
**kwargs: Additional arguments passed to tracking algorithm.
|
| 295 |
-
|
| 296 |
-
Returns:
|
| 297 |
-
Tuple of (TrackGraph, tracked masks).
|
| 298 |
-
"""
|
| 299 |
-
if not imgs_path.exists():
|
| 300 |
-
raise FileNotFoundError(f"{imgs_path=} does not exist.")
|
| 301 |
-
if not masks_path.exists():
|
| 302 |
-
raise FileNotFoundError(f"{masks_path=} does not exist.")
|
| 303 |
-
|
| 304 |
-
if imgs_path.is_dir():
|
| 305 |
-
imgs = load_tiff_timeseries(imgs_path)
|
| 306 |
-
else:
|
| 307 |
-
imgs = tifffile.imread(imgs_path)
|
| 308 |
-
|
| 309 |
-
if masks_path.is_dir():
|
| 310 |
-
masks = load_tiff_timeseries(masks_path)
|
| 311 |
-
else:
|
| 312 |
-
masks = tifffile.imread(masks_path)
|
| 313 |
-
|
| 314 |
-
if len(imgs) != len(masks):
|
| 315 |
-
raise RuntimeError(
|
| 316 |
-
f"#imgs and #masks do not match. Found {len(imgs)} images,"
|
| 317 |
-
f" {len(masks)} masks."
|
| 318 |
-
)
|
| 319 |
-
|
| 320 |
-
if imgs.ndim - 1 == masks.ndim:
|
| 321 |
-
if imgs[1] == 1:
|
| 322 |
-
logger.info(
|
| 323 |
-
"Found a channel dimension with a single channel. Removing dim."
|
| 324 |
-
)
|
| 325 |
-
masks = np.squeeze(masks, 1)
|
| 326 |
-
else:
|
| 327 |
-
raise RuntimeError(
|
| 328 |
-
"Trackastra currently only supports single channel images."
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
if imgs.shape != masks.shape:
|
| 332 |
-
raise RuntimeError(
|
| 333 |
-
f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
|
| 334 |
-
)
|
| 335 |
-
|
| 336 |
-
return self.track(
|
| 337 |
-
imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs
|
| 338 |
-
), masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/predict.py
DELETED
|
@@ -1,188 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import warnings
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
from scipy.sparse import SparseEfficiencyWarning, csr_array
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
from typing import List
|
| 9 |
-
|
| 10 |
-
# TODO fix circular import
|
| 11 |
-
# from .model import TrackingTransformer
|
| 12 |
-
# from trackastra.data import WRFeatures
|
| 13 |
-
|
| 14 |
-
warnings.simplefilter("ignore", SparseEfficiencyWarning)
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
logger.setLevel(logging.INFO)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def predict(batch, model):
|
| 21 |
-
"""Predict association scores between objects in a batch.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
batch: Dictionary containing:
|
| 25 |
-
- features: Object features array
|
| 26 |
-
- coords: Object coordinates array
|
| 27 |
-
- timepoints: Time points array
|
| 28 |
-
model: TrackingTransformer model to use for prediction.
|
| 29 |
-
|
| 30 |
-
Returns:
|
| 31 |
-
Array of association scores between objects.
|
| 32 |
-
"""
|
| 33 |
-
feats = torch.from_numpy(batch["features"])
|
| 34 |
-
coords = torch.from_numpy(batch["coords"])
|
| 35 |
-
timepoints = torch.from_numpy(batch["timepoints"]).long()
|
| 36 |
-
# Hack that assumes that all parameters of a model are on the same device
|
| 37 |
-
device = next(model.parameters()).device
|
| 38 |
-
feats = feats.unsqueeze(0).to(device)
|
| 39 |
-
timepoints = timepoints.unsqueeze(0).to(device)
|
| 40 |
-
coords = coords.unsqueeze(0).to(device)
|
| 41 |
-
|
| 42 |
-
# Concat timepoints to coordinates
|
| 43 |
-
coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
|
| 44 |
-
with torch.no_grad():
|
| 45 |
-
A = model(coords, features=feats)
|
| 46 |
-
A = model.normalize_output(A, timepoints, coords)
|
| 47 |
-
|
| 48 |
-
# # Spatially far entries should not influence the causal normalization
|
| 49 |
-
# dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
|
| 50 |
-
# invalid = dist > model.config["spatial_pos_cutoff"]
|
| 51 |
-
# A[invalid] = -torch.inf
|
| 52 |
-
|
| 53 |
-
A = A.squeeze(0).detach().cpu().numpy()
|
| 54 |
-
|
| 55 |
-
return A
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def predict_windows(
|
| 59 |
-
windows: List[dict],
|
| 60 |
-
# features: list[WRFeatures],
|
| 61 |
-
# model: TrackingTransformer,
|
| 62 |
-
features: list,
|
| 63 |
-
model,
|
| 64 |
-
intra_window_weight: float = 0,
|
| 65 |
-
delta_t: int = 1,
|
| 66 |
-
edge_threshold: float = 0.05,
|
| 67 |
-
spatial_dim: int = 3,
|
| 68 |
-
progbar_class=tqdm,
|
| 69 |
-
) -> dict:
|
| 70 |
-
"""Predict associations between objects across sliding windows.
|
| 71 |
-
|
| 72 |
-
This function processes a sequence of sliding windows to predict associations
|
| 73 |
-
between objects across time frames. It handles:
|
| 74 |
-
- Object tracking across time
|
| 75 |
-
- Weight normalization across windows
|
| 76 |
-
- Edge thresholding
|
| 77 |
-
- Time-based filtering
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
windows: List of window dictionaries containing:
|
| 81 |
-
- timepoints: Array of time points
|
| 82 |
-
- labels: Array of object labels
|
| 83 |
-
- features: Object features
|
| 84 |
-
- coords: Object coordinates
|
| 85 |
-
features: List of feature objects containing:
|
| 86 |
-
- labels: Object labels
|
| 87 |
-
- timepoints: Time points
|
| 88 |
-
- coords: Object coordinates
|
| 89 |
-
model: TrackingTransformer model to use for prediction.
|
| 90 |
-
intra_window_weight: Weight factor for objects in middle of window. Defaults to 0.
|
| 91 |
-
delta_t: Maximum time difference between objects to consider. Defaults to 1.
|
| 92 |
-
edge_threshold: Minimum association score to consider. Defaults to 0.05.
|
| 93 |
-
spatial_dim: Dimensionality of input masks. May be less than model.coord_dim.
|
| 94 |
-
progbar_class: Progress bar class to use. Defaults to tqdm.
|
| 95 |
-
|
| 96 |
-
Returns:
|
| 97 |
-
Dictionary containing:
|
| 98 |
-
- nodes: List of node properties (id, coords, time, label)
|
| 99 |
-
- weights: Tuple of ((node_i, node_j), weight) pairs
|
| 100 |
-
"""
|
| 101 |
-
# first get all objects/coords
|
| 102 |
-
time_labels_to_id = dict()
|
| 103 |
-
node_properties = list()
|
| 104 |
-
max_id = np.sum([len(f.labels) for f in features])
|
| 105 |
-
|
| 106 |
-
all_timepoints = np.concatenate([f.timepoints for f in features])
|
| 107 |
-
all_labels = np.concatenate([f.labels for f in features])
|
| 108 |
-
all_coords = np.concatenate([f.coords for f in features])
|
| 109 |
-
all_coords = all_coords[:, -spatial_dim:]
|
| 110 |
-
|
| 111 |
-
for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)):
|
| 112 |
-
time_labels_to_id[(t, la)] = i
|
| 113 |
-
node_properties.append(
|
| 114 |
-
dict(
|
| 115 |
-
id=i,
|
| 116 |
-
coords=tuple(c),
|
| 117 |
-
time=t,
|
| 118 |
-
# index=ix,
|
| 119 |
-
label=la,
|
| 120 |
-
)
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
# create assoc matrix between ids
|
| 124 |
-
sp_weights, sp_accum = (
|
| 125 |
-
csr_array((max_id, max_id), dtype=np.float32),
|
| 126 |
-
csr_array((max_id, max_id), dtype=np.float32),
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
for t in progbar_class(
|
| 130 |
-
range(len(windows)),
|
| 131 |
-
desc="Computing associations",
|
| 132 |
-
):
|
| 133 |
-
# This assumes that the samples in the dataset are ordered by time and start at 0.
|
| 134 |
-
batch = windows[t]
|
| 135 |
-
timepoints = batch["timepoints"]
|
| 136 |
-
labels = batch["labels"]
|
| 137 |
-
|
| 138 |
-
A = predict(batch, model)
|
| 139 |
-
|
| 140 |
-
dt = timepoints[None, :] - timepoints[:, None]
|
| 141 |
-
time_mask = np.logical_and(dt <= delta_t, dt > 0)
|
| 142 |
-
A[~time_mask] = 0
|
| 143 |
-
ii, jj = np.where(A >= edge_threshold)
|
| 144 |
-
|
| 145 |
-
if len(ii) == 0:
|
| 146 |
-
continue
|
| 147 |
-
|
| 148 |
-
labels_ii = labels[ii]
|
| 149 |
-
labels_jj = labels[jj]
|
| 150 |
-
ts_ii = timepoints[ii]
|
| 151 |
-
ts_jj = timepoints[jj]
|
| 152 |
-
nodes_ii = np.array(
|
| 153 |
-
tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii))
|
| 154 |
-
)
|
| 155 |
-
nodes_jj = np.array(
|
| 156 |
-
tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj))
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
# weight middle parts higher
|
| 160 |
-
t_middle = t + (model.config["window"] - 1) / 2
|
| 161 |
-
ddt = timepoints[:, None] - t_middle * np.ones_like(dt)
|
| 162 |
-
window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1
|
| 163 |
-
# window_weight = np.exp(4*A) # smooth max
|
| 164 |
-
sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj]
|
| 165 |
-
sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj]
|
| 166 |
-
|
| 167 |
-
sp_weights_coo = sp_weights.tocoo()
|
| 168 |
-
sp_accum_coo = sp_accum.tocoo()
|
| 169 |
-
assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose(
|
| 170 |
-
sp_weights_coo.row, sp_accum_coo.row
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
# Normalize weights by the number of times they were written from different sliding window positions
|
| 174 |
-
weights = tuple(
|
| 175 |
-
((i, j), v / a)
|
| 176 |
-
for i, j, v, a in zip(
|
| 177 |
-
sp_weights_coo.row,
|
| 178 |
-
sp_weights_coo.col,
|
| 179 |
-
sp_weights_coo.data,
|
| 180 |
-
sp_accum_coo.data,
|
| 181 |
-
)
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
results = dict()
|
| 185 |
-
results["nodes"] = node_properties
|
| 186 |
-
results["weights"] = weights
|
| 187 |
-
|
| 188 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/pretrained.json
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"general_2d": {
|
| 3 |
-
"tags": ["cells, nuclei, bacteria, epithelial, yeast, particles"],
|
| 4 |
-
"dimensionality": [2],
|
| 5 |
-
"description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane, budding yeast cells (PhC), fluorescent particles, .",
|
| 6 |
-
"url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
|
| 7 |
-
"datasets": {
|
| 8 |
-
"Subset of Cell Tracking Challenge 2d datasets": {
|
| 9 |
-
"url": "https://celltrackingchallenge.net/2d-datasets/",
|
| 10 |
-
"reference": "MaΕ‘ka M, Ulman V, Delgado-Rodriguez P, GΓ³mez-de-Mariscal E, NeΔasovΓ‘ T, Guerrero PeΓ±a FA, Ren TI, Meyerowitz EM, Scherr T, LΓΆffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20."
|
| 11 |
-
},
|
| 12 |
-
"Bacteria van Vliet": {
|
| 13 |
-
"url": "https://zenodo.org/records/268921",
|
| 14 |
-
"reference": "van Vliet S, Winkler AR, Spriewald S, Stecher B, Ackermann M. Spatially correlated gene expression in bacterial groups: the role of lineage history, spatial gradients, and cell-cell interactions. Cell systems. 2018 Apr 25;6(4):496-507."
|
| 15 |
-
},
|
| 16 |
-
"Bacteria ObiWan-Microbi": {
|
| 17 |
-
"url": "https://zenodo.org/records/7260137",
|
| 18 |
-
"reference": "Seiffarth J, Scherr T, Wollenhaupt B, Neumann O, Scharr H, Kohlheyer D, Mikut R, NΓΆh K. ObiWan-Microbi: OMERO-based integrated workflow for annotating microbes in the cloud. SoftwareX. 2024 May 1;26:101638."
|
| 19 |
-
},
|
| 20 |
-
"Bacteria Persat": {
|
| 21 |
-
"url": "https://www.p-lab.science",
|
| 22 |
-
"reference": "Datasets kindly provided by Persat lab, EPFL."
|
| 23 |
-
},
|
| 24 |
-
"DeepCell": {
|
| 25 |
-
"url": "https://datasets.deepcell.org/data",
|
| 26 |
-
"reference": "Schwartz, M, Moen E, Miller G, Dougherty T, Borba E, Ding R, Graf W, Pao E, Van Valen D. Caliban: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning. Biorxiv. 2023 Sept 13:803205."
|
| 27 |
-
},
|
| 28 |
-
"Ker phase contrast": {
|
| 29 |
-
"url": "https://osf.io/ysaq2/",
|
| 30 |
-
"reference": "Ker DF, Eom S, Sanami S, Bise R, Pascale C, Yin Z, Huh SI, Osuna-Highley E, Junkers SN, Helfrich CJ, Liang PY. Phase contrast time-lapse microscopy datasets with automated and manual cell tracking annotations. Scientific data. 2018 Nov 13;5(1):1-2."
|
| 31 |
-
},
|
| 32 |
-
"Epithelia benchmark": {
|
| 33 |
-
"reference": "Funke J, Mais L, Champion A, Dye N, Kainmueller D. A benchmark for epithelial cell tracking. InProceedings of The European Conference on Computer Vision (ECCV) Workshops 2018 (pp. 0-0)."
|
| 34 |
-
},
|
| 35 |
-
"T Cells": {
|
| 36 |
-
"url": "https://zenodo.org/records/5206119"
|
| 37 |
-
},
|
| 38 |
-
"Neisseria meningitidis bacterial growth": {
|
| 39 |
-
"url": "https://zenodo.org/records/5419619"
|
| 40 |
-
},
|
| 41 |
-
"Synthetic nuclei": {
|
| 42 |
-
"reference": "Weigert group live cell simulator."
|
| 43 |
-
},
|
| 44 |
-
"Synthetic particles": {
|
| 45 |
-
"reference": "Weigert group particle simulator."
|
| 46 |
-
},
|
| 47 |
-
"Particle Tracking Challenge": {
|
| 48 |
-
"url": "http://bioimageanalysis.org/track/#data",
|
| 49 |
-
"reference": "Chenouard, N., Smal, I., De Chaumont, F., MaΕ‘ka, M., Sbalzarini, I. F., Gong, Y., ... & Meijering, E. (2014). Objective comparison of particle tracking methods. Nature methods, 11(3), 281-289."
|
| 50 |
-
},
|
| 51 |
-
"Yeast Cell-ACDC": {
|
| 52 |
-
"url": "https://zenodo.org/records/6795124",
|
| 53 |
-
"reference": "Padovani, F., MairhΓΆrmann, B., Falter-Braun, P., Lengefeld, J., & Schmoller, K. M. (2022). Segmentation, tracking and cell cycle analysis of live-cell imaging data with Cell-ACDC. BMC biology, 20(1), 174."
|
| 54 |
-
},
|
| 55 |
-
"DeepSea": {
|
| 56 |
-
"url": "https://deepseas.org/datasets/",
|
| 57 |
-
"reference": "Zargari, A., Lodewijk, G. A., Mashhadi, N., Cook, N., Neudorf, C. W., Araghbidikashani, K., ... & Shariati, S. A. (2023). DeepSea is an efficient deep-learning model for single-cell segmentation and tracking in time-lapse microscopy. Cell Reports Methods, 3(6)."
|
| 58 |
-
},
|
| 59 |
-
"Btrack" : {
|
| 60 |
-
"url": "https://rdr.ucl.ac.uk/articles/dataset/Cell_tracking_reference_dataset/16595978",
|
| 61 |
-
"reference": "Ulicna, K., Vallardi, G., Charras, G., & Lowe, A. R. (2021). Automated deep lineage tree analysis using a Bayesian single cell tracking approach. Frontiers in Computer Science, 3, 734559."
|
| 62 |
-
},
|
| 63 |
-
"E. coli in mother machine": {
|
| 64 |
-
"url": "https://zenodo.org/records/11237127",
|
| 65 |
-
"reference": "OβConnor, O. M., & Dunlop, M. J. (2024). Cell-TRACTR: A transformer-based model for end-to-end segmentation and tracking of cells. bioRxiv, 2024-07."
|
| 66 |
-
}
|
| 67 |
-
}
|
| 68 |
-
},
|
| 69 |
-
"ctc": {
|
| 70 |
-
"tags": ["ctc", "Cell Tracking Challenge", "Cell Linking Benchmark"],
|
| 71 |
-
"dimensionality": [2, 3],
|
| 72 |
-
"description": "For tracking Cell Tracking Challenge datasets. This is the successor of the winning model of the ISBI 2024 CTC generalizable linking challenge.",
|
| 73 |
-
"url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
|
| 74 |
-
"datasets": {
|
| 75 |
-
"All Cell Tracking Challenge 2d+3d datasets with available GT and ERR_SEG": {
|
| 76 |
-
"url": "https://celltrackingchallenge.net/3d-datasets/",
|
| 77 |
-
"reference": "MaΕ‘ka M, Ulman V, Delgado-Rodriguez P, GΓ³mez-de-Mariscal E, NeΔasovΓ‘ T, Guerrero PeΓ±a FA, Ren TI, Meyerowitz EM, Scherr T, LΓΆffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20."
|
| 78 |
-
}
|
| 79 |
-
}
|
| 80 |
-
}
|
| 81 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/pretrained.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import shutil
|
| 3 |
-
import tempfile
|
| 4 |
-
import zipfile
|
| 5 |
-
try:
|
| 6 |
-
from importlib.resources import files
|
| 7 |
-
except:
|
| 8 |
-
from importlib_resources import files
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
|
| 11 |
-
import requests
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
from typing import Optional
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
_MODELS = {
|
| 18 |
-
"ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
|
| 19 |
-
"general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def download_and_unzip(url: str, dst: Path):
|
| 24 |
-
# TODO make safe and use tempfile lib
|
| 25 |
-
if dst.exists():
|
| 26 |
-
print(f"{dst} already downloaded, skipping.")
|
| 27 |
-
return
|
| 28 |
-
|
| 29 |
-
# get the name of the zipfile
|
| 30 |
-
zip_base = Path(url.split("/")[-1])
|
| 31 |
-
|
| 32 |
-
with tempfile.TemporaryDirectory() as tmp:
|
| 33 |
-
tmp = Path(tmp)
|
| 34 |
-
zip_file = tmp / zip_base
|
| 35 |
-
# Download the zip file
|
| 36 |
-
download(url, zip_file)
|
| 37 |
-
|
| 38 |
-
# Unzip the file
|
| 39 |
-
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
| 40 |
-
zip_ref.extractall(tmp)
|
| 41 |
-
|
| 42 |
-
shutil.move(tmp / zip_base.stem, dst)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def download(url: str, fname: Path):
|
| 46 |
-
resp = requests.get(url, stream=True)
|
| 47 |
-
total = int(resp.headers.get("content-length", 0))
|
| 48 |
-
# try:
|
| 49 |
-
# with (open(str(fname), "wb") as file,
|
| 50 |
-
# tqdm(
|
| 51 |
-
# desc=str(fname),
|
| 52 |
-
# total=total,
|
| 53 |
-
# unit="iB",
|
| 54 |
-
# unit_scale=True,
|
| 55 |
-
# unit_divisor=1024,
|
| 56 |
-
# ) as bar,):
|
| 57 |
-
# for data in resp.iter_content(chunk_size=1024):
|
| 58 |
-
# size = file.write(data)
|
| 59 |
-
# bar.update(size)
|
| 60 |
-
# except:
|
| 61 |
-
with open(str(fname), "wb") as file, tqdm(
|
| 62 |
-
desc=str(fname),
|
| 63 |
-
total=total,
|
| 64 |
-
unit="iB",
|
| 65 |
-
unit_scale=True,
|
| 66 |
-
unit_divisor=1024,
|
| 67 |
-
) as bar:
|
| 68 |
-
for data in resp.iter_content(chunk_size=1024):
|
| 69 |
-
size = file.write(data)
|
| 70 |
-
bar.update(size)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def download_pretrained(name: str, download_dir: Optional[Path] = None):
|
| 74 |
-
# TODO make safe, introduce versioning
|
| 75 |
-
if download_dir is None:
|
| 76 |
-
download_dir = files("trackastra").joinpath(".models")
|
| 77 |
-
else:
|
| 78 |
-
download_dir = Path(download_dir)
|
| 79 |
-
|
| 80 |
-
download_dir.mkdir(exist_ok=True, parents=True)
|
| 81 |
-
try:
|
| 82 |
-
url = _MODELS[name]
|
| 83 |
-
except KeyError:
|
| 84 |
-
raise ValueError(
|
| 85 |
-
"Pretrained model `name` is not available. Choose from"
|
| 86 |
-
f" {list(_MODELS.keys())}"
|
| 87 |
-
)
|
| 88 |
-
folder = download_dir / name
|
| 89 |
-
download_and_unzip(url=url, dst=folder)
|
| 90 |
-
return folder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/model/rope.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
"""Transformer class."""
|
| 2 |
-
|
| 3 |
-
# from torch_geometric.nn import GATv2Conv
|
| 4 |
-
import math
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn
|
| 8 |
-
from typing import Tuple
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def _pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
|
| 12 |
-
# Maximum initial frequency is 1
|
| 13 |
-
return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# https://github.com/cvg/LightGlue/blob/b1cd942fc4a3a824b6aedff059d84f5c31c297f6/lightglue/lightglue.py#L51
|
| 17 |
-
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 18 |
-
"""Rotate pairs of scalars as 2d vectors by pi/2.
|
| 19 |
-
Refer to eq 34 in https://arxiv.org/pdf/2104.09864.pdf.
|
| 20 |
-
"""
|
| 21 |
-
x = x.unflatten(-1, (-1, 2))
|
| 22 |
-
x1, x2 = x.unbind(dim=-1)
|
| 23 |
-
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class RotaryPositionalEncoding(nn.Module):
|
| 27 |
-
def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
|
| 28 |
-
"""Rotary positional encoding with given cutoff and number of frequencies for each dimension.
|
| 29 |
-
number of dimension is inferred from the length of cutoffs and n_pos.
|
| 30 |
-
|
| 31 |
-
see
|
| 32 |
-
https://arxiv.org/pdf/2104.09864.pdf
|
| 33 |
-
"""
|
| 34 |
-
super().__init__()
|
| 35 |
-
assert len(cutoffs) == len(n_pos)
|
| 36 |
-
if not all(n % 2 == 0 for n in n_pos):
|
| 37 |
-
raise ValueError("n_pos must be even")
|
| 38 |
-
|
| 39 |
-
self._n_dim = len(cutoffs)
|
| 40 |
-
# theta in RoFormer https://arxiv.org/pdf/2104.09864.pdf
|
| 41 |
-
self.freqs = nn.ParameterList([
|
| 42 |
-
nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
|
| 43 |
-
for cutoff, n in zip(cutoffs, n_pos)
|
| 44 |
-
])
|
| 45 |
-
|
| 46 |
-
def get_co_si(self, coords: torch.Tensor):
|
| 47 |
-
_B, _N, D = coords.shape
|
| 48 |
-
assert D == len(self.freqs)
|
| 49 |
-
co = torch.cat(
|
| 50 |
-
tuple(
|
| 51 |
-
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
|
| 52 |
-
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
|
| 53 |
-
),
|
| 54 |
-
axis=-1,
|
| 55 |
-
)
|
| 56 |
-
si = torch.cat(
|
| 57 |
-
tuple(
|
| 58 |
-
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
|
| 59 |
-
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
|
| 60 |
-
),
|
| 61 |
-
axis=-1,
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
return co, si
|
| 65 |
-
|
| 66 |
-
def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
|
| 67 |
-
_B, _N, D = coords.shape
|
| 68 |
-
_B, _H, _N, _C = q.shape
|
| 69 |
-
|
| 70 |
-
if not D == self._n_dim:
|
| 71 |
-
raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
|
| 72 |
-
|
| 73 |
-
co, si = self.get_co_si(coords)
|
| 74 |
-
|
| 75 |
-
co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
|
| 76 |
-
si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
|
| 77 |
-
q2 = q * co + _rotate_half(q) * si
|
| 78 |
-
k2 = k * co + _rotate_half(k) * si
|
| 79 |
-
|
| 80 |
-
return q2, k2
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
model = RotaryPositionalEncoding((256, 256), (32, 32))
|
| 85 |
-
|
| 86 |
-
x = 100 * torch.rand(1, 17, 2)
|
| 87 |
-
q = torch.rand(1, 4, 17, 64)
|
| 88 |
-
k = torch.rand(1, 4, 17, 64)
|
| 89 |
-
|
| 90 |
-
q1, k1 = model(q, k, x)
|
| 91 |
-
A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2)
|
| 92 |
-
|
| 93 |
-
q2, k2 = model(q, k, x + 10)
|
| 94 |
-
A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/trackastra/utils/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# ruff: noqa: F401
|
| 2 |
-
|
| 3 |
-
from .utils import (
|
| 4 |
-
blockwise_causal_norm,
|
| 5 |
-
blockwise_sum,
|
| 6 |
-
normalize,
|
| 7 |
-
normalize_01,
|
| 8 |
-
preallocate_memory,
|
| 9 |
-
random_label_cmap,
|
| 10 |
-
render_label,
|
| 11 |
-
seed,
|
| 12 |
-
str2bool,
|
| 13 |
-
str2path,
|
| 14 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/tra_post_model/{trackastra/tracking β tracking}/__init__.py
RENAMED
|
@@ -7,9 +7,6 @@ from .tracking import (
|
|
| 7 |
)
|
| 8 |
from .utils import (
|
| 9 |
ctc_to_graph,
|
| 10 |
-
ctc_to_napari_tracks,
|
| 11 |
graph_to_ctc,
|
| 12 |
graph_to_edge_table,
|
| 13 |
-
graph_to_napari_tracks,
|
| 14 |
-
linear_chains,
|
| 15 |
)
|
|
|
|
| 7 |
)
|
| 8 |
from .utils import (
|
| 9 |
ctc_to_graph,
|
|
|
|
| 10 |
graph_to_ctc,
|
| 11 |
graph_to_edge_table,
|
|
|
|
|
|
|
| 12 |
)
|
models/tra_post_model/{trackastra/tracking β tracking}/ilp.py
RENAMED
|
File without changes
|
models/tra_post_model/{trackastra/tracking β tracking}/track_graph.py
RENAMED
|
File without changes
|
models/tra_post_model/{trackastra/tracking β tracking}/tracking.py
RENAMED
|
@@ -9,8 +9,6 @@ from tqdm import tqdm
|
|
| 9 |
from .track_graph import TrackGraph
|
| 10 |
from typing import Optional, Tuple
|
| 11 |
|
| 12 |
-
# from trackastra.tracking import graph_to_napari_tracks, graph_to_ctc
|
| 13 |
-
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
|
@@ -43,9 +41,6 @@ def track_greedy(
|
|
| 43 |
|
| 44 |
solution_graph = nx.DiGraph()
|
| 45 |
|
| 46 |
-
# TODO bring back
|
| 47 |
-
# if args.gt_as_dets:
|
| 48 |
-
# solution_graph.add_nodes_from(candidate_graph.nodes(data=True))
|
| 49 |
|
| 50 |
edges = candidate_graph.edges(data=True)
|
| 51 |
edges = sorted(
|
|
@@ -75,12 +70,9 @@ def track_greedy(
|
|
| 75 |
# otherwise add to solution
|
| 76 |
copy_edge(edge, candidate_graph, solution_graph)
|
| 77 |
|
| 78 |
-
# df, masks = graph_to_ctc(solution_graph, masks_original)
|
| 79 |
-
# tracks, tracks_graph, _ = graph_to_napari_tracks(solution_graph)
|
| 80 |
|
| 81 |
return solution_graph
|
| 82 |
-
|
| 83 |
-
# return df, masks, solution_graph, tracks_graph, tracks, candidate_graph
|
| 84 |
|
| 85 |
|
| 86 |
def build_graph(
|
|
|
|
| 9 |
from .track_graph import TrackGraph
|
| 10 |
from typing import Optional, Tuple
|
| 11 |
|
|
|
|
|
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
|
|
|
|
| 41 |
|
| 42 |
solution_graph = nx.DiGraph()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
edges = candidate_graph.edges(data=True)
|
| 46 |
edges = sorted(
|
|
|
|
| 70 |
# otherwise add to solution
|
| 71 |
copy_edge(edge, candidate_graph, solution_graph)
|
| 72 |
|
|
|
|
|
|
|
| 73 |
|
| 74 |
return solution_graph
|
| 75 |
+
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
def build_graph(
|
models/tra_post_model/{trackastra/tracking β tracking}/utils.py
RENAMED
|
@@ -14,38 +14,8 @@ logger = logging.getLogger(__name__)
|
|
| 14 |
logger.setLevel(logging.INFO)
|
| 15 |
|
| 16 |
|
| 17 |
-
class FoundTracks(Exception):
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def ctc_to_napari_tracks(segmentation: np.ndarray, man_track: pd.DataFrame):
|
| 22 |
-
"""Convert tracks in CTC format to tracks in napari format.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
segmentation: Dims time, spatial_0, ... , spatial_n
|
| 26 |
-
man_track: columns id, start, end, parent
|
| 27 |
-
"""
|
| 28 |
-
tracks = []
|
| 29 |
-
for t, frame in tqdm(
|
| 30 |
-
enumerate(segmentation),
|
| 31 |
-
total=len(segmentation),
|
| 32 |
-
leave=False,
|
| 33 |
-
desc="Computing centroids",
|
| 34 |
-
):
|
| 35 |
-
for r in regionprops(frame):
|
| 36 |
-
tracks.append((r.label, t, *r.centroid))
|
| 37 |
-
|
| 38 |
-
tracks_graph = {}
|
| 39 |
-
for idx, _, _, parent in tqdm(
|
| 40 |
-
man_track.to_numpy(),
|
| 41 |
-
desc="Converting CTC to napari tracks",
|
| 42 |
-
leave=False,
|
| 43 |
-
):
|
| 44 |
-
if parent != 0:
|
| 45 |
-
tracks_graph[idx] = [parent]
|
| 46 |
-
|
| 47 |
-
return tracks, tracks_graph
|
| 48 |
-
|
| 49 |
|
| 50 |
class CtcTracklet:
|
| 51 |
def __init__(self, parent: int, nodes: List[int], start_frame: int) -> None:
|
|
@@ -125,77 +95,77 @@ def ctc_tracklets(G: nx.DiGraph, frame_attribute: str = "time") -> List[CtcTrack
|
|
| 125 |
return tracklets
|
| 126 |
|
| 127 |
|
| 128 |
-
def linear_chains(G: nx.DiGraph):
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def graph_to_napari_tracks(
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
):
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
|
| 200 |
|
| 201 |
def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray):
|
|
|
|
| 14 |
logger.setLevel(logging.INFO)
|
| 15 |
|
| 16 |
|
| 17 |
+
# class FoundTracks(Exception):
|
| 18 |
+
# pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class CtcTracklet:
|
| 21 |
def __init__(self, parent: int, nodes: List[int], start_frame: int) -> None:
|
|
|
|
| 95 |
return tracklets
|
| 96 |
|
| 97 |
|
| 98 |
+
# def linear_chains(G: nx.DiGraph):
|
| 99 |
+
# """Find all linear chains in a tree/graph, i.e. paths that.
|
| 100 |
+
|
| 101 |
+
# i) either start/end at a node with out_degree>in_degree or and have no internal branches, or
|
| 102 |
+
# ii) consists of a single node or a single splitting node
|
| 103 |
+
|
| 104 |
+
# Note that each chain includes its start/end node, i.e. they can be appear in multiple chains.
|
| 105 |
+
# """
|
| 106 |
+
# # get all nodes with out_degree>in_degree (i.e. start of chain)
|
| 107 |
+
# nodes = tuple(n for n in G.nodes if G.out_degree[n] > G.in_degree[n])
|
| 108 |
+
# # single nodes are those that are not starting a linear chain
|
| 109 |
+
# # single_nodes = tuple(n for n in G.nodes if G.out_degree[n] == G.in_degree[n] == 0)
|
| 110 |
+
# single_nodes = tuple(
|
| 111 |
+
# n for n in G.nodes if G.in_degree[n] == 0 and G.out_degree[n] != 1
|
| 112 |
+
# )
|
| 113 |
+
|
| 114 |
+
# for ni in single_nodes:
|
| 115 |
+
# yield [ni]
|
| 116 |
+
|
| 117 |
+
# for ni in nodes:
|
| 118 |
+
# neighs = tuple(G.neighbors(ni))
|
| 119 |
+
# for child in neighs:
|
| 120 |
+
# path = [ni, child]
|
| 121 |
+
# while len(childs := tuple(G.neighbors(path[-1]))) == 1:
|
| 122 |
+
# path.append(childs[0])
|
| 123 |
+
# yield path
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# def graph_to_napari_tracks(
|
| 127 |
+
# graph: nx.DiGraph,
|
| 128 |
+
# properties: List[str] = [],
|
| 129 |
+
# ):
|
| 130 |
+
# """Convert a track graph to napari tracks."""
|
| 131 |
+
# # each tracklet is a linear chain in the graph
|
| 132 |
+
# chains = tuple(linear_chains(graph))
|
| 133 |
+
|
| 134 |
+
# track_end_to_track_id = dict()
|
| 135 |
+
# labels = []
|
| 136 |
+
# for i, cs in enumerate(chains):
|
| 137 |
+
# label = i + 1
|
| 138 |
+
# labels.append(label)
|
| 139 |
+
# # if len(cs) == 1:
|
| 140 |
+
# # print(cs)
|
| 141 |
+
# # # Non-connected node
|
| 142 |
+
# # continue
|
| 143 |
+
# end = cs[-1]
|
| 144 |
+
# track_end_to_track_id[end] = label
|
| 145 |
+
|
| 146 |
+
# tracks = []
|
| 147 |
+
# tracks_graph = dict()
|
| 148 |
+
# tracks_props = {p: [] for p in properties}
|
| 149 |
+
|
| 150 |
+
# for label, cs in tqdm(zip(labels, chains), total=len(chains)):
|
| 151 |
+
# start = cs[0]
|
| 152 |
+
# if start in track_end_to_track_id and len(cs) > 1:
|
| 153 |
+
# tracks_graph[label] = track_end_to_track_id[start]
|
| 154 |
+
# nodes = cs[1:]
|
| 155 |
+
# else:
|
| 156 |
+
# nodes = cs
|
| 157 |
+
|
| 158 |
+
# for c in nodes:
|
| 159 |
+
# node = graph.nodes[c]
|
| 160 |
+
# t = node["time"]
|
| 161 |
+
# coord = node["coords"]
|
| 162 |
+
# tracks.append([label, t, *list(coord)])
|
| 163 |
+
|
| 164 |
+
# for p in properties:
|
| 165 |
+
# tracks_props[p].append(node[p])
|
| 166 |
+
|
| 167 |
+
# tracks = np.array(tracks)
|
| 168 |
+
# return tracks, tracks_graph, tracks_props
|
| 169 |
|
| 170 |
|
| 171 |
def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray):
|
models/tra_post_model/{trackastra/utils/utils.py β utils.py}
RENAMED
|
@@ -1,13 +1,6 @@
|
|
| 1 |
-
import colorsys
|
| 2 |
-
import itertools
|
| 3 |
import logging
|
| 4 |
-
import random
|
| 5 |
-
import sys
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from timeit import default_timer
|
| 8 |
|
| 9 |
import dask.array as da
|
| 10 |
-
import matplotlib
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
from typing import Optional, Union
|
|
@@ -15,174 +8,6 @@ from typing import Optional, Union
|
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
|
| 18 |
-
def _single_color_integer_cmap(color=(0.3, 0.4, 0.5)):
|
| 19 |
-
from matplotlib.colors import Colormap
|
| 20 |
-
|
| 21 |
-
assert len(color) in (3, 4)
|
| 22 |
-
|
| 23 |
-
class BinaryMap(Colormap):
|
| 24 |
-
def __init__(self, color):
|
| 25 |
-
self.color = np.array(color)
|
| 26 |
-
if len(self.color) == 3:
|
| 27 |
-
self.color = np.concatenate([self.color, [1]])
|
| 28 |
-
|
| 29 |
-
def __call__(self, X, alpha=None, bytes=False):
|
| 30 |
-
res = np.zeros((*X.shape, 4), np.float32)
|
| 31 |
-
res[..., -1] = self.color[-1]
|
| 32 |
-
res[X > 0] = np.expand_dims(self.color, 0)
|
| 33 |
-
if bytes:
|
| 34 |
-
return np.clip(256 * res, 0, 255).astype(np.uint8)
|
| 35 |
-
else:
|
| 36 |
-
return res
|
| 37 |
-
|
| 38 |
-
return BinaryMap(color)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def render_label(
|
| 42 |
-
lbl,
|
| 43 |
-
img=None,
|
| 44 |
-
cmap=None,
|
| 45 |
-
cmap_img="gray",
|
| 46 |
-
alpha=0.5,
|
| 47 |
-
alpha_boundary=None,
|
| 48 |
-
normalize_img=True,
|
| 49 |
-
):
|
| 50 |
-
"""Renders a label image and optionally overlays it with another image. Used for generating simple output images to asses the label quality.
|
| 51 |
-
|
| 52 |
-
Parameters
|
| 53 |
-
----------
|
| 54 |
-
lbl: np.ndarray of dtype np.uint16
|
| 55 |
-
The 2D label image
|
| 56 |
-
img: np.ndarray
|
| 57 |
-
The array to overlay the label image with (optional)
|
| 58 |
-
cmap: string, tuple, or callable
|
| 59 |
-
The label colormap. If given as rgb(a) only a single color is used, if None uses a random colormap
|
| 60 |
-
cmap_img: string or callable
|
| 61 |
-
The colormap of img (optional)
|
| 62 |
-
alpha: float
|
| 63 |
-
The alpha value of the overlay. Set alpha=1 to get fully opaque labels
|
| 64 |
-
alpha_boundary: float
|
| 65 |
-
The alpha value of the boundary (if None, use the same as for labels, i.e. no boundaries are visible)
|
| 66 |
-
normalize_img: bool
|
| 67 |
-
If True, normalizes the img (if given)
|
| 68 |
-
|
| 69 |
-
Returns:
|
| 70 |
-
-------
|
| 71 |
-
img: np.ndarray
|
| 72 |
-
the (m,n,4) RGBA image of the rendered label
|
| 73 |
-
|
| 74 |
-
Example:
|
| 75 |
-
-------
|
| 76 |
-
from scipy.ndimage import label, zoom
|
| 77 |
-
img = zoom(np.random.uniform(0,1,(16,16)),(8,8),order=3)
|
| 78 |
-
lbl,_ = label(img>.8)
|
| 79 |
-
u1 = render_label(lbl, img = img, alpha = .7)
|
| 80 |
-
u2 = render_label(lbl, img = img, alpha = 0, alpha_boundary =.8)
|
| 81 |
-
plt.subplot(1,2,1);plt.imshow(u1)
|
| 82 |
-
plt.subplot(1,2,2);plt.imshow(u2)
|
| 83 |
-
|
| 84 |
-
"""
|
| 85 |
-
from matplotlib import cm
|
| 86 |
-
from skimage.segmentation import find_boundaries
|
| 87 |
-
|
| 88 |
-
alpha = np.clip(alpha, 0, 1)
|
| 89 |
-
|
| 90 |
-
if alpha_boundary is None:
|
| 91 |
-
alpha_boundary = alpha
|
| 92 |
-
|
| 93 |
-
if cmap is None:
|
| 94 |
-
cmap = random_label_cmap()
|
| 95 |
-
elif isinstance(cmap, tuple):
|
| 96 |
-
cmap = _single_color_integer_cmap(cmap)
|
| 97 |
-
else:
|
| 98 |
-
pass
|
| 99 |
-
|
| 100 |
-
cmap = cm.get_cmap(cmap) if isinstance(cmap, str) else cmap
|
| 101 |
-
cmap_img = cm.get_cmap(cmap_img) if isinstance(cmap_img, str) else cmap_img
|
| 102 |
-
|
| 103 |
-
# render image if given
|
| 104 |
-
if img is None:
|
| 105 |
-
im_img = np.zeros((*lbl.shape, 4), np.float32)
|
| 106 |
-
im_img[..., -1] = 1
|
| 107 |
-
|
| 108 |
-
else:
|
| 109 |
-
assert lbl.shape[:2] == img.shape[:2]
|
| 110 |
-
img = normalize(img) if normalize_img else img
|
| 111 |
-
if img.ndim == 2:
|
| 112 |
-
im_img = cmap_img(img)
|
| 113 |
-
elif img.ndim == 3:
|
| 114 |
-
im_img = img[..., :4]
|
| 115 |
-
if img.shape[-1] < 4:
|
| 116 |
-
im_img = np.concatenate(
|
| 117 |
-
[img, np.ones(img.shape[:2] + (4 - img.shape[-1],))], axis=-1
|
| 118 |
-
)
|
| 119 |
-
else:
|
| 120 |
-
raise ValueError("img should be 2 or 3 dimensional")
|
| 121 |
-
|
| 122 |
-
# render label
|
| 123 |
-
im_lbl = cmap(lbl)
|
| 124 |
-
|
| 125 |
-
mask_lbl = lbl > 0
|
| 126 |
-
mask_bound = np.bitwise_and(mask_lbl, find_boundaries(lbl, mode="thick"))
|
| 127 |
-
|
| 128 |
-
# blend
|
| 129 |
-
im = im_img.copy()
|
| 130 |
-
|
| 131 |
-
im[mask_lbl] = alpha * im_lbl[mask_lbl] + (1 - alpha) * im_img[mask_lbl]
|
| 132 |
-
im[mask_bound] = (
|
| 133 |
-
alpha_boundary * im_lbl[mask_bound] + (1 - alpha_boundary) * im_img[mask_bound]
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
return im
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def random_label_cmap(n=2**16, h=(0, 1), lightness=(0.4, 1), s=(0.2, 0.8)):
|
| 140 |
-
h, lightness, s = (
|
| 141 |
-
np.random.uniform(*h, n),
|
| 142 |
-
np.random.uniform(*lightness, n),
|
| 143 |
-
np.random.uniform(*s, n),
|
| 144 |
-
)
|
| 145 |
-
cols = np.stack(
|
| 146 |
-
[colorsys.hls_to_rgb(_h, _l, _s) for _h, _l, _s in zip(h, lightness, s)], axis=0
|
| 147 |
-
)
|
| 148 |
-
cols[0] = 0
|
| 149 |
-
return matplotlib.colors.ListedColormap(cols)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
# @torch.jit.script
|
| 153 |
-
def _blockwise_sum_with_bounds(A: torch.Tensor, bounds: torch.Tensor, dim: int = 0):
|
| 154 |
-
A = A.transpose(dim, 0)
|
| 155 |
-
cum = torch.cumsum(A, dim=0)
|
| 156 |
-
cum = torch.cat((torch.zeros_like(cum[:1]), cum), dim=0)
|
| 157 |
-
B = torch.zeros_like(A, device=A.device)
|
| 158 |
-
for i, j in itertools.pairwise(bounds[:-1], bounds[1:]):
|
| 159 |
-
B[i:j] = cum[j] - cum[i]
|
| 160 |
-
B = B.transpose(0, dim)
|
| 161 |
-
return B
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def _bounds_from_timepoints(timepoints: torch.Tensor):
|
| 165 |
-
assert timepoints.ndim == 1
|
| 166 |
-
bounds = torch.cat((
|
| 167 |
-
torch.tensor([0], device=timepoints.device),
|
| 168 |
-
# torch.nonzero faster than torch.where
|
| 169 |
-
torch.nonzero(timepoints[1:] - timepoints[:-1], as_tuple=False)[:, 0] + 1,
|
| 170 |
-
torch.tensor([len(timepoints)], device=timepoints.device),
|
| 171 |
-
))
|
| 172 |
-
return bounds
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
# def blockwise_sum(A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0):
|
| 176 |
-
# # get block boundaries
|
| 177 |
-
# assert A.shape[dim] == len(timepoints)
|
| 178 |
-
|
| 179 |
-
# bounds = _bounds_from_timepoints(timepoints)
|
| 180 |
-
|
| 181 |
-
# # normalize within blocks
|
| 182 |
-
# u = _blockwise_sum_with_bounds(A, bounds, dim=dim)
|
| 183 |
-
# return u
|
| 184 |
-
|
| 185 |
-
|
| 186 |
def blockwise_sum(
|
| 187 |
A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum"
|
| 188 |
):
|
|
@@ -270,11 +95,6 @@ def blockwise_causal_norm(
|
|
| 270 |
else:
|
| 271 |
raise NotImplementedError(f"Mode {mode} not implemented")
|
| 272 |
|
| 273 |
-
# get block boundaries and normalize within blocks
|
| 274 |
-
# bounds = _bounds_from_timepoints(timepoints)
|
| 275 |
-
# u0_sum = _blockwise_sum_with_bounds(u0, bounds, dim=0) + eps
|
| 276 |
-
# u1_sum = _blockwise_sum_with_bounds(u1, bounds, dim=1) + eps
|
| 277 |
-
|
| 278 |
u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps
|
| 279 |
u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps
|
| 280 |
|
|
@@ -296,12 +116,6 @@ def blockwise_causal_norm(
|
|
| 296 |
return res
|
| 297 |
|
| 298 |
|
| 299 |
-
def normalize_tensor(x: torch.Tensor, dim: Optional[int] = None, eps: float = 1e-8):
|
| 300 |
-
if dim is None:
|
| 301 |
-
dim = tuple(range(x.ndim))
|
| 302 |
-
|
| 303 |
-
mi, ma = torch.amin(x, dim=dim, keepdim=True), torch.amax(x, dim=dim, keepdim=True)
|
| 304 |
-
return (x - mi) / (ma - mi + eps)
|
| 305 |
|
| 306 |
|
| 307 |
def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4):
|
|
@@ -340,155 +154,3 @@ def normalize_01(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4):
|
|
| 340 |
x /= ma - mi + 1e-8
|
| 341 |
return x
|
| 342 |
|
| 343 |
-
|
| 344 |
-
def batched(x, batch_size, device):
|
| 345 |
-
return x.unsqueeze(0).expand(batch_size, *((-1,) * x.ndim)).to(device)
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
def preallocate_memory(dataset, model_lightning, batch_size, max_tokens, device):
|
| 349 |
-
"""https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#preallocate-memory-in-case-of-variable-input-length."""
|
| 350 |
-
start = default_timer()
|
| 351 |
-
|
| 352 |
-
if max_tokens is None:
|
| 353 |
-
logger.warning(
|
| 354 |
-
"Preallocating memory without specifying max_tokens not implemented."
|
| 355 |
-
)
|
| 356 |
-
return
|
| 357 |
-
|
| 358 |
-
# max_len = 0
|
| 359 |
-
# max_idx = -1
|
| 360 |
-
# # TODO speed up
|
| 361 |
-
# # find largest training sample
|
| 362 |
-
# if isinstance(dataset, torch.utils.data.dataset.ConcatDataset):
|
| 363 |
-
# lens = tuple(
|
| 364 |
-
# len(t["timepoints"]) for data in dataset.datasets for t in data.windows
|
| 365 |
-
# )
|
| 366 |
-
# elif isinstance(dataset, torch.utils.data.Dataset):
|
| 367 |
-
# lens = tuple(len(t["timepoints"]) for t in dataset.windows)
|
| 368 |
-
# else:
|
| 369 |
-
# lens = tuple(
|
| 370 |
-
# len(s["timepoints"])
|
| 371 |
-
# for i, s in tqdm(
|
| 372 |
-
# enumerate(dataset),
|
| 373 |
-
# desc="Iterate over training set to find largest training sample",
|
| 374 |
-
# total=len(dataset),
|
| 375 |
-
# leave=False,
|
| 376 |
-
# )
|
| 377 |
-
# )
|
| 378 |
-
|
| 379 |
-
# max_len = max(lens)
|
| 380 |
-
# max_idx = lens.index(max_len)
|
| 381 |
-
|
| 382 |
-
# # build random batch
|
| 383 |
-
# x = dataset[max_idx]
|
| 384 |
-
# batch = dict(
|
| 385 |
-
# features=batched(x["features"], batch_size, device),
|
| 386 |
-
# coords=batched(x["coords"], batch_size, device),
|
| 387 |
-
# assoc_matrix=batched(x["assoc_matrix"], batch_size, device),
|
| 388 |
-
# timepoints=batched(x["timepoints"], batch_size, device),
|
| 389 |
-
# padding_mask=batched(torch.zeros_like(x["timepoints"]), batch_size, device),
|
| 390 |
-
# )
|
| 391 |
-
|
| 392 |
-
else:
|
| 393 |
-
max_len = max_tokens
|
| 394 |
-
x = dataset[0]
|
| 395 |
-
batch = dict(
|
| 396 |
-
features=batched(
|
| 397 |
-
torch.zeros(
|
| 398 |
-
(max_len,) + x["features"].shape[1:], dtype=x["features"].dtype
|
| 399 |
-
),
|
| 400 |
-
batch_size,
|
| 401 |
-
device,
|
| 402 |
-
),
|
| 403 |
-
coords=batched(
|
| 404 |
-
torch.zeros(
|
| 405 |
-
(max_len,) + x["coords"].shape[1:], dtype=x["coords"].dtype
|
| 406 |
-
),
|
| 407 |
-
batch_size,
|
| 408 |
-
device,
|
| 409 |
-
),
|
| 410 |
-
assoc_matrix=batched(
|
| 411 |
-
torch.zeros((max_len, max_len), dtype=x["assoc_matrix"].dtype),
|
| 412 |
-
batch_size,
|
| 413 |
-
device,
|
| 414 |
-
),
|
| 415 |
-
timepoints=batched(
|
| 416 |
-
torch.zeros(max_len, dtype=x["timepoints"].dtype), batch_size, device
|
| 417 |
-
),
|
| 418 |
-
padding_mask=batched(torch.zeros(max_len, dtype=bool), batch_size, device),
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
-
loss = model_lightning._common_step(batch)["loss"]
|
| 422 |
-
loss.backward()
|
| 423 |
-
model_lightning.zero_grad()
|
| 424 |
-
|
| 425 |
-
logger.info(
|
| 426 |
-
f"Preallocated memory for largest training batch (length {max_len}) in"
|
| 427 |
-
f" {default_timer() - start:.02f} s"
|
| 428 |
-
)
|
| 429 |
-
if device.type == "cuda":
|
| 430 |
-
logger.info(
|
| 431 |
-
"Memory allocated for model:"
|
| 432 |
-
f" {torch.cuda.max_memory_allocated() / 1024**3:.02f} GB"
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
def seed(s=None):
|
| 437 |
-
"""Seed random number generators.
|
| 438 |
-
|
| 439 |
-
Defaults to unix timestamp of function call.
|
| 440 |
-
|
| 441 |
-
Args:
|
| 442 |
-
s (``int``): Manual seed.
|
| 443 |
-
"""
|
| 444 |
-
if s is None:
|
| 445 |
-
s = int(default_timer())
|
| 446 |
-
|
| 447 |
-
random.seed(s)
|
| 448 |
-
logger.debug(f"Seed `random` rng with {s}.")
|
| 449 |
-
np.random.seed(s)
|
| 450 |
-
logger.debug(f"Seed `numpy` rng with {s}.")
|
| 451 |
-
if "torch" in sys.modules:
|
| 452 |
-
torch.manual_seed(s)
|
| 453 |
-
logger.debug(f"Seed `torch` rng with {s}.")
|
| 454 |
-
|
| 455 |
-
return s
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
def str2bool(x: str) -> bool:
|
| 459 |
-
"""Cast string to boolean.
|
| 460 |
-
|
| 461 |
-
Useful for parsing command line arguments.
|
| 462 |
-
"""
|
| 463 |
-
if not isinstance(x, str):
|
| 464 |
-
raise TypeError("String expected.")
|
| 465 |
-
elif x.lower() in ("true", "t", "1"):
|
| 466 |
-
return True
|
| 467 |
-
elif x.lower() in ("false", "f", "0"):
|
| 468 |
-
return False
|
| 469 |
-
else:
|
| 470 |
-
raise ValueError(f"'{x}' does not seem to be boolean.")
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
def str2path(x: str) -> Path:
|
| 474 |
-
"""Cast string to resolved absolute path.
|
| 475 |
-
|
| 476 |
-
Useful for parsing command line arguments.
|
| 477 |
-
"""
|
| 478 |
-
if not isinstance(x, str):
|
| 479 |
-
raise TypeError("String expected.")
|
| 480 |
-
else:
|
| 481 |
-
return Path(x).expanduser().resolve()
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
if __name__ == "__main__":
|
| 485 |
-
A = torch.rand(50, 50)
|
| 486 |
-
idx = torch.tensor([0, 10, 20, A.shape[0]])
|
| 487 |
-
|
| 488 |
-
A = torch.eye(50)
|
| 489 |
-
|
| 490 |
-
B = _blockwise_sum_with_bounds(A, idx)
|
| 491 |
-
|
| 492 |
-
tps = torch.repeat_interleave(torch.arange(5), 10)
|
| 493 |
-
|
| 494 |
-
C = blockwise_causal_norm(A, tps)
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import dask.array as da
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
from typing import Optional, Union
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def blockwise_sum(
|
| 12 |
A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum"
|
| 13 |
):
|
|
|
|
| 95 |
else:
|
| 96 |
raise NotImplementedError(f"Mode {mode} not implemented")
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps
|
| 99 |
u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps
|
| 100 |
|
|
|
|
| 116 |
return res
|
| 117 |
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4):
|
|
|
|
| 154 |
x /= ma - mi + 1e-8
|
| 155 |
return x
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracking_one.py
CHANGED
|
@@ -27,12 +27,12 @@ from _utils.load_models import load_stable_diffusion_model
|
|
| 27 |
from models.model import Counting_with_SD_features_track as Counting
|
| 28 |
from models.enc_model.loca import build_model as build_loca_model
|
| 29 |
import time
|
| 30 |
-
from models.tra_post_model.
|
| 31 |
-
from models.tra_post_model.
|
| 32 |
normalize,
|
| 33 |
)
|
| 34 |
-
from models.tra_post_model.
|
| 35 |
-
from models.tra_post_model.
|
| 36 |
from _utils.track_args import parse_train_args as get_track_args
|
| 37 |
import torchvision.transforms as T
|
| 38 |
from pathlib import Path
|
|
@@ -142,9 +142,7 @@ class TrackingModule(pl.LightningModule):
|
|
| 142 |
# others
|
| 143 |
self.placeholder_token = placeholder_token
|
| 144 |
self.placeholder_token_id = placeholder_token_id
|
| 145 |
-
|
| 146 |
-
# tracking model
|
| 147 |
-
# fpath = Path("models/tra_post_model/trackastra/.models/general_2d/model.pt")
|
| 148 |
fpath = Path("_utils/config.yaml")
|
| 149 |
args_ = get_track_args()
|
| 150 |
|
|
@@ -935,7 +933,7 @@ class TrackingModule(pl.LightningModule):
|
|
| 935 |
elif mode == "greedy_nodiv":
|
| 936 |
return track_greedy(candidate_graph, allow_divisions=False)
|
| 937 |
elif mode == "ilp":
|
| 938 |
-
from models.tra_post_model.
|
| 939 |
|
| 940 |
return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
|
| 941 |
else:
|
|
|
|
| 27 |
from models.model import Counting_with_SD_features_track as Counting
|
| 28 |
from models.enc_model.loca import build_model as build_loca_model
|
| 29 |
import time
|
| 30 |
+
from models.tra_post_model.model import TrackingTransformer
|
| 31 |
+
from models.tra_post_model.utils import (
|
| 32 |
normalize,
|
| 33 |
)
|
| 34 |
+
from models.tra_post_model.data import build_windows_sd, get_features
|
| 35 |
+
from models.tra_post_model.tracking import TrackGraph, build_graph, track_greedy
|
| 36 |
from _utils.track_args import parse_train_args as get_track_args
|
| 37 |
import torchvision.transforms as T
|
| 38 |
from pathlib import Path
|
|
|
|
| 142 |
# others
|
| 143 |
self.placeholder_token = placeholder_token
|
| 144 |
self.placeholder_token_id = placeholder_token_id
|
| 145 |
+
|
|
|
|
|
|
|
| 146 |
fpath = Path("_utils/config.yaml")
|
| 147 |
args_ = get_track_args()
|
| 148 |
|
|
|
|
| 933 |
elif mode == "greedy_nodiv":
|
| 934 |
return track_greedy(candidate_graph, allow_divisions=False)
|
| 935 |
elif mode == "ilp":
|
| 936 |
+
from models.tra_post_model.tracking.ilp import track_ilp
|
| 937 |
|
| 938 |
return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
|
| 939 |
else:
|