| from transformers import PreTrainedModel, PretrainedConfig |
| import numpy as np |
| import monai.transforms as transforms |
| import nibabel as nib |
| from scipy import sparse |
| import ast |
|
|
| class SegVolConfig(PretrainedConfig): |
| model_type = "segvol" |
|
|
| def __init__( |
| self, |
| test_mode=True, |
| **kwargs, |
| ): |
| self.spatial_size = [32, 256, 256] |
| self.patch_size = [4, 16, 16] |
| self.test_mode = test_mode |
| super().__init__(**kwargs) |
|
|
| class SegVolModel(PreTrainedModel): |
| config_class = SegVolConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| sam_model = _build_sam( |
| image_encoder_type='vit', |
| embed_dim = 768, |
| patch_size=self.config.patch_size, |
| checkpoint=None, |
| image_size=self.config.spatial_size, |
| ) |
| self.model = SegVol( |
| image_encoder=sam_model.image_encoder, |
| mask_decoder=sam_model.mask_decoder, |
| prompt_encoder=sam_model.prompt_encoder, |
| roi_size=self.config.spatial_size, |
| patch_size=self.config.patch_size, |
| |
| test_mode=self.config.test_mode, |
| ) |
| |
| self.processor = SegVolProcessor(spatial_size=self.config.spatial_size) |
|
|
| def forward_test(self, |
| image, |
| zoomed_image=None, |
| text_prompt=None, |
| bbox_prompt_group=None, |
| point_prompt_group=None, |
| use_zoom=True,): |
| device = image.device |
| assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1' |
| assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt' |
| bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None |
| if bbox_prompt_group is not None: |
| bbox_prompt, bbox_prompt_map = bbox_prompt_group |
| if point_prompt_group is not None: |
| point_prompt, point_prompt_map = point_prompt_group |
| volume_shape = image[0][0].shape |
|
|
| with torch.no_grad(): |
| logits_global_single = self.model(zoomed_image, |
| text=text_prompt, |
| boxes=bbox_prompt, |
| points=point_prompt) |
| logits_global_single = F.interpolate( |
| logits_global_single.cpu(), |
| size=volume_shape, mode='nearest') |
| if not use_zoom: |
| return logits_global_single |
| |
| if point_prompt_map is not None: |
| binary_points = F.interpolate( |
| point_prompt_map.float(), |
| size=volume_shape, mode='nearest') |
| if bbox_prompt_map is not None: |
| binary_cube = F.interpolate( |
| bbox_prompt_map.float(), |
| size=volume_shape, mode='nearest') |
| |
| min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(self.config.spatial_size, logits_global_single[0][0]) |
| if min_d is None: |
| print('Fail to detect foreground!') |
| return logits_global_single |
|
|
| |
| image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] |
| global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long() |
| |
| assert not (bbox_prompt is not None and point_prompt is not None), 'Do not use point prompt and box prompt at the same time.' |
| prompt_reflection = None |
| if bbox_prompt is not None: |
| binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] |
| prompt_reflection = ( |
| binary_cube_cropped, |
| global_preds |
| ) |
| if point_prompt is not None: |
| binary_points_cropped = binary_points[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] |
| prompt_reflection = ( |
| binary_points_cropped, |
| global_preds |
| ) |
| |
| |
| with torch.no_grad(): |
| logits_single_cropped = sliding_window_inference( |
| image_single_cropped.to(device), prompt_reflection, |
| self.config.spatial_size, 1, self.model, 0.5, |
| text=text_prompt, |
| use_box=bbox_prompt is not None, |
| use_point=point_prompt is not None, |
| ) |
| logits_single_cropped = logits_single_cropped.cpu().squeeze() |
| logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped |
| return logits_global_single |
|
|
| def forward_train(self, image, train_organs, train_labels): |
| loss = self.model(image, text=None, boxes=None, points=None, |
| train_organs=train_organs, |
| train_labels=train_labels) |
| return loss |
|
|
| def forward(self, **kwargs): |
| if self.config.test_mode: |
| return self.forward_test(kwargs['image'], |
| kwargs['zoomed_image'], |
| kwargs['text_prompt'], |
| kwargs['bbox_prompt_group'], |
| kwargs['point_prompt_group'], |
| kwargs['use_zoom']) |
| else: |
| return self.forward_train(kwargs['image'], |
| kwargs['train_organs'], |
| kwargs['train_labels']) |
| |
| |
| class SegVolProcessor(): |
| def __init__(self, spatial_size) -> None: |
| self.img_loader = transforms.LoadImage() |
| self.transform4test = transforms.Compose( |
| [ |
| DimTranspose(keys=["image", "label"]), |
| MinMaxNormalization(), |
| transforms.CropForegroundd(keys=["image", "label"], source_key="image"), |
| transforms.ToTensord(keys=["image", "label"]), |
| ] |
| ) |
| self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact') |
| self.transform4train = transforms.Compose( |
| [ |
| |
| DimTranspose(keys=["image", "label"]), |
| MinMaxNormalization(), |
| transforms.CropForegroundd(keys=["image", "label"], source_key="image"), |
| transforms.SpatialPadd(keys=["image", "label"], spatial_size=spatial_size, mode='constant'), |
| transforms.OneOf(transforms=[ |
| transforms.Resized(keys=["image", "label"],spatial_size=spatial_size), |
| transforms.RandCropByPosNegLabeld( |
| keys=["image", "label"], |
| label_key="label", |
| spatial_size=spatial_size, |
| pos=2, |
| neg=1, |
| num_samples=1, |
| image_key="image", |
| image_threshold=0, |
| ), |
| ], |
| weights=[1, 3] |
| ), |
| transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0), |
| transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1), |
| transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2), |
| transforms.RandScaleIntensityd(keys="image", factors=0.2, prob=0.2), |
| transforms.RandShiftIntensityd(keys="image", offsets=0.2, prob=0.2), |
| transforms.ToTensord(keys=["image", "label"]), |
| ] |
| ) |
| |
| |
| |
| def preprocess_ct_gt(self, ct_path, gt_path, category): |
| item = {} |
| |
| ct_voxel_ndarray, _ = self.img_loader(ct_path) |
| ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze() |
| ct_shape = ct_voxel_ndarray.shape |
| ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0) |
| ct_voxel_ndarray = self.ForegroundNorm(ct_voxel_ndarray) |
| item['image'] = ct_voxel_ndarray |
|
|
| |
| gt_voxel_ndarray, _ = self.img_loader(gt_path) |
| gt_voxel_ndarray = np.array(gt_voxel_ndarray) |
| present_categories = np.unique(gt_voxel_ndarray) |
| gt_masks = [] |
| for cls_idx in range(len(category)): |
| |
| cls = cls_idx + 1 |
| if cls not in present_categories: |
| gt_voxel_ndarray_category = np.zeros(ct_shape) |
| gt_masks.append(gt_voxel_ndarray_category) |
| else: |
| gt_voxel_ndarray_category = gt_voxel_ndarray.copy() |
| gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0 |
| gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1 |
| gt_masks.append(gt_voxel_ndarray_category) |
| gt_voxel_ndarray = np.stack(gt_masks, axis=0) |
| assert gt_voxel_ndarray.shape[0] == len(category) and gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:] |
| item['label'] = gt_voxel_ndarray.astype(np.int32) |
|
|
| |
| return item['image'], item['label'] |
|
|
| def load_uniseg_case(self, ct_npy_path, gt_npy_path): |
| img_array = np.load(ct_npy_path) |
| allmatrix_sp= sparse.load_npz(gt_npy_path) |
| if 'mask_' in gt_npy_path: |
| gt_shape = ast.literal_eval(gt_npy_path.split('_')[-1].replace('.npz', '')) |
| else: |
| gt_shape = ast.literal_eval(gt_npy_path.split('.')[-2]) |
| gt_array=allmatrix_sp.toarray().reshape(gt_shape) |
| return img_array, gt_array |
|
|
| def ForegroundNorm(self, ct_narray): |
| ct_voxel_ndarray = ct_narray.copy() |
| ct_voxel_ndarray = ct_voxel_ndarray.flatten() |
| thred = np.mean(ct_voxel_ndarray) |
| voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)] |
| upper_bound = np.percentile(voxel_filtered, 99.95) |
| lower_bound = np.percentile(voxel_filtered, 00.05) |
| mean = np.mean(voxel_filtered) |
| std = np.std(voxel_filtered) |
| ct_narray = np.clip(ct_narray, lower_bound, upper_bound) |
| ct_narray = (ct_narray - mean) / max(std, 1e-8) |
| return ct_narray |
|
|
| def zoom_transform(self, ct_npy, gt_npy): |
| item = { |
| 'image': ct_npy, |
| 'label': gt_npy |
| } |
| item = self.transform4test(item) |
| item_zoom_out = self.zoom_out_transform(item) |
| item['zoom_out_image'] = item_zoom_out['image'] |
| item['zoom_out_label'] = item_zoom_out['label'] |
| return item |
|
|
| def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0, device='cpu'): |
| point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra) |
| points_single = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device)) |
| binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0) |
| return points_single, binary_points_resize |
| |
| def bbox_prompt_b(self, label_single_resize, device='cpu'): |
| box_single = generate_box(label_single_resize).unsqueeze(0).float().to(device) |
| binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0) |
| return box_single, binary_cube_resize |
|
|
| def dice_score(self, preds, labels, device='cpu'): |
| assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape) |
| predict = preds.view(1, -1).to(device) |
| target = labels.view(1, -1).to(device) |
|
|
| predict = torch.sigmoid(predict) |
| predict = torch.where(predict > 0.5, 1., 0.) |
| |
| tp = torch.sum(torch.mul(predict, target)) |
| den = torch.sum(predict) + torch.sum(target) + 1 |
| dice = 2 * tp / den |
| return dice |
| |
| def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord): |
| ct = nib.load(ct_path) |
| logits_mask = logits_mask.transpose(-1, -3) |
| start_coord[-1], start_coord[-3] = start_coord[-3], start_coord[-1] |
| end_coord[-1], end_coord[-3] = end_coord[-3], end_coord[-1] |
| preds_save = torch.zeros(ct.shape) |
| preds_save[start_coord[0]:end_coord[0], |
| start_coord[1]:end_coord[1], |
| start_coord[2]:end_coord[2]] = torch.sigmoid(logits_mask) |
| preds_save = torch.where(preds_save > 0.5, 1., 0.).numpy() |
| preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header) |
| nib.save(preds_nii, save_path) |
|
|
| def train_transform(self, ct_npy, gt_npy): |
| item = { |
| 'image': ct_npy, |
| 'label': gt_npy |
| } |
| item = self.transform4train(item) |
| if type(item) is list: |
| assert len(item) == 1 |
| item = item[0] |
| return item |
|
|
| class MinMaxNormalization(transforms.Transform): |
| def __call__(self, data): |
| d = dict(data) |
| k = "image" |
| d[k] = d[k] - d[k].min() |
| d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None) |
| return d |
|
|
| class DimTranspose(transforms.Transform): |
| def __init__(self, keys): |
| self.keys = keys |
| |
| def __call__(self, data): |
| d = dict(data) |
| for key in self.keys: |
| d[key] = np.swapaxes(d[key], -1, -3) |
| return d |
|
|
| |
| def generate_box(pred_pre, bbox_shift=None): |
| meaning_post_label = pred_pre |
| ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True) |
| if all(tensor.nelement() == 0 for tensor in ones_idx): |
| bboxes = torch.tensor([-1,-1,-1,-1,-1,-1]) |
| return bboxes |
| min_coords = [dim.min() for dim in ones_idx] |
| max_coords = [dim.max() for dim in ones_idx] |
|
|
|
|
| if bbox_shift is None: |
| corner_min = [] |
| corner_max = [] |
| shape = meaning_post_label.shape |
| for coor in min_coords: |
| coor_ = max(0, coor) |
| corner_min.append(coor_) |
| for idx, coor in enumerate(max_coords): |
| coor_ = min(shape[idx], coor) |
| corner_max.append(coor_) |
| corner_min = torch.tensor(corner_min) |
| corner_max = torch.tensor(corner_max) |
| return torch.cat((corner_min, corner_max), dim=0) |
| else: |
| |
| corner_min = [] |
| corner_max = [] |
| shape = meaning_post_label.shape |
| for coor in min_coords: |
| coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift)) |
| corner_min.append(coor_) |
| for idx, coor in enumerate(max_coords): |
| coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift)) |
| corner_max.append(coor_) |
| corner_min = torch.tensor(corner_min) |
| corner_max = torch.tensor(corner_max) |
| return torch.cat((corner_min, corner_max), dim=0) |
|
|
|
|
| def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None): |
| spacial_dim = 3 |
| points = torch.zeros((0, 3)) |
| labels = torch.zeros((0)) |
| pos_thred = 0.9 |
| neg_thred = 0.1 |
| |
| |
| positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) |
| negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True) |
|
|
| ones_idx = (preds > pos_thred).nonzero(as_tuple=True) |
| if all(tmp.nelement() == 0 for tmp in ones_idx): |
| |
| num_positive_extra = 0 |
| selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0) |
| points = torch.cat((points, selected_positive_point), dim=0) |
| labels = torch.cat((labels, torch.tensor([-1]).reshape(1))) |
| else: |
| |
| random_idx = torch.randint(len(positive_indices[0]), (1,)) |
| selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0) |
| points = torch.cat((points, selected_positive_point), dim=0) |
| labels = torch.cat((labels, torch.ones((1)))) |
|
|
| if num_positive_extra > 0: |
| pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra] |
| extra_positive_points = [] |
| for pos_idx in pos_idx_list: |
| extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)]) |
| extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3) |
| points = torch.cat((points, extra_positive_points), dim=0) |
| labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0])))) |
|
|
| if num_negative_extra > 0: |
| neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra] |
| extra_negative_points = [] |
| for neg_idx in neg_idx_list: |
| extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)]) |
| extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3) |
| points = torch.cat((points, extra_negative_points), dim=0) |
| labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0])))) |
| |
| if fix_extra_point_num is None: |
| left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0] |
| else: |
| left_point_num = fix_extra_point_num + 1 - labels.shape[0] |
|
|
| for _ in range(left_point_num): |
| ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0) |
| points = torch.cat((points, ignore_point), dim=0) |
| labels = torch.cat((labels, torch.tensor([-1]).reshape(1))) |
|
|
| return points, labels |
|
|
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from transformers import CLIPTextModel, CLIPTextConfig |
| import random |
|
|
| |
| class SegVol(nn.Module): |
| def __init__(self, |
| image_encoder, |
| mask_decoder, |
| prompt_encoder, |
| roi_size, |
| patch_size, |
| |
| test_mode=False, |
| ): |
| super().__init__() |
| self.image_encoder = image_encoder |
| self.mask_decoder = mask_decoder |
| self.prompt_encoder = prompt_encoder |
| self.text_encoder = TextEncoder() |
| self.feat_shape = np.array(roi_size)/np.array(patch_size) |
| self.test_mode = test_mode |
| self.dice_loss = BinaryDiceLoss() |
| self.bce_loss = BCELoss() |
| self.decoder_iter = 6 |
|
|
| def forward(self, image, text=None, boxes=None, points=None, **kwargs): |
| bs = image.shape[0] |
| img_shape = (image.shape[2], image.shape[3], image.shape[4]) |
| image_embedding, _ = self.image_encoder(image) |
| image_embedding = image_embedding.transpose(1, 2).view(bs, -1, |
| int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2])) |
| |
| if self.test_mode: |
| return self.forward_decoder(image_embedding, img_shape, text, boxes, points) |
| |
| |
| |
| sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels']) |
| |
| |
| return sl_loss |
|
|
| def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None): |
| device = image_embedding.device |
| with torch.no_grad(): |
| if boxes is not None: |
| if len(boxes.shape) == 2: |
| boxes = boxes[:, None, :] |
| if text is not None: |
| text_embedding = self.text_encoder(text, device) |
| else: |
| text_embedding = None |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( |
| points=points, |
| boxes=boxes, |
| masks=None, |
| text_embedding=text_embedding, |
| ) |
|
|
| dense_pe = self.prompt_encoder.get_dense_pe() |
| low_res_masks, _ = self.mask_decoder( |
| image_embeddings=image_embedding, |
| text_embedding = text_embedding, |
| image_pe=dense_pe, |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=False, |
| ) |
| logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False) |
| return logits |
|
|
| def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels): |
| device = image_embedding.device |
| iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels, device) |
| |
| prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs], |
| [None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None], |
| [iter_bboxes, iter_points, None]] |
| sl_loss = 0 |
| for prompt in prompt_options: |
| bboxes, points, organs = prompt |
| logits = self.forward_decoder(image_embedding, img_shape, text=organs, boxes=bboxes, points=points) |
| |
| sl_loss_dice = self.dice_loss.forward(logits.squeeze().float(), train_labels.squeeze().float()) |
| sl_loss_bce = self.bce_loss.forward(logits.squeeze().float(), train_labels.squeeze().float()) |
| sl_loss += sl_loss_dice + sl_loss_bce |
| return sl_loss |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def build_prompt_label(self, bs, training_organs, train_labels, device): |
| |
| iter_organs = [] |
| iter_bboxes = [] |
| iter_points_ax = [] |
| iter_point_labels = [] |
| for sample_idx in range(bs): |
| |
| iter_organs.append(training_organs) |
| |
| box = generate_box(train_labels[sample_idx], bbox_shift=10) |
| iter_bboxes.append(box) |
| |
| num_positive_extra_max, num_negative_extra_max = 10, 10 |
| num_positive_extra = random.randint(0, num_positive_extra_max) |
| num_negative_extra = random.randint(0, num_negative_extra_max) |
| point, point_label = select_points( |
| train_labels[sample_idx], |
| num_positive_extra=num_positive_extra, |
| num_negative_extra=num_negative_extra, |
| fix_extra_point_num=num_positive_extra_max + num_negative_extra_max) |
| iter_points_ax.append(point) |
| iter_point_labels.append(point_label) |
| |
| iter_points_ax = torch.stack(iter_points_ax, dim=0).to(device) |
| iter_point_labels = torch.stack(iter_point_labels, dim=0).to(device) |
| iter_points = (iter_points_ax, iter_point_labels) |
| iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(device) |
| return iter_points, iter_bboxes, iter_organs |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| class TextEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| config = CLIPTextConfig() |
| self.clip_text_model = CLIPTextModel(config) |
| self.tokenizer = None |
| self.dim_align = nn.Linear(512, 768) |
| |
| for param in self.clip_text_model.parameters(): |
| param.requires_grad = False |
|
|
| def organ2tokens(self, organ_names, device): |
| text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names] |
| tokens = self.tokenizer(text_list, padding=True, return_tensors="pt") |
| for key in tokens.keys(): |
| tokens[key] = tokens[key].to(device) |
| return tokens |
| |
| def forward(self, text, device): |
| if text is None: |
| return None |
| if type(text) is str: |
| |
| text = [text] |
| tokens = self.organ2tokens(text, device) |
| clip_outputs = self.clip_text_model(**tokens) |
| text_embedding = clip_outputs.pooler_output |
| text_embedding = self.dim_align(text_embedding) |
| return text_embedding |
|
|
| |
| import torch |
| import torch.nn as nn |
|
|
| class BinaryDiceLoss(nn.Module): |
| def __init__(self, smooth=1, p=2, reduction='mean'): |
| super(BinaryDiceLoss, self).__init__() |
| self.smooth = smooth |
| self.p = p |
| self.reduction = reduction |
|
|
| def forward(self, predict, target): |
| predict = torch.sigmoid(predict) |
| target_ = target.clone() |
| target_[target == -1] = 0 |
| assert predict.shape[0] == target.shape[0], "predict & target batch size don't match\n" + str(predict.shape) + '\n' + str(target.shape[0]) |
| predict = predict.contiguous().view(predict.shape[0], -1) |
| target_ = target_.contiguous().view(target_.shape[0], -1) |
|
|
| num = torch.sum(torch.mul(predict, target_), dim=1) |
| den = torch.sum(predict, dim=1) + torch.sum(target_, dim=1) + self.smooth |
|
|
| dice_score = 2*num / den |
| dice_loss = 1 - dice_score |
|
|
| |
| dice_loss_avg = dice_loss.sum() / dice_loss.shape[0] |
|
|
| return dice_loss_avg |
|
|
| class BCELoss(nn.Module): |
| def __init__(self): |
| super(BCELoss, self).__init__() |
| self.criterion = nn.BCEWithLogitsLoss() |
|
|
| def forward(self, predict, target): |
| assert predict.shape == target.shape, 'predict & target shape do not match\n' + str(predict.shape) + '\n' + str(target.shape) |
| target_ = target.clone() |
| target_[target == -1] = 0 |
|
|
| ce_loss = self.criterion(predict, target_) |
|
|
| return ce_loss |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import warnings |
| from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import random |
|
|
| from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
| from monai.transforms import Resize |
| from monai.utils import ( |
| BlendMode, |
| PytorchPadMode, |
| convert_data_type, |
| ensure_tuple, |
| fall_back_tuple, |
| look_up_option, |
| optional_import, |
| ) |
|
|
| tqdm, _ = optional_import("tqdm", name="tqdm") |
|
|
| __all__ = ["sliding_window_inference"] |
|
|
| def logits2roi_coor(spatial_size, logits_global_single): |
| |
| pred_global_single = torch.sigmoid(logits_global_single) > 0.5 |
| |
| nonzero_indices = torch.nonzero(pred_global_single) |
| if nonzero_indices.shape[0] == 0: |
| return None, None, None, None, None, None |
| |
| min_d, max_d = nonzero_indices[:, 0].min(), nonzero_indices[:, 0].max() |
| min_h, max_h = nonzero_indices[:, 1].min(), nonzero_indices[:, 1].max() |
| min_w, max_w = nonzero_indices[:, 2].min(), nonzero_indices[:, 2].max() |
| |
| crop_d, crop_h, crop_w = max_d - min_d + 1, max_h - min_h + 1, max_w - min_w + 1, |
| window_d, window_h, window_w = spatial_size |
| padding_d, padding_h, padding_w = max(0, window_d-crop_d), max(0, window_h-crop_h), max(0, window_w-crop_w) |
| global_d, global_h, global_w = logits_global_single.shape |
| min_d = max(0, min_d - int(padding_d)//2) |
| min_h = max(0, min_h - int(padding_h)//2) |
| min_w = max(0, min_w - int(padding_w)//2) |
| max_d = min(global_d, max_d + int(padding_d)//2) |
| max_h = min(global_h, max_h + int(padding_h)//2) |
| max_w = min(global_w, max_w + int(padding_w)//2) |
| return min_d, min_h, min_w, max_d, max_h, max_w |
|
|
| def build_binary_cube(bbox, binary_cube_shape): |
| min_coord = bbox[0][:3].int().tolist() |
| max_coord = bbox[0][3:].int().tolist() |
| binary_cube = torch.zeros(binary_cube_shape) |
| binary_cube[min_coord[0]:max_coord[0]+1, min_coord[1]:max_coord[1]+1, min_coord[2]:max_coord[2]+1] = 1 |
| return binary_cube |
|
|
| def build_binary_points(points, labels, shape): |
| binary_points = torch.zeros(shape, dtype=torch.int16) |
| binary_points[points[labels == 1, 0].long(), points[labels == 1, 1].long(), points[labels == 1, 2].long()] = 1 |
| return binary_points |
|
|
| def sliding_window_inference( |
| inputs: torch.Tensor, |
| prompt_reflection: Union[torch.Tensor, Tuple[torch.Tensor, ...]], |
| roi_size: Union[Sequence[int], int], |
| sw_batch_size: int, |
| predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], |
| overlap: float = 0.25, |
| mode: Union[BlendMode, str] = BlendMode.CONSTANT, |
| sigma_scale: Union[Sequence[float], float] = 0.125, |
| padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, |
| cval: float = 0.0, |
| sw_device: Union[torch.device, str, None] = None, |
| device: Union[torch.device, str, None] = None, |
| progress: bool = False, |
| roi_weight_map: Union[torch.Tensor, None] = None, |
| *args: Any, |
| **kwargs: Any, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: |
| """ |
| Sliding window inference on `inputs` with `predictor`. |
| |
| The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. |
| Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. |
| e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes |
| could be ([128,64,256], [64,32,128]). |
| In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still |
| an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters |
| so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). |
| |
| When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
| To maintain the same spatial sizes, the output image will be cropped to the original input size. |
| |
| Args: |
| inputs: input image to be processed (assuming NCHW[D]) |
| roi_size: the spatial window size for inferences. |
| When its components have None or non-positives, the corresponding inputs dimension will be used. |
| if the components of the `roi_size` are non-positive values, the transform will use the |
| corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
| to `(32, 64)` if the second spatial dimension size of img is `64`. |
| sw_batch_size: the batch size to run window slices. |
| predictor: given input tensor ``patch_data`` in shape NCHW[D], |
| The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary |
| with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; |
| where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, |
| N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), |
| the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). |
| In this case, the parameter `overlap` and `roi_size` need to be carefully chosen |
| to ensure the scaled output ROI sizes are still integers. |
| If the `predictor`'s input and output spatial sizes are different, |
| we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. |
| overlap: Amount of overlap between scans. |
| mode: {``"constant"``, ``"gaussian"``} |
| How to blend output of overlapping windows. Defaults to ``"constant"``. |
| |
| - ``"constant``": gives equal weight to all predictions. |
| - ``"gaussian``": gives less weight to predictions on edges of windows. |
| |
| sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
| Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
| When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
| spatial dimensions. |
| padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
| Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
| See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
| cval: fill value for 'constant' padding mode. Default: 0 |
| sw_device: device for the window data. |
| By default the device (and accordingly the memory) of the `inputs` is used. |
| Normally `sw_device` should be consistent with the device where `predictor` is defined. |
| device: device for the stitched output prediction. |
| By default the device (and accordingly the memory) of the `inputs` is used. If for example |
| set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
| `inputs` and `roi_size`. Output is on the `device`. |
| progress: whether to print a `tqdm` progress bar. |
| roi_weight_map: pre-computed (non-negative) weight map for each ROI. |
| If not given, and ``mode`` is not `constant`, this map will be computed on the fly. |
| args: optional args to be passed to ``predictor``. |
| kwargs: optional keyword args to be passed to ``predictor``. |
| |
| Note: |
| - input must be channel-first and have a batch dim, supports N-D sliding window. |
| |
| """ |
| print('sliding window inference for ROI') |
| text = kwargs['text'] |
| use_box = kwargs['use_box'] |
| use_point = kwargs['use_point'] |
| assert not (use_box and use_point) |
| compute_dtype = inputs.dtype |
| num_spatial_dims = len(inputs.shape) - 2 |
| if overlap < 0 or overlap >= 1: |
| raise ValueError("overlap must be >= 0 and < 1.") |
|
|
| |
| |
| batch_size, _, *image_size_ = inputs.shape |
|
|
| if device is None: |
| device = inputs.device |
| if sw_device is None: |
| sw_device = inputs.device |
|
|
| roi_size = fall_back_tuple(roi_size, image_size_) |
| |
| image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
| pad_size = [] |
| for k in range(len(inputs.shape) - 1, 1, -1): |
| diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
| half = diff // 2 |
| pad_size.extend([half, diff - half]) |
| inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) |
| |
| if use_point or use_box: |
| binary_prompt_map, global_preds = prompt_reflection |
| global_preds = F.pad(global_preds, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) |
| |
| scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
|
| |
| slices = dense_patch_slices(image_size, roi_size, scan_interval) |
| num_win = len(slices) |
| total_slices = num_win * batch_size |
|
|
| |
| valid_patch_size = get_valid_patch_size(image_size, roi_size) |
| if valid_patch_size == roi_size and (roi_weight_map is not None): |
| importance_map = roi_weight_map |
| else: |
| try: |
| importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) |
| except BaseException as e: |
| raise RuntimeError( |
| "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." |
| ) from e |
| importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] |
| |
| min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) |
| importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) |
|
|
| |
| dict_key, output_image_list, count_map_list = None, [], [] |
| _initialized_ss = -1 |
| is_tensor_output = True |
|
|
| |
| for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): |
| slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) |
| unravel_slice = [ |
| [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) |
| for idx in slice_range |
| ] |
| window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| |
| |
| boxes = None |
| points = None |
| if use_point: |
| window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| point, point_label = select_points(window_binary_prompt_map.squeeze()) |
| points = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device)) |
| pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device) |
| if use_box: |
| if num_win == 1: |
| window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().to(device) |
| else: |
| pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device) |
| seg_prob_out = predictor(window_data, text, boxes, points) |
| |
| |
| seg_prob_tuple: Tuple[torch.Tensor, ...] |
| if isinstance(seg_prob_out, torch.Tensor): |
| seg_prob_tuple = (seg_prob_out,) |
| elif isinstance(seg_prob_out, Mapping): |
| if dict_key is None: |
| dict_key = sorted(seg_prob_out.keys()) |
| seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) |
| is_tensor_output = False |
| else: |
| seg_prob_tuple = ensure_tuple(seg_prob_out) |
| is_tensor_output = False |
|
|
| |
| for ss, seg_prob in enumerate(seg_prob_tuple): |
| seg_prob = seg_prob.to(device) |
|
|
| |
| zoom_scale = [] |
| for axis, (img_s_i, out_w_i, in_w_i) in enumerate( |
| zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) |
| ): |
| _scale = out_w_i / float(in_w_i) |
| if not (img_s_i * _scale).is_integer(): |
| warnings.warn( |
| f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " |
| f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." |
| ) |
| zoom_scale.append(_scale) |
|
|
| if _initialized_ss < ss: |
| |
| output_classes = seg_prob.shape[1] |
| output_shape = [batch_size, output_classes] + [ |
| int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) |
| ] |
| |
| output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) |
| count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) |
| _initialized_ss += 1 |
|
|
| |
| resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) |
|
|
| |
| for idx, original_idx in zip(slice_range, unravel_slice): |
| |
| original_idx_zoom = list(original_idx) |
| for axis in range(2, len(original_idx_zoom)): |
| zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] |
| zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] |
| if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): |
| warnings.warn( |
| f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " |
| f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " |
| f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " |
| f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" |
| f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " |
| "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." |
| ) |
| original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) |
| importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) |
| |
| output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] |
| count_map_list[ss][original_idx_zoom] += ( |
| importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) |
| ) |
|
|
| |
| for ss in range(len(output_image_list)): |
| output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) |
|
|
| |
| for ss, output_i in enumerate(output_image_list): |
| if torch.isnan(output_i).any() or torch.isinf(output_i).any(): |
| warnings.warn("Sliding window inference results contain NaN or Inf.") |
|
|
| zoom_scale = [ |
| seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) |
| ] |
|
|
| final_slicing: List[slice] = [] |
| for sp in range(num_spatial_dims): |
| slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) |
| slice_dim = slice( |
| int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), |
| int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), |
| ) |
| final_slicing.insert(0, slice_dim) |
| while len(final_slicing) < len(output_i.shape): |
| final_slicing.insert(0, slice(None)) |
| output_image_list[ss] = output_i[final_slicing] |
|
|
| if dict_key is not None: |
| final_output = dict(zip(dict_key, output_image_list)) |
| else: |
| final_output = tuple(output_image_list) |
| return final_output[0] if is_tensor_output else final_output |
|
|
|
|
| def _get_scan_interval( |
| image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float |
| ) -> Tuple[int, ...]: |
| """ |
| Compute scan interval according to the image size, roi size and overlap. |
| Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, |
| use 1 instead to make sure sliding window works. |
| |
| """ |
| if len(image_size) != num_spatial_dims: |
| raise ValueError("image coord different from spatial dims.") |
| if len(roi_size) != num_spatial_dims: |
| raise ValueError("roi coord different from spatial dims.") |
|
|
| scan_interval = [] |
| for i in range(num_spatial_dims): |
| if roi_size[i] == image_size[i]: |
| scan_interval.append(int(roi_size[i])) |
| else: |
| interval = int(roi_size[i] * (1 - overlap)) |
| scan_interval.append(interval if interval > 0 else 1) |
| return tuple(scan_interval) |
|
|
| |
| import torch |
| import numpy as np |
| from monai.networks.nets import ViT |
|
|
| def _build_sam( |
| image_encoder_type, |
| embed_dim, |
| patch_size, |
| checkpoint, |
| image_size, |
| ): |
| mlp_dim = 3072 |
| num_layers = 12 |
| num_heads = 12 |
| pos_embed = 'perceptron' |
| dropout_rate = 0.0 |
| |
| image_encoder=ViT( |
| in_channels=1, |
| img_size=image_size, |
| patch_size=patch_size, |
| hidden_size=embed_dim, |
| mlp_dim=mlp_dim, |
| num_layers=num_layers, |
| num_heads=num_heads, |
| pos_embed=pos_embed, |
| classification=False, |
| dropout_rate=dropout_rate, |
| ) |
| image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))] |
|
|
| if checkpoint is not None: |
| with open(checkpoint, "rb") as f: |
| state_dict = torch.load(f, map_location='cpu')['state_dict'] |
| encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k} |
| image_encoder.load_state_dict(encoder_dict) |
| print(f'===> image_encoder.load_param: {checkpoint}') |
| sam = Sam( |
| image_encoder=image_encoder, |
| prompt_encoder=PromptEncoder( |
| embed_dim=embed_dim, |
| image_embedding_size=image_embedding_size, |
| input_image_size=image_size, |
| mask_in_chans=16, |
| ), |
| mask_decoder=MaskDecoder( |
| image_encoder_type=image_encoder_type, |
| num_multimask_outputs=3, |
| transformer=TwoWayTransformer( |
| depth=2, |
| embedding_dim=embed_dim, |
| mlp_dim=2048, |
| num_heads=8, |
| ), |
| transformer_dim=embed_dim, |
| iou_head_depth=3, |
| iou_head_hidden_dim=256, |
| image_size=np.array(image_size), |
| patch_size=np.array(patch_size), |
| ), |
| pixel_mean=[123.675, 116.28, 103.53], |
| pixel_std=[58.395, 57.12, 57.375], |
| ) |
| sam.eval() |
| return sam |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from typing import List, Tuple, Type, Optional |
|
|
| class MaskDecoder(nn.Module): |
| def __init__( |
| self, |
| *, |
| image_encoder_type: str, |
| transformer_dim: int, |
| transformer: nn.Module, |
| num_multimask_outputs: int = 3, |
| activation: Type[nn.Module] = nn.GELU, |
| iou_head_depth: int = 3, |
| iou_head_hidden_dim: int = 256, |
| image_size, |
| patch_size, |
| ) -> None: |
| """ |
| Predicts masks given an image and prompt embeddings, using a |
| transformer architecture. |
| |
| Arguments: |
| transformer_dim (int): the channel dimension of the transformer |
| transformer (nn.Module): the transformer used to predict masks |
| num_multimask_outputs (int): the number of masks to predict |
| when disambiguating masks |
| activation (nn.Module): the type of activation to use when |
| upscaling masks |
| iou_head_depth (int): the depth of the MLP used to predict |
| mask quality |
| iou_head_hidden_dim (int): the hidden dimension of the MLP |
| used to predict mask quality |
| """ |
| super().__init__() |
| self.transformer_dim = transformer_dim |
| self.transformer = transformer |
|
|
| self.num_multimask_outputs = num_multimask_outputs |
|
|
| self.iou_token = nn.Embedding(1, transformer_dim) |
| self.num_mask_tokens = num_multimask_outputs + 1 |
| self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) |
|
|
| if image_encoder_type == 'swin_vit': |
| self.feat_shape = image_size/patch_size |
| self.output_upscaling = nn.Sequential( |
| nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), |
| nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), |
| activation(), |
| nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
| |
| activation(), |
| ) |
| else: |
| self.feat_shape = image_size/patch_size * 2 |
| self.output_upscaling = nn.Sequential( |
| nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), |
| nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), |
| activation(), |
| nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
| |
| activation(), |
| ) |
| self.output_hypernetworks_mlps = nn.ModuleList( |
| [ |
| MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) |
| for i in range(self.num_mask_tokens) |
| ] |
| ) |
|
|
| self.iou_prediction_head = MLP( |
| transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth |
| ) |
|
|
| self.txt_align_upscaled_embedding = nn.Linear(768, 96) |
|
|
| def forward( |
| self, |
| image_embeddings: torch.Tensor, |
| text_embedding: Optional[torch.Tensor], |
| image_pe: torch.Tensor, |
| sparse_prompt_embeddings: torch.Tensor, |
| dense_prompt_embeddings: torch.Tensor, |
| multimask_output: bool, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Predict masks given image and prompt embeddings. |
| |
| Returns: |
| torch.Tensor: batched predicted masks |
| """ |
| |
| masks, iou_pred = self.predict_masks( |
| image_embeddings=image_embeddings, |
| text_embedding=text_embedding, |
| image_pe=image_pe, |
| sparse_prompt_embeddings=sparse_prompt_embeddings, |
| dense_prompt_embeddings=dense_prompt_embeddings, |
| ) |
|
|
| |
| if multimask_output: |
| mask_slice = slice(1, None) |
| else: |
| mask_slice = slice(0, 1) |
| masks = masks[:, mask_slice, :, :, :] |
| iou_pred = iou_pred[:, mask_slice] |
|
|
| |
| return masks, iou_pred |
|
|
| def predict_masks( |
| self, |
| image_embeddings: torch.Tensor, |
| text_embedding: torch.Tensor, |
| image_pe: torch.Tensor, |
| sparse_prompt_embeddings: torch.Tensor, |
| dense_prompt_embeddings: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Predicts masks. See 'forward' for more details.""" |
| |
| output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) |
| output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) |
| tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) |
| |
| if image_embeddings.shape[0] != tokens.shape[0]: |
| src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) |
| else: |
| src = image_embeddings |
| src = src + dense_prompt_embeddings |
| pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) |
| b, c, h, w, d = src.shape |
|
|
| |
| hs, src = self.transformer(src, pos_src, tokens) |
| iou_token_out = hs[:, 0, :] |
| mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] |
|
|
| |
| src = src.transpose(1, 2).view(b, c, h, w, d) |
| upscaled_embedding = self.output_upscaling(src) |
| hyper_in_list: List[torch.Tensor] = [] |
| for i in range(self.num_mask_tokens): |
| hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) |
| hyper_in = torch.stack(hyper_in_list, dim=1) |
| b, c, h, w, d = upscaled_embedding.shape |
| masks = (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d) |
| |
| if text_embedding is not None: |
| text_embedding_down = self.txt_align_upscaled_embedding(text_embedding).unsqueeze(dim=1) |
| upscaled_embedding = upscaled_embedding.view(b, c, h * w * d) |
| sim = (text_embedding_down @ upscaled_embedding).view(b, -1, h, w, d) |
| sim = sim.repeat(1, masks.shape[1], 1, 1, 1) |
| masks = masks + sim |
| iou_pred = self.iou_prediction_head(iou_token_out) |
|
|
| return masks, iou_pred |
|
|
| |
| |
| class MLP(nn.Module): |
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dim: int, |
| output_dim: int, |
| num_layers: int, |
| sigmoid_output: bool = False, |
| ) -> None: |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList( |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) |
| ) |
| self.sigmoid_output = sigmoid_output |
|
|
| def forward(self, x): |
| for i, layer in enumerate(self.layers): |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| if self.sigmoid_output: |
| x = F.sigmoid(x) |
| return x |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
|
|
| from typing import Any, Optional, Tuple, Type |
|
|
| class PromptEncoder(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| image_embedding_size: Tuple[int, int, int], |
| input_image_size: Tuple[int, int, int], |
| mask_in_chans: int, |
| activation: Type[nn.Module] = nn.GELU, |
| ) -> None: |
| """ |
| Encodes prompts for input to SAM's mask decoder. |
| |
| Arguments: |
| embed_dim (int): The prompts' embedding dimension |
| image_embedding_size (tuple(int, int)): The spatial size of the |
| image embedding, as (H, W). |
| input_image_size (int): The padded size of the image as input |
| to the image encoder, as (H, W). |
| mask_in_chans (int): The number of hidden channels used for |
| encoding input masks. |
| activation (nn.Module): The activation to use when encoding |
| input masks. |
| """ |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.input_image_size = input_image_size |
| self.image_embedding_size = image_embedding_size |
| self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) |
|
|
| self.num_point_embeddings: int = 4 |
| point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] |
| self.point_embeddings = nn.ModuleList(point_embeddings) |
| self.not_a_point_embed = nn.Embedding(1, embed_dim) |
|
|
| self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2]) |
| self.mask_downscaling = nn.Sequential( |
| nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), |
| LayerNorm2d(mask_in_chans // 4), |
| activation(), |
| nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), |
| LayerNorm2d(mask_in_chans), |
| activation(), |
| nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), |
| ) |
| self.no_mask_embed = nn.Embedding(1, embed_dim) |
|
|
| def get_dense_pe(self) -> torch.Tensor: |
| """ |
| Returns the positional encoding used to encode point prompts, |
| applied to a dense set of points the shape of the image encoding. |
| |
| Returns: |
| torch.Tensor: Positional encoding with shape |
| 1x(embed_dim)x(embedding_h)x(embedding_w) |
| """ |
| return self.pe_layer(self.image_embedding_size).unsqueeze(0) |
|
|
| def _embed_points( |
| self, |
| points: torch.Tensor, |
| labels: torch.Tensor, |
| pad: bool, |
| ) -> torch.Tensor: |
| """Embeds point prompts.""" |
| points = points + 0.5 |
| if pad: |
| padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device) |
| padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) |
| points = torch.cat([points, padding_point], dim=1) |
| labels = torch.cat([labels, padding_label], dim=1) |
| point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) |
| point_embedding[labels == -1] = 0.0 |
| point_embedding[labels == -1] += self.not_a_point_embed.weight |
| point_embedding[labels == 0] += self.point_embeddings[0].weight |
| point_embedding[labels == 1] += self.point_embeddings[1].weight |
| return point_embedding |
|
|
| def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: |
| """Embeds box prompts.""" |
| boxes = boxes + 0.5 |
| coords = boxes.reshape(-1, 2, 3) |
| corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) |
| corner_embedding[:, 0, :] += self.point_embeddings[2].weight |
| corner_embedding[:, 1, :] += self.point_embeddings[3].weight |
| return corner_embedding |
|
|
| def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: |
| """Embeds mask inputs.""" |
| mask_embedding = self.mask_downscaling(masks) |
| return mask_embedding |
|
|
| def _get_batch_size( |
| self, |
| points: Optional[Tuple[torch.Tensor, torch.Tensor]], |
| boxes: Optional[torch.Tensor], |
| masks: Optional[torch.Tensor], |
| text_embedding: Optional[torch.Tensor], |
| ) -> int: |
| """ |
| Gets the batch size of the output given the batch size of the input prompts. |
| """ |
| if points is not None: |
| return points[0].shape[0] |
| elif boxes is not None: |
| return boxes.shape[0] |
| elif masks is not None: |
| return masks.shape[0] |
| elif text_embedding is not None: |
| return text_embedding.shape[0] |
| else: |
| return 1 |
|
|
| def _get_device(self) -> torch.device: |
| return self.point_embeddings[0].weight.device |
|
|
| def forward( |
| self, |
| points: Optional[Tuple[torch.Tensor, torch.Tensor]], |
| boxes: Optional[torch.Tensor], |
| masks: Optional[torch.Tensor], |
| text_embedding: Optional[torch.Tensor], |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| bs = self._get_batch_size(points, boxes, masks, text_embedding) |
| sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) |
|
|
| if points is not None: |
| coords, labels = points |
| point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) |
| sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) |
| |
| if boxes is not None: |
| box_embeddings = self._embed_boxes(boxes) |
| sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) |
| |
| if text_embedding is not None: |
| sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1) |
| |
| if masks is not None: |
| dense_embeddings = self._embed_masks(masks) |
| else: |
| dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand( |
| bs, -1, int(self.image_embedding_size[0]), int(self.image_embedding_size[1]), int(self.image_embedding_size[2]) |
| ) |
|
|
| return sparse_embeddings, dense_embeddings |
|
|
|
|
| class PositionEmbeddingRandom(nn.Module): |
| """ |
| Positional encoding using random spatial frequencies. |
| """ |
|
|
| def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: |
| super().__init__() |
| if scale is None or scale <= 0.0: |
| scale = 1.0 |
| self.register_buffer( |
| "positional_encoding_gaussian_matrix", |
| scale * torch.randn((3, num_pos_feats)), |
| ) |
|
|
| def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: |
| """Positionally encode points that are normalized to [0,1].""" |
| |
| coords = 2 * coords - 1 |
| coords = coords @ self.positional_encoding_gaussian_matrix |
| coords = 2 * np.pi * coords |
| |
| return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) |
|
|
| def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: |
| """Generate positional encoding for a grid of the specified size.""" |
| h, w, d = size |
| device: Any = self.positional_encoding_gaussian_matrix.device |
| grid = torch.ones((h, w, d), device=device, dtype=torch.float32) |
| y_embed = grid.cumsum(dim=0) - 0.5 |
| x_embed = grid.cumsum(dim=1) - 0.5 |
| z_embed = grid.cumsum(dim=2) - 0.5 |
| y_embed = y_embed / h |
| x_embed = x_embed / w |
| z_embed = z_embed / d |
|
|
| pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) |
| return pe.permute(3, 0, 1, 2) |
|
|
| def forward_with_coords( |
| self, coords_input: torch.Tensor, image_size: Tuple[int, int] |
| ) -> torch.Tensor: |
| """Positionally encode points that are not normalized to [0,1].""" |
| coords = coords_input.clone() |
| coords[:, :, 0] = coords[:, :, 0] / image_size[1] |
| coords[:, :, 1] = coords[:, :, 1] / image_size[0] |
| coords[:, :, 2] = coords[:, :, 2] / image_size[2] |
| return self._pe_encoding(coords.to(torch.float)) |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| import math |
| from typing import Tuple, Type |
|
|
| class TwoWayTransformer(nn.Module): |
| def __init__( |
| self, |
| depth: int, |
| embedding_dim: int, |
| num_heads: int, |
| mlp_dim: int, |
| activation: Type[nn.Module] = nn.ReLU, |
| attention_downsample_rate: int = 2, |
| ) -> None: |
| """ |
| A transformer decoder that attends to an input image using |
| queries whose positional embedding is supplied. |
| |
| Args: |
| depth (int): number of layers in the transformer |
| embedding_dim (int): the channel dimension for the input embeddings |
| num_heads (int): the number of heads for multihead attention. Must |
| divide embedding_dim |
| mlp_dim (int): the channel dimension internal to the MLP block |
| activation (nn.Module): the activation to use in the MLP block |
| """ |
| super().__init__() |
| self.depth = depth |
| self.embedding_dim = embedding_dim |
| self.num_heads = num_heads |
| self.mlp_dim = mlp_dim |
| self.layers = nn.ModuleList() |
|
|
| for i in range(depth): |
| self.layers.append( |
| TwoWayAttentionBlock( |
| embedding_dim=embedding_dim, |
| num_heads=num_heads, |
| mlp_dim=mlp_dim, |
| activation=activation, |
| attention_downsample_rate=attention_downsample_rate, |
| skip_first_layer_pe=(i == 0), |
| ) |
| ) |
|
|
| self.final_attn_token_to_image = Attention( |
| embedding_dim, num_heads, downsample_rate=attention_downsample_rate |
| ) |
| self.norm_final_attn = nn.LayerNorm(embedding_dim) |
|
|
| def forward( |
| self, |
| image_embedding: Tensor, |
| image_pe: Tensor, |
| point_embedding: Tensor, |
| ) -> Tuple[Tensor, Tensor]: |
| """ |
| Args: |
| image_embedding (torch.Tensor): image to attend to. Should be shape |
| B x embedding_dim x h x w for any h and w. |
| image_pe (torch.Tensor): the positional encoding to add to the image. Must |
| have the same shape as image_embedding. |
| point_embedding (torch.Tensor): the embedding to add to the query points. |
| Must have shape B x N_points x embedding_dim for any N_points. |
| |
| Returns: |
| torch.Tensor: the processed point_embedding |
| torch.Tensor: the processed image_embedding |
| """ |
| |
| bs, c, h, w, d = image_embedding.shape |
| image_embedding = image_embedding.flatten(2).permute(0, 2, 1) |
| image_pe = image_pe.flatten(2).permute(0, 2, 1) |
|
|
| |
| queries = point_embedding |
| keys = image_embedding |
|
|
| |
| for layer in self.layers: |
| queries, keys = layer( |
| queries=queries, |
| keys=keys, |
| query_pe=point_embedding, |
| key_pe=image_pe, |
| ) |
|
|
| |
| q = queries + point_embedding |
| k = keys + image_pe |
| attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) |
| queries = queries + attn_out |
| queries = self.norm_final_attn(queries) |
|
|
| return queries, keys |
|
|
|
|
| class TwoWayAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| embedding_dim: int, |
| num_heads: int, |
| mlp_dim: int = 2048, |
| activation: Type[nn.Module] = nn.ReLU, |
| attention_downsample_rate: int = 2, |
| skip_first_layer_pe: bool = False, |
| ) -> None: |
| """ |
| A transformer block with four layers: (1) self-attention of sparse |
| inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp |
| block on sparse inputs, and (4) cross attention of dense inputs to sparse |
| inputs. |
| |
| Arguments: |
| embedding_dim (int): the channel dimension of the embeddings |
| num_heads (int): the number of heads in the attention layers |
| mlp_dim (int): the hidden dimension of the mlp block |
| activation (nn.Module): the activation of the mlp block |
| skip_first_layer_pe (bool): skip the PE on the first layer |
| """ |
| super().__init__() |
| self.self_attn = Attention(embedding_dim, num_heads) |
| self.norm1 = nn.LayerNorm(embedding_dim) |
|
|
| self.cross_attn_token_to_image = Attention( |
| embedding_dim, num_heads, downsample_rate=attention_downsample_rate |
| ) |
| self.norm2 = nn.LayerNorm(embedding_dim) |
|
|
| self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) |
| self.norm3 = nn.LayerNorm(embedding_dim) |
|
|
| self.norm4 = nn.LayerNorm(embedding_dim) |
| self.cross_attn_image_to_token = Attention( |
| embedding_dim, num_heads, downsample_rate=attention_downsample_rate |
| ) |
|
|
| self.skip_first_layer_pe = skip_first_layer_pe |
|
|
| def forward( |
| self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor |
| ) -> Tuple[Tensor, Tensor]: |
| |
| if self.skip_first_layer_pe: |
| queries = self.self_attn(q=queries, k=queries, v=queries) |
| else: |
| q = queries + query_pe |
| attn_out = self.self_attn(q=q, k=q, v=queries) |
| queries = queries + attn_out |
| queries = self.norm1(queries) |
|
|
| |
| q = queries + query_pe |
| k = keys + key_pe |
| attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) |
| queries = queries + attn_out |
| queries = self.norm2(queries) |
|
|
| |
| mlp_out = self.mlp(queries) |
| queries = queries + mlp_out |
| queries = self.norm3(queries) |
|
|
| |
| q = queries + query_pe |
| k = keys + key_pe |
| attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) |
| keys = keys + attn_out |
| keys = self.norm4(keys) |
|
|
| return queries, keys |
|
|
|
|
| class Attention(nn.Module): |
| """ |
| An attention layer that allows for downscaling the size of the embedding |
| after projection to queries, keys, and values. |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: int, |
| num_heads: int, |
| downsample_rate: int = 1, |
| ) -> None: |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.internal_dim = embedding_dim // downsample_rate |
| self.num_heads = num_heads |
| assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." |
|
|
| self.q_proj = nn.Linear(embedding_dim, self.internal_dim) |
| self.k_proj = nn.Linear(embedding_dim, self.internal_dim) |
| self.v_proj = nn.Linear(embedding_dim, self.internal_dim) |
| self.out_proj = nn.Linear(self.internal_dim, embedding_dim) |
|
|
| def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: |
| b, n, c = x.shape |
| x = x.reshape(b, n, num_heads, c // num_heads) |
| return x.transpose(1, 2) |
|
|
| def _recombine_heads(self, x: Tensor) -> Tensor: |
| b, n_heads, n_tokens, c_per_head = x.shape |
| x = x.transpose(1, 2) |
| return x.reshape(b, n_tokens, n_heads * c_per_head) |
|
|
| def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: |
| |
| q = self.q_proj(q) |
| k = self.k_proj(k) |
| v = self.v_proj(v) |
|
|
| |
| q = self._separate_heads(q, self.num_heads) |
| k = self._separate_heads(k, self.num_heads) |
| v = self._separate_heads(v, self.num_heads) |
|
|
| |
| _, _, _, c_per_head = q.shape |
| attn = q @ k.permute(0, 1, 3, 2) |
| attn = attn / math.sqrt(c_per_head) |
| attn = torch.softmax(attn, dim=-1) |
|
|
| |
| out = attn @ v |
| out = self._recombine_heads(out) |
| out = self.out_proj(out) |
|
|
| return out |
|
|
| |
| |
| class LayerNorm2d(nn.Module): |
| def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(num_channels)) |
| self.bias = nn.Parameter(torch.zeros(num_channels)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| u = x.mean(1, keepdim=True) |
| s = (x - u).pow(2).mean(1, keepdim=True) |
| x = (x - u) / torch.sqrt(s + self.eps) |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| return x |
|
|
| class MLPBlock(nn.Module): |
| def __init__( |
| self, |
| embedding_dim: int, |
| mlp_dim: int, |
| act: Type[nn.Module] = nn.GELU, |
| ) -> None: |
| super().__init__() |
| self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
| self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
| self.act = act() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.lin2(self.act(self.lin1(x))) |
| |
|
|
| |
| class Sam(nn.Module): |
| mask_threshold: float = 0.0 |
| image_format: str = "RGB" |
|
|
| def __init__( |
| self, |
| image_encoder, |
| prompt_encoder, |
| mask_decoder, |
| pixel_mean: List[float] = [123.675, 116.28, 103.53], |
| pixel_std: List[float] = [58.395, 57.12, 57.375], |
| ) -> None: |
| """ |
| SAM predicts object masks from an image and input prompts. |
| |
| Arguments: |
| image_encoder (ImageEncoderViT): The backbone used to encode the |
| image into image embeddings that allow for efficient mask prediction. |
| prompt_encoder (PromptEncoder): Encodes various types of input prompts. |
| mask_decoder (MaskDecoder): Predicts masks from the image embeddings |
| and encoded prompts. |
| pixel_mean (list(float)): Mean values for normalizing pixels in the input image. |
| pixel_std (list(float)): Std values for normalizing pixels in the input image. |
| """ |
| super().__init__() |
| self.image_encoder = image_encoder |
| self.prompt_encoder = prompt_encoder |
| self.mask_decoder = mask_decoder |
| self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) |
| self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) |
|
|
| @property |
| def device(self) -> Any: |
| return self.pixel_mean.device |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| batched_input: List[Dict[str, Any]], |
| multimask_output: bool, |
| ) -> List[Dict[str, torch.Tensor]]: |
| """ |
| Predicts masks end-to-end from provided images and prompts. |
| If prompts are not known in advance, using SamPredictor is |
| recommended over calling the model directly. |
| |
| Arguments: |
| batched_input (list(dict)): A list over input images, each a |
| dictionary with the following keys. A prompt key can be |
| excluded if it is not present. |
| 'image': The image as a torch tensor in 3xHxW format, |
| already transformed for input to the model. |
| 'original_size': (tuple(int, int)) The original size of |
| the image before transformation, as (H, W). |
| 'point_coords': (torch.Tensor) Batched point prompts for |
| this image, with shape BxNx2. Already transformed to the |
| input frame of the model. |
| 'point_labels': (torch.Tensor) Batched labels for point prompts, |
| with shape BxN. |
| 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. |
| Already transformed to the input frame of the model. |
| 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, |
| in the form Bx1xHxW. |
| multimask_output (bool): Whether the model should predict multiple |
| disambiguating masks, or return a single mask. |
| |
| Returns: |
| (list(dict)): A list over input images, where each element is |
| as dictionary with the following keys. |
| 'masks': (torch.Tensor) Batched binary mask predictions, |
| with shape BxCxHxW, where B is the number of input prompts, |
| C is determined by multimask_output, and (H, W) is the |
| original size of the image. |
| 'iou_predictions': (torch.Tensor) The model's predictions |
| of mask quality, in shape BxC. |
| 'low_res_logits': (torch.Tensor) Low resolution logits with |
| shape BxCxHxW, where H=W=256. Can be passed as mask input |
| to subsequent iterations of prediction. |
| """ |
| input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) |
| image_embeddings = self.image_encoder(input_images) |
|
|
| outputs = [] |
| for image_record, curr_embedding in zip(batched_input, image_embeddings): |
| if "point_coords" in image_record: |
| points = (image_record["point_coords"], image_record["point_labels"]) |
| else: |
| points = None |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( |
| points=points, |
| boxes=image_record.get("boxes", None), |
| masks=image_record.get("mask_inputs", None), |
| ) |
| low_res_masks, iou_predictions = self.mask_decoder( |
| image_embeddings=curr_embedding.unsqueeze(0), |
| image_pe=self.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=multimask_output, |
| ) |
| masks = self.postprocess_masks( |
| low_res_masks, |
| input_size=image_record["image"].shape[-2:], |
| original_size=image_record["original_size"], |
| ) |
| masks = masks > self.mask_threshold |
| outputs.append( |
| { |
| "masks": masks, |
| "iou_predictions": iou_predictions, |
| "low_res_logits": low_res_masks, |
| } |
| ) |
| return outputs |
|
|
| def postprocess_masks( |
| self, |
| masks: torch.Tensor, |
| input_size: Tuple[int, ...], |
| original_size: Tuple[int, ...], |
| ) -> torch.Tensor: |
| """ |
| Remove padding and upscale masks to the original image size. |
| |
| Arguments: |
| masks (torch.Tensor): Batched masks from the mask_decoder, |
| in BxCxHxW format. |
| input_size (tuple(int, int)): The size of the image input to the |
| model, in (H, W) format. Used to remove padding. |
| original_size (tuple(int, int)): The original size of the image |
| before resizing for input to the model, in (H, W) format. |
| |
| Returns: |
| (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) |
| is given by original_size. |
| """ |
| masks = F.interpolate( |
| masks, |
| (self.image_encoder.img_size, self.image_encoder.img_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| masks = masks[..., : input_size[0], : input_size[1]] |
| masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) |
| return masks |
|
|
| def preprocess(self, x: torch.Tensor) -> torch.Tensor: |
| """Normalize pixel values and pad to a square input.""" |
| |
| |
| x = (x - self.pixel_mean) / self.pixel_std |
|
|
| |
| h, w = x.shape[-2:] |
| padh = self.image_encoder.img_size - h |
| padw = self.image_encoder.img_size - w |
| x = F.pad(x, (0, padw, 0, padh)) |
| return x |