Spaces:
Running on Zero
Running on Zero
| # Modified from Trackastra (https://github.com/weigertlab/trackastra) | |
| import logging | |
| import dask.array as da | |
| import numpy as np | |
| import torch | |
| from typing import Optional, Union | |
| logger = logging.getLogger(__name__) | |
| def blockwise_sum( | |
| A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum" | |
| ): | |
| if not A.shape[dim] == len(timepoints): | |
| raise ValueError( | |
| f"Dimension {dim} of A ({A.shape[dim]}) must match length of timepoints" | |
| f" ({len(timepoints)})" | |
| ) | |
| A = A.transpose(dim, 0) | |
| if len(timepoints) == 0: | |
| logger.warning("Empty timepoints in block_sum. Returning zero tensor.") | |
| return A | |
| # -1 is the filling value for padded/invalid timepoints | |
| min_t = timepoints[timepoints >= 0] | |
| if len(min_t) == 0: | |
| logger.warning("All timepoints are -1 in block_sum. Returning zero tensor.") | |
| return A | |
| min_t = min_t.min() | |
| # after that, valid timepoints start with 1 (padding timepoints will be mapped to 0) | |
| ts = torch.clamp(timepoints - min_t + 1, min=0) | |
| index = ts.unsqueeze(1).expand(-1, len(ts)) | |
| blocks = ts.max().long() + 1 | |
| out = torch.zeros((blocks, A.shape[1]), device=A.device, dtype=A.dtype) | |
| out = torch.scatter_reduce(out, 0, index, A, reduce=reduce) | |
| B = out[ts] | |
| B = B.transpose(0, dim) | |
| return B | |
| def blockwise_causal_norm( | |
| A: torch.Tensor, | |
| timepoints: torch.Tensor, | |
| mode: str = "quiet_softmax", | |
| mask_invalid: torch.BoolTensor = None, | |
| eps: float = 1e-6, | |
| ): | |
| """Normalization over the causal dimension of A. | |
| For each block of constant timepoints, normalize the corresponding block of A | |
| such that the sum over the causal dimension is 1. | |
| Args: | |
| A (torch.Tensor): input tensor | |
| timepoints (torch.Tensor): timepoints for each element in the causal dimension | |
| mode: normalization mode. | |
| `linear`: Simple linear normalization. | |
| `softmax`: Apply exp to A before normalization. | |
| `quiet_softmax`: Apply exp to A before normalization, and add 1 to the denominator of each row/column. | |
| mask_invalid: Values that should not influence the normalization. | |
| eps (float, optional): epsilon for numerical stability. | |
| """ | |
| assert A.ndim == 2 and A.shape[0] == A.shape[1] | |
| A = A.clone() | |
| if mode in ("softmax", "quiet_softmax"): | |
| # Subtract max for numerical stability | |
| # https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning | |
| if mask_invalid is not None: | |
| assert mask_invalid.shape == A.shape | |
| A[mask_invalid] = -torch.inf | |
| # TODO set to min, then to 0 after exp | |
| # Blockwise max | |
| with torch.no_grad(): | |
| ma0 = blockwise_sum(A, timepoints, dim=0, reduce="amax") | |
| ma1 = blockwise_sum(A, timepoints, dim=1, reduce="amax") | |
| u0 = torch.exp(A - ma0) | |
| u1 = torch.exp(A - ma1) | |
| elif mode == "linear": | |
| A = torch.sigmoid(A) | |
| if mask_invalid is not None: | |
| assert mask_invalid.shape == A.shape | |
| A[mask_invalid] = 0 | |
| u0, u1 = A, A | |
| ma0 = ma1 = 0 | |
| else: | |
| raise NotImplementedError(f"Mode {mode} not implemented") | |
| u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps | |
| u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps | |
| if mode == "quiet_softmax": | |
| # Add 1 to the denominator of the softmax. With this, the softmax outputs can be all 0, if the logits are all negative. | |
| # If the logits are positive, the softmax outputs will sum to 1. | |
| # Trick: With maximum subtraction, this is equivalent to adding 1 to the denominator | |
| u0_sum += torch.exp(-ma0) | |
| u1_sum += torch.exp(-ma1) | |
| mask0 = timepoints.unsqueeze(0) > timepoints.unsqueeze(1) | |
| # mask1 = timepoints.unsqueeze(0) < timepoints.unsqueeze(1) | |
| # Entries with t1 == t2 are always masked out in final loss | |
| mask1 = ~mask0 | |
| # blockwise diagonal will be normalized along dim=0 | |
| res = mask0 * u0 / u0_sum + mask1 * u1 / u1_sum | |
| res = torch.clamp(res, 0, 1) | |
| return res | |
| def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4): | |
| """Percentile normalize the image. | |
| If subsample is not None, calculate the percentile values over a subsampled image (last two axis) | |
| which is way faster for large images. | |
| """ | |
| x = x.astype(np.float32) | |
| if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): | |
| y = x[..., ::subsample, ::subsample] | |
| else: | |
| y = x | |
| mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) | |
| x -= mi | |
| x /= ma - mi + 1e-8 | |
| return x | |
| def normalize_01(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4): | |
| """Percentile normalize the image. | |
| If subsample is not None, calculate the percentile values over a subsampled image (last two axis) | |
| which is way faster for large images. | |
| """ | |
| x = x.astype(np.float32) | |
| if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): | |
| y = x[..., ::subsample, ::subsample] | |
| else: | |
| y = x | |
| # mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) | |
| mi = x.min() | |
| ma = x.max() | |
| x -= mi | |
| x /= ma - mi + 1e-8 | |
| return x | |