| |
| |
| """ |
| Created on Fri Apr 1 19:18:58 2022 |
| |
| @author: jma |
| """ |
|
|
| from typing import Any, Callable, List, Sequence, Tuple, Union |
| import torch |
| import torch.nn.functional as F |
| from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
| from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option |
|
|
|
|
| __all__ = ["multi_task_sliding_window_inference"] |
|
|
| def multi_task_sliding_window_inference( |
| inputs: torch.Tensor, |
| roi_size: Union[Sequence[int], int], |
| sw_batch_size: int, |
| predictor: Callable[..., torch.Tensor], |
| overlap = 0.25, |
| mode = "constant", |
| sigma_scale = 0.125, |
| padding_mode = "constant", |
| cval = 0.0, |
| sw_device = None, |
| device = None, |
| *args: Any, |
| **kwargs: Any, |
| ) -> torch.Tensor: |
| """ |
| Sliding window inference on `inputs` with `predictor`. |
| |
| When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
| To maintain the same spatial sizes, the output image will be cropped to the original input size. |
| |
| Args: |
| inputs: input image to be processed (assuming NCHW[D]) |
| roi_size: the spatial window size for inferences. |
| When its components have None or non-positives, the corresponding inputs dimension will be used. |
| if the components of the `roi_size` are non-positive values, the transform will use the |
| corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
| to `(32, 64)` if the second spatial dimension size of img is `64`. |
| sw_batch_size: the batch size to run window slices. |
| predictor: given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)` |
| should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D]; |
| where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`. |
| overlap: Amount of overlap between scans. |
| mode: {``"constant"``, ``"gaussian"``} |
| How to blend output of overlapping windows. Defaults to ``"constant"``. |
| |
| - ``"constant``": gives equal weight to all predictions. |
| - ``"gaussian``": gives less weight to predictions on edges of windows. |
| |
| sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
| Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
| When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
| spatial dimensions. |
| padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
| Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
| See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
| cval: fill value for 'constant' padding mode. Default: 0 |
| sw_device: device for the window data. |
| By default the device (and accordingly the memory) of the `inputs` is used. |
| Normally `sw_device` should be consistent with the device where `predictor` is defined. |
| device: device for the stitched output prediction. |
| By default the device (and accordingly the memory) of the `inputs` is used. If for example |
| set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
| `inputs` and `roi_size`. Output is on the `device`. |
| args: optional args to be passed to ``predictor``. |
| kwargs: optional keyword args to be passed to ``predictor``. |
| |
| Note: |
| - input must be channel-first and have a batch dim, supports N-D sliding window. |
| |
| """ |
| num_spatial_dims = len(inputs.shape) - 2 |
| if overlap < 0 or overlap >= 1: |
| raise AssertionError("overlap must be >= 0 and < 1.") |
|
|
| |
| |
| image_size_ = list(inputs.shape[2:]) |
| batch_size = inputs.shape[0] |
|
|
| if device is None: |
| device = inputs.device |
| if sw_device is None: |
| sw_device = inputs.device |
|
|
| roi_size = fall_back_tuple(roi_size, image_size_) |
| |
| image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
| pad_size = [] |
| for k in range(len(inputs.shape) - 1, 1, -1): |
| diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
| half = diff // 2 |
| pad_size.extend([half, diff - half]) |
| inputs = F.pad(inputs, pad=pad_size, mode=mode, value=cval) |
|
|
| scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
|
| |
| slices = dense_patch_slices(image_size, roi_size, scan_interval) |
| num_win = len(slices) |
| total_slices = num_win * batch_size |
|
|
| |
| importance_map = compute_importance_map( |
| get_valid_patch_size(image_size, roi_size), mode="gaussian", sigma_scale=sigma_scale, device=device |
| ) |
|
|
| |
| output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) |
| output_dist = torch.tensor(0.0, device=device) |
| _initialized = False |
| for slice_g in range(0, total_slices, sw_batch_size): |
| slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) |
| unravel_slice = [ |
| [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) |
| for idx in slice_range |
| ] |
| window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| seg_logit, seg_dist = predictor(window_data) |
| seg_logit = torch.nn.functional.interpolate(seg_logit, size=roi_size, mode="bilinear", align_corners=False) |
| seg_logit = torch.softmax(seg_logit, dim=1) |
| seg_dist = torch.nn.functional.interpolate(seg_dist, size=roi_size, mode="bilinear", align_corners=False) |
| seg_dist = torch.sigmoid(seg_dist) |
|
|
| if not _initialized: |
| output_classes = seg_logit.shape[1] |
| dist_class = seg_dist.shape[1] |
| output_shape = [batch_size, output_classes] + list(image_size) |
| output_dist_shape = [batch_size, dist_class] + list(image_size) |
| |
| output_image = torch.zeros(output_shape, dtype=torch.float32, device=device) |
| output_dist = torch.zeros(output_dist_shape, dtype=torch.float32, device=device) |
| count_map = torch.zeros(output_shape, dtype=torch.float32, device=device) |
| count_dist_map = torch.zeros(output_dist_shape, dtype=torch.float32, device=device) |
| _initialized = True |
|
|
| |
| for idx, original_idx in zip(slice_range, unravel_slice): |
| output_image[original_idx] += importance_map * seg_logit[idx - slice_g] |
| output_dist[original_idx] += importance_map * seg_dist[idx - slice_g] |
| count_map[original_idx] += importance_map |
| count_dist_map[original_idx] += importance_map |
|
|
| |
| output_image = output_image / count_map |
| output_dist = output_dist / count_dist_map |
|
|
| final_slicing: List[slice] = [] |
| for sp in range(num_spatial_dims): |
| slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) |
| final_slicing.insert(0, slice_dim) |
| while len(final_slicing) < len(output_image.shape): |
| final_slicing.insert(0, slice(None)) |
| return output_image[final_slicing], output_dist[final_slicing] |
|
|
|
|
|
|
| def _get_scan_interval( |
| image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float |
| ) -> Tuple[int, ...]: |
| """ |
| Compute scan interval according to the image size, roi size and overlap. |
| Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, |
| use 1 instead to make sure sliding window works. |
| |
| """ |
| if len(image_size) != num_spatial_dims: |
| raise ValueError("image coord different from spatial dims.") |
| if len(roi_size) != num_spatial_dims: |
| raise ValueError("roi coord different from spatial dims.") |
|
|
| scan_interval = [] |
| for i in range(num_spatial_dims): |
| if roi_size[i] == image_size[i]: |
| scan_interval.append(int(roi_size[i])) |
| else: |
| interval = int(roi_size[i] * (1 - overlap)) |
| scan_interval.append(interval if interval > 0 else 1) |
| return tuple(scan_interval) |
|
|
|
|