| """ |
| Contains all the types of interaction related to the GUI |
| Not related to automatic evaluation in the DAVIS dataset |
| |
| You can inherit the Interaction class to create new interaction types |
| undo is (sometimes partially) supported |
| """ |
|
|
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import cv2 |
| import time |
| from .interactive_utils import color_map, index_numpy_to_one_hot_torch |
|
|
|
|
| def aggregate_sbg(prob, keep_bg=False, hard=False): |
| device = prob.device |
| k, h, w = prob.shape |
| ex_prob = torch.zeros((k+1, h, w), device=device) |
| ex_prob[0] = 0.5 |
| ex_prob[1:] = prob |
| ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7) |
| logits = torch.log((ex_prob /(1-ex_prob))) |
|
|
| if hard: |
| |
| logits *= 1000 |
|
|
| if keep_bg: |
| return F.softmax(logits, dim=0) |
| else: |
| return F.softmax(logits, dim=0)[1:] |
|
|
| def aggregate_wbg(prob, keep_bg=False, hard=False): |
| k, h, w = prob.shape |
| new_prob = torch.cat([ |
| torch.prod(1-prob, dim=0, keepdim=True), |
| prob |
| ], 0).clamp(1e-7, 1-1e-7) |
| logits = torch.log((new_prob /(1-new_prob))) |
|
|
| if hard: |
| |
| logits *= 1000 |
|
|
| if keep_bg: |
| return F.softmax(logits, dim=0) |
| else: |
| return F.softmax(logits, dim=0)[1:] |
|
|
| class Interaction: |
| def __init__(self, image, prev_mask, true_size, controller): |
| self.image = image |
| self.prev_mask = prev_mask |
| self.controller = controller |
| self.start_time = time.time() |
|
|
| self.h, self.w = true_size |
|
|
| self.out_prob = None |
| self.out_mask = None |
|
|
| def predict(self): |
| pass |
|
|
|
|
| class FreeInteraction(Interaction): |
| def __init__(self, image, prev_mask, true_size, num_objects): |
| """ |
| prev_mask should be index format numpy array |
| """ |
| super().__init__(image, prev_mask, true_size, None) |
|
|
| self.K = num_objects |
|
|
| self.drawn_map = self.prev_mask.copy() |
| self.curr_path = [[] for _ in range(self.K + 1)] |
|
|
| self.size = None |
|
|
| def set_size(self, size): |
| self.size = size |
|
|
| """ |
| k - object id |
| vis - a tuple (visualization map, pass through alpha). None if not needed. |
| """ |
| def push_point(self, x, y, k, vis=None): |
| if vis is not None: |
| vis_map, vis_alpha = vis |
| selected = self.curr_path[k] |
| selected.append((x, y)) |
| if len(selected) >= 2: |
| cv2.line(self.drawn_map, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| k, thickness=self.size) |
|
|
| |
| if vis is not None: |
| |
| if k == 0: |
| vis_map = cv2.line(vis_map, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| color_map[k], thickness=self.size) |
| else: |
| vis_map = cv2.line(vis_map, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| color_map[k], thickness=self.size) |
| |
| vis_alpha = cv2.line(vis_alpha, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| 0.75, thickness=self.size) |
|
|
| if vis is not None: |
| return vis_map, vis_alpha |
|
|
| def end_path(self): |
| |
| self.curr_path = [[] for _ in range(self.K + 1)] |
|
|
| def predict(self): |
| self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda() |
| |
| |
| |
| return self.out_prob |
|
|
| class ScribbleInteraction(Interaction): |
| def __init__(self, image, prev_mask, true_size, controller, num_objects): |
| """ |
| prev_mask should be in an indexed form |
| """ |
| super().__init__(image, prev_mask, true_size, controller) |
|
|
| self.K = num_objects |
|
|
| self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8) |
| self.drawn_map.fill(255) |
| |
| self.curr_path = [[] for _ in range(self.K + 1)] |
| self.size = 3 |
|
|
| """ |
| k - object id |
| vis - a tuple (visualization map, pass through alpha). None if not needed. |
| """ |
| def push_point(self, x, y, k, vis=None): |
| if vis is not None: |
| vis_map, vis_alpha = vis |
| selected = self.curr_path[k] |
| selected.append((x, y)) |
| if len(selected) >= 2: |
| self.drawn_map = cv2.line(self.drawn_map, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| k, thickness=self.size) |
|
|
| |
| if vis is not None: |
| |
| if k == 0: |
| vis_map = cv2.line(vis_map, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| color_map[k], thickness=self.size) |
| else: |
| vis_map = cv2.line(vis_map, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| color_map[k], thickness=self.size) |
| |
| vis_alpha = cv2.line(vis_alpha, |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), |
| 0.75, thickness=self.size) |
|
|
| |
| if vis is not None: |
| return vis_map, vis_alpha |
|
|
| def end_path(self): |
| |
| self.curr_path = [[] for _ in range(self.K + 1)] |
|
|
| def predict(self): |
| self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map) |
| self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True) |
| return self.out_prob |
|
|
|
|
| class ClickInteraction(Interaction): |
| def __init__(self, image, prev_mask, true_size, controller, tar_obj): |
| """ |
| prev_mask in a prob. form |
| """ |
| super().__init__(image, prev_mask, true_size, controller) |
| self.tar_obj = tar_obj |
|
|
| |
| self.pos_clicks = [] |
| self.neg_clicks = [] |
|
|
| self.out_prob = self.prev_mask.clone() |
|
|
| """ |
| neg - Negative interaction or not |
| vis - a tuple (visualization map, pass through alpha). None if not needed. |
| """ |
| def push_point(self, x, y, neg, vis=None): |
| |
| if neg: |
| self.neg_clicks.append((x, y)) |
| else: |
| self.pos_clicks.append((x, y)) |
|
|
| |
| self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg) |
|
|
| |
| if vis is not None: |
| vis_map, vis_alpha = vis |
| |
| if neg: |
| vis_map = cv2.circle(vis_map, |
| (int(round(x)), int(round(y))), |
| 2, color_map[0], thickness=-1) |
| else: |
| vis_map = cv2.circle(vis_map, |
| (int(round(x)), int(round(y))), |
| 2, color_map[self.tar_obj], thickness=-1) |
|
|
| vis_alpha = cv2.circle(vis_alpha, |
| (int(round(x)), int(round(y))), |
| 2, 1, thickness=-1) |
|
|
| |
| return vis_map, vis_alpha |
|
|
| def predict(self): |
| self.out_prob = self.prev_mask.clone() |
| |
| |
| self.out_prob = torch.clamp(self.out_prob, max=0.9) |
| self.out_prob[self.tar_obj] = self.obj_mask |
| self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True) |
| return self.out_prob |
|
|