"""Regionprops features and its augmentations. WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region. Modified from Trackastra (https://github.com/weigertlab/trackastra) """ import itertools import logging from collections import OrderedDict from collections.abc import Iterable #, Sequence from functools import reduce from typing import Literal import joblib import numpy as np import pandas as pd from edt import edt from skimage.measure import regionprops, regionprops_table from tqdm import tqdm from typing import Tuple, Optional, Sequence, Union, List import typing import torch logger = logging.getLogger(__name__) _PROPERTIES = { "regionprops": ( "area", "intensity_mean", "intensity_max", "intensity_min", "inertia_tensor", ), "regionprops2": ( "equivalent_diameter_area", "intensity_mean", "inertia_tensor", "border_dist", ), } def _border_dist_fast(mask: np.ndarray, cutoff: float = 5): cutoff = int(cutoff) border = np.ones(mask.shape, dtype=np.float32) ndim = len(mask.shape) for axis, size in enumerate(mask.shape): # Create fade values for the band [0, cutoff) band_vals = np.arange(cutoff, dtype=np.float32) / cutoff # Build slices for the low border low_slices = [slice(None)] * ndim low_slices[axis] = slice(0, cutoff) border_low = border[tuple(low_slices)] border_low_vals = np.minimum( border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)] ) border[tuple(low_slices)] = border_low_vals # Build slices for the high border high_slices = [slice(None)] * ndim high_slices[axis] = slice(size - cutoff, size) band_vals_rev = band_vals[::-1] border_high = border[tuple(high_slices)] border_high_vals = np.minimum( border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)] ) border[tuple(high_slices)] = border_high_vals dist = 1 - border return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) class WRFeatures: """regionprops features for a windowed track region.""" def __init__( self, coords: np.ndarray, labels: np.ndarray, timepoints: np.ndarray, features: typing.OrderedDict[str, np.ndarray], ): self.ndim = coords.shape[-1] if self.ndim not in (2, 3): raise ValueError("Only 2D or 3D data is supported") self.coords = coords self.labels = labels self.features = features.copy() self.timepoints = timepoints def __repr__(self): s = ( f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}," f" ntimepoints={len(np.unique(self.timepoints))})\n\n" ) for k, v in self.features.items(): s += f"{k:>20} -> {v.shape}\n" return s @property def features_stacked(self): return np.concatenate([v for k, v in self.features.items()], axis=-1) def __len__(self): return len(self.labels) def __getitem__(self, key): if key in self.features: return self.features[key] else: raise KeyError(f"Key {key} not found in features") @classmethod def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures": """Concatenate multiple WRFeatures into a single one.""" if len(feats) == 0: raise ValueError("Cannot concatenate empty list of features") return reduce(lambda x, y: x + y, feats) def __add__(self, other: "WRFeatures") -> "WRFeatures": """Concatenate two WRFeatures.""" if self.ndim != other.ndim: raise ValueError("Cannot concatenate features of different dimensions") if self.features.keys() != other.features.keys(): raise ValueError("Cannot concatenate features with different properties") coords = np.concatenate([self.coords, other.coords], axis=0) labels = np.concatenate([self.labels, other.labels], axis=0) timepoints = np.concatenate([self.timepoints, other.timepoints], axis=0) features = OrderedDict( (k, np.concatenate([v, other.features[k]], axis=0)) for k, v in self.features.items() ) return WRFeatures( coords=coords, labels=labels, timepoints=timepoints, features=features ) @classmethod def from_mask_img( cls, mask: np.ndarray, img: np.ndarray, properties="regionprops2", t_start: int = 0, ): img = np.asarray(img) mask = np.asarray(mask) _ntime, ndim = mask.shape[0], mask.ndim - 1 if ndim not in (2, 3): raise ValueError("Only 2D or 3D data is supported") properties = tuple(_PROPERTIES[properties]) if "label" in properties or "centroid" in properties: raise ValueError( f"label and centroid should not be in properties {properties}" ) if "border_dist" in properties: use_border_dist = True # remove border_dist from properties properties = tuple(p for p in properties if p != "border_dist") else: use_border_dist = False df_properties = ("label", "centroid", *properties) dfs = [] for i, (y, x) in enumerate(zip(mask, img)): _df = pd.DataFrame( regionprops_table(y, intensity_image=x, properties=df_properties) ) _df["timepoint"] = i + t_start if use_border_dist: _df["border_dist"] = _border_dist_fast(y) dfs.append(_df) df = pd.concat(dfs) if use_border_dist: properties = (*properties, "border_dist") timepoints = df["timepoint"].values.astype(np.int32) labels = df["label"].values.astype(np.int32) coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32) features = OrderedDict( ( p, np.stack( [ df[c].values.astype(np.float32) for c in df.columns if c.startswith(p) ], axis=-1, ), ) for p in properties ) return cls( coords=coords, labels=labels, timepoints=timepoints, features=features ) def get_features( detections: np.ndarray, imgs: Optional[np.ndarray] = None, features: Literal["none", "wrfeat"] = "wrfeat", ndim: int = 2, n_workers=0, progbar_class=tqdm, ) -> List[WRFeatures]: detections = _check_dimensions(detections, ndim) imgs = _check_dimensions(imgs, ndim) logger.info(f"Extracting features from {len(detections)} detections") if n_workers > 0: logger.info(f"Using {n_workers} processes for feature extraction") features = joblib.Parallel(n_jobs=n_workers, backend="loky")( joblib.delayed(WRFeatures.from_mask_img)( # New axis for time component mask=mask[np.newaxis, ...].copy(), img=img[np.newaxis, ...].copy(), t_start=t, ) for t, (mask, img) in progbar_class( enumerate(zip(detections, imgs)), total=len(imgs), desc="Extracting features", ) ) else: logger.info("Using single process for feature extraction") features = tuple( WRFeatures.from_mask_img( mask=mask[np.newaxis, ...], img=img[np.newaxis, ...], t_start=t, ) for t, (mask, img) in progbar_class( enumerate(zip(detections, imgs)), total=len(imgs), desc="Extracting features", ) ) return features def _check_dimensions(x: np.ndarray, ndim: int): if ndim == 2 and not x.ndim == 3: raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") elif ndim == 3: # if ndim=3 and data is two dimensional, it will be cast to 3D if x.ndim == 3: x = np.expand_dims(x, axis=1) elif x.ndim == 4: pass else: raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") return x def build_windows_sd( features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm ) -> List[dict]: windows = [] for t1, t2 in progbar_class( zip(range(0, len(features)), range(window_size, len(features) + 1)), total=len(features) - window_size + 1, desc="Building windows", ): feat = WRFeatures.concat(features[t1:t2]) labels = feat.labels timepoints = feat.timepoints coords = feat.coords if len(feat) == 0: coords = np.zeros((0, feat.ndim), dtype=int) w = dict( coords=coords, t1=t1, labels=labels, timepoints=timepoints, features=feat.features_stacked, img_enc=imgs_enc[t1:t2], image_stable=imgs_stable[t1:t2], boxes=boxes, img=imgs[t1:t2], mask=masks[t1:t2], coords_t=torch.tensor(coords, dtype=torch.float32), labels_t=torch.tensor(labels, dtype=torch.int32), timepoints_t=torch.tensor(timepoints, dtype=torch.int64), features_t=torch.tensor(feat.features_stacked, dtype=torch.float32), img_t=torch.tensor(imgs[t1:t2], dtype=torch.float32), mask_t=torch.tensor(masks[t1:t2], dtype=torch.int32), ) windows.append(w) logger.debug(f"Built {len(windows)} track windows.\n") return windows