Spaces:
Running on Zero
Running on Zero
| """Regionprops features and its augmentations. | |
| WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region. | |
| Modified from Trackastra (https://github.com/weigertlab/trackastra) | |
| """ | |
| import itertools | |
| import logging | |
| from collections import OrderedDict | |
| from collections.abc import Iterable #, Sequence | |
| from functools import reduce | |
| from typing import Literal | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from edt import edt | |
| from skimage.measure import regionprops, regionprops_table | |
| from tqdm import tqdm | |
| from typing import Tuple, Optional, Sequence, Union, List | |
| import typing | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| _PROPERTIES = { | |
| "regionprops": ( | |
| "area", | |
| "intensity_mean", | |
| "intensity_max", | |
| "intensity_min", | |
| "inertia_tensor", | |
| ), | |
| "regionprops2": ( | |
| "equivalent_diameter_area", | |
| "intensity_mean", | |
| "inertia_tensor", | |
| "border_dist", | |
| ), | |
| } | |
| def _border_dist_fast(mask: np.ndarray, cutoff: float = 5): | |
| cutoff = int(cutoff) | |
| border = np.ones(mask.shape, dtype=np.float32) | |
| ndim = len(mask.shape) | |
| for axis, size in enumerate(mask.shape): | |
| # Create fade values for the band [0, cutoff) | |
| band_vals = np.arange(cutoff, dtype=np.float32) / cutoff | |
| # Build slices for the low border | |
| low_slices = [slice(None)] * ndim | |
| low_slices[axis] = slice(0, cutoff) | |
| border_low = border[tuple(low_slices)] | |
| border_low_vals = np.minimum( | |
| border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)] | |
| ) | |
| border[tuple(low_slices)] = border_low_vals | |
| # Build slices for the high border | |
| high_slices = [slice(None)] * ndim | |
| high_slices[axis] = slice(size - cutoff, size) | |
| band_vals_rev = band_vals[::-1] | |
| border_high = border[tuple(high_slices)] | |
| border_high_vals = np.minimum( | |
| border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)] | |
| ) | |
| border[tuple(high_slices)] = border_high_vals | |
| dist = 1 - border | |
| return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) | |
| class WRFeatures: | |
| """regionprops features for a windowed track region.""" | |
| def __init__( | |
| self, | |
| coords: np.ndarray, | |
| labels: np.ndarray, | |
| timepoints: np.ndarray, | |
| features: typing.OrderedDict[str, np.ndarray], | |
| ): | |
| self.ndim = coords.shape[-1] | |
| if self.ndim not in (2, 3): | |
| raise ValueError("Only 2D or 3D data is supported") | |
| self.coords = coords | |
| self.labels = labels | |
| self.features = features.copy() | |
| self.timepoints = timepoints | |
| def __repr__(self): | |
| s = ( | |
| f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}," | |
| f" ntimepoints={len(np.unique(self.timepoints))})\n\n" | |
| ) | |
| for k, v in self.features.items(): | |
| s += f"{k:>20} -> {v.shape}\n" | |
| return s | |
| def features_stacked(self): | |
| return np.concatenate([v for k, v in self.features.items()], axis=-1) | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, key): | |
| if key in self.features: | |
| return self.features[key] | |
| else: | |
| raise KeyError(f"Key {key} not found in features") | |
| def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures": | |
| """Concatenate multiple WRFeatures into a single one.""" | |
| if len(feats) == 0: | |
| raise ValueError("Cannot concatenate empty list of features") | |
| return reduce(lambda x, y: x + y, feats) | |
| def __add__(self, other: "WRFeatures") -> "WRFeatures": | |
| """Concatenate two WRFeatures.""" | |
| if self.ndim != other.ndim: | |
| raise ValueError("Cannot concatenate features of different dimensions") | |
| if self.features.keys() != other.features.keys(): | |
| raise ValueError("Cannot concatenate features with different properties") | |
| coords = np.concatenate([self.coords, other.coords], axis=0) | |
| labels = np.concatenate([self.labels, other.labels], axis=0) | |
| timepoints = np.concatenate([self.timepoints, other.timepoints], axis=0) | |
| features = OrderedDict( | |
| (k, np.concatenate([v, other.features[k]], axis=0)) | |
| for k, v in self.features.items() | |
| ) | |
| return WRFeatures( | |
| coords=coords, labels=labels, timepoints=timepoints, features=features | |
| ) | |
| def from_mask_img( | |
| cls, | |
| mask: np.ndarray, | |
| img: np.ndarray, | |
| properties="regionprops2", | |
| t_start: int = 0, | |
| ): | |
| img = np.asarray(img) | |
| mask = np.asarray(mask) | |
| _ntime, ndim = mask.shape[0], mask.ndim - 1 | |
| if ndim not in (2, 3): | |
| raise ValueError("Only 2D or 3D data is supported") | |
| properties = tuple(_PROPERTIES[properties]) | |
| if "label" in properties or "centroid" in properties: | |
| raise ValueError( | |
| f"label and centroid should not be in properties {properties}" | |
| ) | |
| if "border_dist" in properties: | |
| use_border_dist = True | |
| # remove border_dist from properties | |
| properties = tuple(p for p in properties if p != "border_dist") | |
| else: | |
| use_border_dist = False | |
| df_properties = ("label", "centroid", *properties) | |
| dfs = [] | |
| for i, (y, x) in enumerate(zip(mask, img)): | |
| _df = pd.DataFrame( | |
| regionprops_table(y, intensity_image=x, properties=df_properties) | |
| ) | |
| _df["timepoint"] = i + t_start | |
| if use_border_dist: | |
| _df["border_dist"] = _border_dist_fast(y) | |
| dfs.append(_df) | |
| df = pd.concat(dfs) | |
| if use_border_dist: | |
| properties = (*properties, "border_dist") | |
| timepoints = df["timepoint"].values.astype(np.int32) | |
| labels = df["label"].values.astype(np.int32) | |
| coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32) | |
| features = OrderedDict( | |
| ( | |
| p, | |
| np.stack( | |
| [ | |
| df[c].values.astype(np.float32) | |
| for c in df.columns | |
| if c.startswith(p) | |
| ], | |
| axis=-1, | |
| ), | |
| ) | |
| for p in properties | |
| ) | |
| return cls( | |
| coords=coords, labels=labels, timepoints=timepoints, features=features | |
| ) | |
| def get_features( | |
| detections: np.ndarray, | |
| imgs: Optional[np.ndarray] = None, | |
| features: Literal["none", "wrfeat"] = "wrfeat", | |
| ndim: int = 2, | |
| n_workers=0, | |
| progbar_class=tqdm, | |
| ) -> List[WRFeatures]: | |
| detections = _check_dimensions(detections, ndim) | |
| imgs = _check_dimensions(imgs, ndim) | |
| logger.info(f"Extracting features from {len(detections)} detections") | |
| if n_workers > 0: | |
| logger.info(f"Using {n_workers} processes for feature extraction") | |
| features = joblib.Parallel(n_jobs=n_workers, backend="loky")( | |
| joblib.delayed(WRFeatures.from_mask_img)( | |
| # New axis for time component | |
| mask=mask[np.newaxis, ...].copy(), | |
| img=img[np.newaxis, ...].copy(), | |
| t_start=t, | |
| ) | |
| for t, (mask, img) in progbar_class( | |
| enumerate(zip(detections, imgs)), | |
| total=len(imgs), | |
| desc="Extracting features", | |
| ) | |
| ) | |
| else: | |
| logger.info("Using single process for feature extraction") | |
| features = tuple( | |
| WRFeatures.from_mask_img( | |
| mask=mask[np.newaxis, ...], | |
| img=img[np.newaxis, ...], | |
| t_start=t, | |
| ) | |
| for t, (mask, img) in progbar_class( | |
| enumerate(zip(detections, imgs)), | |
| total=len(imgs), | |
| desc="Extracting features", | |
| ) | |
| ) | |
| return features | |
| def _check_dimensions(x: np.ndarray, ndim: int): | |
| if ndim == 2 and not x.ndim == 3: | |
| raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") | |
| elif ndim == 3: | |
| # if ndim=3 and data is two dimensional, it will be cast to 3D | |
| if x.ndim == 3: | |
| x = np.expand_dims(x, axis=1) | |
| elif x.ndim == 4: | |
| pass | |
| else: | |
| raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") | |
| return x | |
| def build_windows_sd( | |
| features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm | |
| ) -> List[dict]: | |
| windows = [] | |
| for t1, t2 in progbar_class( | |
| zip(range(0, len(features)), range(window_size, len(features) + 1)), | |
| total=len(features) - window_size + 1, | |
| desc="Building windows", | |
| ): | |
| feat = WRFeatures.concat(features[t1:t2]) | |
| labels = feat.labels | |
| timepoints = feat.timepoints | |
| coords = feat.coords | |
| if len(feat) == 0: | |
| coords = np.zeros((0, feat.ndim), dtype=int) | |
| w = dict( | |
| coords=coords, | |
| t1=t1, | |
| labels=labels, | |
| timepoints=timepoints, | |
| features=feat.features_stacked, | |
| img_enc=imgs_enc[t1:t2], | |
| image_stable=imgs_stable[t1:t2], | |
| boxes=boxes, | |
| img=imgs[t1:t2], | |
| mask=masks[t1:t2], | |
| coords_t=torch.tensor(coords, dtype=torch.float32), | |
| labels_t=torch.tensor(labels, dtype=torch.int32), | |
| timepoints_t=torch.tensor(timepoints, dtype=torch.int64), | |
| features_t=torch.tensor(feat.features_stacked, dtype=torch.float32), | |
| img_t=torch.tensor(imgs[t1:t2], dtype=torch.float32), | |
| mask_t=torch.tensor(masks[t1:t2], dtype=torch.int32), | |
| ) | |
| windows.append(w) | |
| logger.debug(f"Built {len(windows)} track windows.\n") | |
| return windows | |