| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import copy |
| from typing import List, Union |
|
|
| import torch |
| from monai.apps.vista3d.inferer import point_based_window_inferer |
| from monai.inferers import Inferer, SlidingWindowInfererAdapt |
| from torch import Tensor |
|
|
|
|
| class Vista3dInferer(Inferer): |
| """ |
| Vista3D Inferer |
| |
| Args: |
| roi_size: the sliding window patch size. |
| overlap: sliding window overlap ratio. |
| """ |
|
|
| def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) -> None: |
| Inferer.__init__(self) |
| self.roi_size = roi_size |
| self.overlap = overlap |
| self.sw_batch_size = sw_batch_size |
| self.use_point_window = use_point_window |
|
|
| def __call__( |
| self, |
| inputs: Union[List[Tensor], Tensor], |
| network, |
| point_coords, |
| point_labels, |
| class_vector, |
| labels=None, |
| label_set=None, |
| prev_mask=None, |
| ): |
| """ |
| Unified callable function API of Inferers. |
| Notice: The point_based_window_inferer currently only supports SINGLE OBJECT INFERENCE with B=1. |
| It only used in interactive segmentation. |
| |
| Args: |
| inputs: input tensor images. |
| network: vista3d model. |
| point_coords: point click coordinates. [B, N, 3]. |
| point_labels: point click labels (0 for negative, 1 for positive) [B, N]. |
| class_vector: class vector of length B. |
| labels: groundtruth labels. Used for sampling validation points. |
| label_set: [0,1,2,3,...,output_classes]. |
| prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! |
| |
| """ |
| prompt_class = copy.deepcopy(class_vector) |
| if class_vector is not None and (point_labels is not None and torch.any(point_labels != -1)): |
| |
| |
| |
| |
| if hasattr(network, "point_head"): |
| point_head = network.point_head |
| elif hasattr(network, "module") and hasattr(network.module, "point_head"): |
| point_head = network.module.point_head |
| else: |
| raise AttributeError("Network does not have attribute 'point_head'.") |
|
|
| if torch.any(class_vector > point_head.last_supported): |
| class_vector = None |
| val_outputs = None |
| torch.cuda.empty_cache() |
| if self.use_point_window and point_coords is not None: |
| if isinstance(inputs, list): |
| device = inputs[0].device |
| else: |
| device = inputs.device |
| val_outputs = point_based_window_inferer( |
| inputs=inputs, |
| roi_size=self.roi_size, |
| sw_batch_size=self.sw_batch_size, |
| transpose=True, |
| with_coord=True, |
| predictor=network, |
| mode="gaussian", |
| sw_device=device, |
| device=device, |
| overlap=self.overlap, |
| point_coords=point_coords, |
| point_labels=point_labels, |
| class_vector=class_vector, |
| prompt_class=prompt_class, |
| prev_mask=prev_mask, |
| labels=labels, |
| label_set=label_set, |
| ) |
| else: |
| val_outputs = SlidingWindowInfererAdapt( |
| roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True, padding_mode="replicate" |
| )( |
| inputs, |
| network, |
| transpose=True, |
| point_coords=point_coords, |
| point_labels=point_labels, |
| class_vector=class_vector, |
| prompt_class=prompt_class, |
| prev_mask=prev_mask, |
| labels=labels, |
| label_set=label_set, |
| ) |
| return val_outputs |
|
|