| from albumentations.pytorch import ToTensorV2 |
| import cv2 |
| import albumentations as A |
| import torch |
| import numpy as np |
| from torch.nn import functional as F |
| from skimage.measure import label, regionprops |
| from matplotlib import pyplot as plt |
| import random |
| import torch.nn as nn |
| import logging |
| import os |
|
|
|
|
| def get_boxes_from_mask(mask, box_num=1, std = 0.1, max_pixel = 5): |
| """ |
| Args: |
| mask: Mask, can be a torch.Tensor or a numpy array of binary mask. |
| box_num: Number of bounding boxes, default is 1. |
| std: Standard deviation of the noise, default is 0.1. |
| max_pixel: Maximum noise pixel value, default is 5. |
| Returns: |
| noise_boxes: Bounding boxes after noise perturbation, returned as a torch.Tensor. |
| """ |
| if isinstance(mask, torch.Tensor): |
| mask = mask.numpy() |
| |
| label_img = label(mask) |
| regions = regionprops(label_img) |
|
|
| |
| boxes = [tuple(region.bbox) for region in regions] |
|
|
| |
| |
| if len(boxes) >= box_num: |
| sorted_regions = sorted(regions, key=lambda x: x.area, reverse=True)[:box_num] |
| boxes = [tuple(region.bbox) for region in sorted_regions] |
|
|
| |
| |
| elif len(boxes) < box_num: |
| num_duplicates = box_num - len(boxes) |
| boxes += [boxes[i % len(boxes)] for i in range(num_duplicates)] |
|
|
| |
| noise_boxes = [] |
| for box in boxes: |
| y0, x0, y1, x1 = box |
| width, height = abs(x1 - x0), abs(y1 - y0) |
| |
| noise_std = min(width, height) * std |
| max_noise = min(max_pixel, int(noise_std * 5)) |
| |
| noise_x = np.random.randint(-max_noise, max_noise) |
| noise_y = np.random.randint(-max_noise, max_noise) |
| x0, y0 = x0 + noise_x, y0 + noise_y |
| x1, y1 = x1 + noise_x, y1 + noise_y |
| noise_boxes.append((x0, y0, x1, y1)) |
| return torch.as_tensor(noise_boxes, dtype=torch.float) |
|
|
|
|
| def select_random_points(pr, gt, point_num = 9): |
| """ |
| Selects random points from the predicted and ground truth masks and assigns labels to them. |
| Args: |
| pred (torch.Tensor): Predicted mask tensor. |
| gt (torch.Tensor): Ground truth mask tensor. |
| point_num (int): Number of random points to select. Default is 9. |
| Returns: |
| batch_points (np.array): Array of selected points coordinates (x, y) for each batch. |
| batch_labels (np.array): Array of corresponding labels (0 for background, 1 for foreground) for each batch. |
| """ |
| pred, gt = pr.data.cpu().numpy(), gt.data.cpu().numpy() |
| error = np.zeros_like(pred) |
| error[pred != gt] = 1 |
|
|
| |
| batch_points = [] |
| batch_labels = [] |
| for j in range(error.shape[0]): |
| one_pred = pred[j].squeeze(0) |
| one_gt = gt[j].squeeze(0) |
| one_erroer = error[j].squeeze(0) |
|
|
| indices = np.argwhere(one_erroer == 1) |
| if indices.shape[0] > 0: |
| selected_indices = indices[np.random.choice(indices.shape[0], point_num, replace=True)] |
| else: |
| indices = np.random.randint(0, 256, size=(point_num, 2)) |
| selected_indices = indices[np.random.choice(indices.shape[0], point_num, replace=True)] |
| selected_indices = selected_indices.reshape(-1, 2) |
|
|
| points, labels = [], [] |
| for i in selected_indices: |
| x, y = i[0], i[1] |
| if one_pred[x,y] == 0 and one_gt[x,y] == 1: |
| label = 1 |
| elif one_pred[x,y] == 1 and one_gt[x,y] == 0: |
| label = 0 |
| points.append((y, x)) |
| labels.append(label) |
|
|
| batch_points.append(points) |
| batch_labels.append(labels) |
| return np.array(batch_points), np.array(batch_labels) |
|
|
|
|
| def init_point_sampling(mask, get_point=1): |
| """ |
| Initialization samples points from the mask and assigns labels to them. |
| Args: |
| mask (torch.Tensor): Input mask tensor. |
| num_points (int): Number of points to sample. Default is 1. |
| Returns: |
| coords (torch.Tensor): Tensor containing the sampled points' coordinates (x, y). |
| labels (torch.Tensor): Tensor containing the corresponding labels (0 for background, 1 for foreground). |
| """ |
| if isinstance(mask, torch.Tensor): |
| mask = mask.numpy() |
| |
| |
| fg_coords = np.argwhere(mask == 1)[:,::-1] |
| bg_coords = np.argwhere(mask == 0)[:,::-1] |
|
|
| fg_size = len(fg_coords) |
| bg_size = len(bg_coords) |
|
|
| if get_point == 1: |
| if fg_size > 0: |
| index = np.random.randint(fg_size) |
| fg_coord = fg_coords[index] |
| label = 1 |
| else: |
| index = np.random.randint(bg_size) |
| fg_coord = bg_coords[index] |
| label = 0 |
| return torch.as_tensor([fg_coord.tolist()], dtype=torch.float), torch.as_tensor([label], dtype=torch.int) |
| else: |
| num_fg = get_point // 2 |
| num_bg = get_point - num_fg |
| fg_indices = np.random.choice(fg_size, size=num_fg, replace=True) |
| bg_indices = np.random.choice(bg_size, size=num_bg, replace=True) |
| fg_coords = fg_coords[fg_indices] |
| bg_coords = bg_coords[bg_indices] |
| coords = np.concatenate([fg_coords, bg_coords], axis=0) |
| labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int) |
| indices = np.random.permutation(get_point) |
| coords, labels = torch.as_tensor(coords[indices], dtype=torch.float), torch.as_tensor(labels[indices], dtype=torch.int) |
| return coords, labels |
| |
|
|
| def train_transforms(img_size, ori_h, ori_w): |
| transforms = [] |
| if ori_h < img_size and ori_w < img_size: |
| transforms.append(A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0))) |
| else: |
| transforms.append(A.Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST)) |
| transforms.append(ToTensorV2(p=1.0)) |
| return A.Compose(transforms, p=1.) |
|
|
|
|
| def get_logger(filename, verbosity=1, name=None): |
| level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} |
| formatter = logging.Formatter( |
| "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" |
| ) |
| logger = logging.getLogger(name) |
| logger.setLevel(level_dict[verbosity]) |
|
|
| os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
|
| fh = logging.FileHandler(filename, "w") |
| fh.setFormatter(formatter) |
| logger.addHandler(fh) |
|
|
| sh = logging.StreamHandler() |
| sh.setFormatter(formatter) |
| logger.addHandler(sh) |
|
|
| return logger |
|
|
|
|
| def generate_point(masks, labels, low_res_masks, batched_input, point_num): |
| masks_clone = masks.clone() |
| masks_sigmoid = torch.sigmoid(masks_clone) |
| masks_binary = (masks_sigmoid > 0.5).float() |
|
|
| low_res_masks_clone = low_res_masks.clone() |
| low_res_masks_logist = torch.sigmoid(low_res_masks_clone) |
|
|
| points, point_labels = select_random_points(masks_binary, labels, point_num = point_num) |
| batched_input["mask_inputs"] = low_res_masks_logist |
| batched_input["point_coords"] = torch.as_tensor(points) |
| batched_input["point_labels"] = torch.as_tensor(point_labels) |
| batched_input["boxes"] = None |
| return batched_input |
|
|
|
|
| def setting_prompt_none(batched_input): |
| batched_input["point_coords"] = None |
| batched_input["point_labels"] = None |
| batched_input["boxes"] = None |
| return batched_input |
|
|
|
|
| def draw_boxes(img, boxes): |
| img_copy = np.copy(img) |
| for box in boxes: |
| cv2.rectangle(img_copy, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) |
| return img_copy |
|
|
|
|
| def save_masks(preds, save_path, mask_name, image_size, original_size, pad=None, boxes=None, points=None, visual_prompt=False): |
| ori_h, ori_w = original_size |
|
|
| preds = torch.sigmoid(preds) |
| preds[preds > 0.5] = int(1) |
| preds[preds <= 0.5] = int(0) |
|
|
| mask = preds.squeeze().cpu().numpy() |
| mask = cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR) |
|
|
| if visual_prompt: |
| if boxes is not None: |
| boxes = boxes.squeeze().cpu().numpy() |
|
|
| x0, y0, x1, y1 = boxes |
| if pad is not None: |
| x0_ori = int((x0 - pad[1]) + 0.5) |
| y0_ori = int((y0 - pad[0]) + 0.5) |
| x1_ori = int((x1 - pad[1]) + 0.5) |
| y1_ori = int((y1 - pad[0]) + 0.5) |
| else: |
| x0_ori = int(x0 * ori_w / image_size) |
| y0_ori = int(y0 * ori_h / image_size) |
| x1_ori = int(x1 * ori_w / image_size) |
| y1_ori = int(y1 * ori_h / image_size) |
|
|
| boxes = [(x0_ori, y0_ori, x1_ori, y1_ori)] |
| mask = draw_boxes(mask, boxes) |
|
|
| if points is not None: |
| point_coords, point_labels = points[0].squeeze(0).cpu().numpy(), points[1].squeeze(0).cpu().numpy() |
| point_coords = point_coords.tolist() |
| if pad is not None: |
| ori_points = [[int((x * ori_w / image_size)) , int((y * ori_h / image_size))]if l==0 else [x - pad[1], y - pad[0]] for (x, y), l in zip(point_coords, point_labels)] |
| else: |
| ori_points = [[int((x * ori_w / image_size)) , int((y * ori_h / image_size))] for x, y in point_coords] |
|
|
| for point, label in zip(ori_points, point_labels): |
| x, y = map(int, point) |
| color = (0, 255, 0) if label == 1 else (0, 0, 255) |
| mask[y, x] = color |
| cv2.drawMarker(mask, (x, y), color, markerType=cv2.MARKER_CROSS , markerSize=7, thickness=2) |
| os.makedirs(save_path, exist_ok=True) |
| mask_path = os.path.join(save_path, f"{mask_name}") |
| cv2.imwrite(mask_path, np.uint8(mask)) |
|
|
|
|
| |
| class FocalLoss(nn.Module): |
| def __init__(self, gamma=2.0, alpha=0.25): |
| super(FocalLoss, self).__init__() |
| self.gamma = gamma |
| self.alpha = alpha |
|
|
| def forward(self, pred, mask): |
| """ |
| pred: [B, 1, H, W] |
| mask: [B, 1, H, W] |
| """ |
| assert pred.shape == mask.shape, "pred and mask should have the same shape." |
| p = torch.sigmoid(pred) |
| num_pos = torch.sum(mask) |
| num_neg = mask.numel() - num_pos |
| w_pos = (1 - p) ** self.gamma |
| w_neg = p ** self.gamma |
|
|
| loss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12) |
| loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12) |
|
|
| loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12) |
|
|
| return loss |
|
|
|
|
| class DiceLoss(nn.Module): |
| def __init__(self, smooth=1.0): |
| super(DiceLoss, self).__init__() |
| self.smooth = smooth |
|
|
| def forward(self, pred, mask): |
| """ |
| pred: [B, 1, H, W] |
| mask: [B, 1, H, W] |
| """ |
| assert pred.shape == mask.shape, "pred and mask should have the same shape." |
| p = torch.sigmoid(pred) |
| intersection = torch.sum(p * mask) |
| union = torch.sum(p) + torch.sum(mask) |
| dice_loss = (2.0 * intersection + self.smooth) / (union + self.smooth) |
|
|
| return 1 - dice_loss |
|
|
|
|
| class MaskIoULoss(nn.Module): |
|
|
| def __init__(self, ): |
| super(MaskIoULoss, self).__init__() |
|
|
| def forward(self, pred_mask, ground_truth_mask, pred_iou): |
| """ |
| pred_mask: [B, 1, H, W] |
| ground_truth_mask: [B, 1, H, W] |
| pred_iou: [B, 1] |
| """ |
| assert pred_mask.shape == ground_truth_mask.shape, "pred_mask and ground_truth_mask should have the same shape." |
|
|
| p = torch.sigmoid(pred_mask) |
| intersection = torch.sum(p * ground_truth_mask) |
| union = torch.sum(p) + torch.sum(ground_truth_mask) - intersection |
| iou = (intersection + 1e-7) / (union + 1e-7) |
| iou_loss = torch.mean((iou - pred_iou) ** 2) |
| return iou_loss |
|
|
|
|
| class FocalDiceloss_IoULoss(nn.Module): |
| |
| def __init__(self, weight=20.0, iou_scale=1.0): |
| super(FocalDiceloss_IoULoss, self).__init__() |
| self.weight = weight |
| self.iou_scale = iou_scale |
| self.focal_loss = FocalLoss() |
| self.dice_loss = DiceLoss() |
| self.maskiou_loss = MaskIoULoss() |
|
|
| def forward(self, pred, mask, pred_iou): |
| """ |
| pred: [B, 1, H, W] |
| mask: [B, 1, H, W] |
| """ |
| assert pred.shape == mask.shape, "pred and mask should have the same shape." |
|
|
| focal_loss = self.focal_loss(pred, mask) |
| dice_loss =self.dice_loss(pred, mask) |
| loss1 = self.weight * focal_loss + dice_loss |
| loss2 = self.maskiou_loss(pred, mask, pred_iou) |
| loss = loss1 + loss2 * self.iou_scale |
| return loss |
|
|