| """ |
| Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN |
| (which is based on https://github.com/seoungwugoh/ivs-demo) |
| |
| This version is much simplified. |
| In this repo, we don't have |
| - local control |
| - fusion module |
| - undo |
| - timers |
| |
| but with XMem as the backbone and is more memory (for both CPU and GPU) friendly |
| """ |
|
|
| import functools |
|
|
| import os |
| import cv2 |
| |
| os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH") |
|
|
| import numpy as np |
| import torch |
|
|
| from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, |
| QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, |
| QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton) |
|
|
| from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon |
| from PyQt5.QtCore import Qt, QTimer |
|
|
| from model.network import XMem |
|
|
| from inference.inference_core import InferenceCore |
| from .s2m_controller import S2MController |
| from .fbrs_controller import FBRSController |
|
|
| from .interactive_utils import * |
| from .interaction import * |
| from .resource_manager import ResourceManager |
| from .gui_utils import * |
|
|
|
|
| class App(QWidget): |
| def __init__(self, net: XMem, |
| resource_manager: ResourceManager, |
| s2m_ctrl:S2MController, |
| fbrs_ctrl:FBRSController, config): |
| super().__init__() |
|
|
| self.initialized = False |
| self.num_objects = config['num_objects'] |
| self.s2m_controller = s2m_ctrl |
| self.fbrs_controller = fbrs_ctrl |
| self.config = config |
| self.processor = InferenceCore(net, config) |
| self.processor.set_all_labels(list(range(1, self.num_objects+1))) |
| self.res_man = resource_manager |
|
|
| self.num_frames = len(self.res_man) |
| self.height, self.width = self.res_man.h, self.res_man.w |
|
|
| |
| self.setWindowTitle('XMem Demo') |
| self.setGeometry(100, 100, self.width, self.height+100) |
| self.setWindowIcon(QIcon('docs/icon.png')) |
|
|
| |
| self.play_button = QPushButton('Play Video') |
| self.play_button.clicked.connect(self.on_play_video) |
| self.commit_button = QPushButton('Commit') |
| self.commit_button.clicked.connect(self.on_commit) |
|
|
| self.forward_run_button = QPushButton('Forward Propagate') |
| self.forward_run_button.clicked.connect(self.on_forward_propagation) |
| self.forward_run_button.setMinimumWidth(200) |
|
|
| self.backward_run_button = QPushButton('Backward Propagate') |
| self.backward_run_button.clicked.connect(self.on_backward_propagation) |
| self.backward_run_button.setMinimumWidth(200) |
|
|
| self.reset_button = QPushButton('Reset Frame') |
| self.reset_button.clicked.connect(self.on_reset_mask) |
|
|
| |
| self.lcd = QTextEdit() |
| self.lcd.setReadOnly(True) |
| self.lcd.setMaximumHeight(28) |
| self.lcd.setMaximumWidth(120) |
| self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) |
|
|
| |
| self.tl_slider = QSlider(Qt.Horizontal) |
| self.tl_slider.valueChanged.connect(self.tl_slide) |
| self.tl_slider.setMinimum(0) |
| self.tl_slider.setMaximum(self.num_frames-1) |
| self.tl_slider.setValue(0) |
| self.tl_slider.setTickPosition(QSlider.TicksBelow) |
| self.tl_slider.setTickInterval(1) |
| |
| |
| self.brush_label = QLabel() |
| self.brush_label.setAlignment(Qt.AlignCenter) |
| self.brush_label.setMinimumWidth(100) |
| |
| self.brush_slider = QSlider(Qt.Horizontal) |
| self.brush_slider.valueChanged.connect(self.brush_slide) |
| self.brush_slider.setMinimum(1) |
| self.brush_slider.setMaximum(100) |
| self.brush_slider.setValue(3) |
| self.brush_slider.setTickPosition(QSlider.TicksBelow) |
| self.brush_slider.setTickInterval(2) |
| self.brush_slider.setMinimumWidth(300) |
|
|
| |
| self.combo = QComboBox(self) |
| self.combo.addItem("davis") |
| self.combo.addItem("fade") |
| self.combo.addItem("light") |
| self.combo.addItem("popup") |
| self.combo.addItem("layered") |
| self.combo.currentTextChanged.connect(self.set_viz_mode) |
|
|
| self.save_visualization_checkbox = QCheckBox(self) |
| self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle) |
| self.save_visualization_checkbox.setChecked(False) |
| self.save_visualization = False |
|
|
| |
| self.curr_interaction = 'Click' |
| self.interaction_group = QButtonGroup() |
| self.radio_fbrs = QRadioButton('Click') |
| self.radio_s2m = QRadioButton('Scribble') |
| self.radio_free = QRadioButton('Free') |
| self.interaction_group.addButton(self.radio_fbrs) |
| self.interaction_group.addButton(self.radio_s2m) |
| self.interaction_group.addButton(self.radio_free) |
| self.radio_fbrs.toggled.connect(self.interaction_radio_clicked) |
| self.radio_s2m.toggled.connect(self.interaction_radio_clicked) |
| self.radio_free.toggled.connect(self.interaction_radio_clicked) |
| self.radio_fbrs.toggle() |
|
|
| |
| self.main_canvas = QLabel() |
| self.main_canvas.setSizePolicy(QSizePolicy.Expanding, |
| QSizePolicy.Expanding) |
| self.main_canvas.setAlignment(Qt.AlignCenter) |
| self.main_canvas.setMinimumSize(100, 100) |
|
|
| self.main_canvas.mousePressEvent = self.on_mouse_press |
| self.main_canvas.mouseMoveEvent = self.on_mouse_motion |
| self.main_canvas.setMouseTracking(True) |
| self.main_canvas.mouseReleaseEvent = self.on_mouse_release |
|
|
| |
| self.minimap = QLabel() |
| self.minimap.setSizePolicy(QSizePolicy.Expanding, |
| QSizePolicy.Expanding) |
| self.minimap.setAlignment(Qt.AlignTop) |
| self.minimap.setMinimumSize(100, 100) |
|
|
| |
| self.zoom_p_button = QPushButton('Zoom +') |
| self.zoom_p_button.clicked.connect(self.on_zoom_plus) |
| self.zoom_m_button = QPushButton('Zoom -') |
| self.zoom_m_button.clicked.connect(self.on_zoom_minus) |
|
|
| |
| self.clear_mem_button = QPushButton('Clear memory') |
| self.clear_mem_button.clicked.connect(self.on_clear_memory) |
|
|
| self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') |
| self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size') |
| self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)') |
| self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)') |
|
|
| self.update_memory_size() |
| self.update_gpu_usage() |
|
|
| self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames', |
| callback=self.on_work_min_change) |
| self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames', |
| callback=self.on_work_max_change) |
| self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000, |
| 'Max. long-term memory size', step=1000, callback=self.update_config) |
| self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280, |
| 'Number of prototypes', step=32, callback=self.update_config) |
| self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)', |
| callback=self.update_config) |
|
|
| self.work_mem_min.setValue(self.processor.memory.min_mt_frames) |
| self.work_mem_max.setValue(self.processor.memory.max_mt_frames) |
| self.long_mem_max.setValue(self.processor.memory.max_long_elements) |
| self.num_prototypes_box.setValue(self.processor.memory.num_prototypes) |
| self.mem_every_box.setValue(self.processor.mem_every) |
|
|
| |
| self.import_mask_button = QPushButton('Import mask') |
| self.import_mask_button.clicked.connect(self.on_import_mask) |
| self.import_layer_button = QPushButton('Import layer') |
| self.import_layer_button.clicked.connect(self.on_import_layer) |
|
|
| |
| self.console = QPlainTextEdit() |
| self.console.setReadOnly(True) |
| self.console.setMinimumHeight(100) |
| self.console.setMaximumHeight(100) |
|
|
| |
| navi = QHBoxLayout() |
| navi.addWidget(self.lcd) |
| navi.addWidget(self.play_button) |
|
|
| interact_subbox = QVBoxLayout() |
| interact_topbox = QHBoxLayout() |
| interact_botbox = QHBoxLayout() |
| interact_topbox.setAlignment(Qt.AlignCenter) |
| interact_topbox.addWidget(self.radio_s2m) |
| interact_topbox.addWidget(self.radio_fbrs) |
| interact_topbox.addWidget(self.radio_free) |
| interact_topbox.addWidget(self.brush_label) |
| interact_botbox.addWidget(self.brush_slider) |
| interact_subbox.addLayout(interact_topbox) |
| interact_subbox.addLayout(interact_botbox) |
| navi.addLayout(interact_subbox) |
|
|
| navi.addStretch(1) |
| navi.addWidget(self.reset_button) |
|
|
| navi.addStretch(1) |
| navi.addWidget(QLabel('Overlay Mode')) |
| navi.addWidget(self.combo) |
| navi.addWidget(QLabel('Save overlay during propagation')) |
| navi.addWidget(self.save_visualization_checkbox) |
| navi.addStretch(1) |
| navi.addWidget(self.commit_button) |
| navi.addWidget(self.forward_run_button) |
| navi.addWidget(self.backward_run_button) |
|
|
| |
| draw_area = QHBoxLayout() |
| draw_area.addWidget(self.main_canvas, 4) |
|
|
| |
| minimap_area = QVBoxLayout() |
| minimap_area.setAlignment(Qt.AlignTop) |
| mini_label = QLabel('Minimap') |
| mini_label.setAlignment(Qt.AlignTop) |
| minimap_area.addWidget(mini_label) |
|
|
| |
| minimap_ctrl = QHBoxLayout() |
| minimap_ctrl.setAlignment(Qt.AlignTop) |
| minimap_ctrl.addWidget(self.zoom_p_button) |
| minimap_ctrl.addWidget(self.zoom_m_button) |
| minimap_area.addLayout(minimap_ctrl) |
| minimap_area.addWidget(self.minimap) |
|
|
| |
| minimap_area.addLayout(self.work_mem_gauge_layout) |
| minimap_area.addLayout(self.long_mem_gauge_layout) |
| minimap_area.addLayout(self.gpu_mem_gauge_layout) |
| minimap_area.addLayout(self.torch_mem_gauge_layout) |
| minimap_area.addWidget(self.clear_mem_button) |
| minimap_area.addLayout(self.work_mem_min_layout) |
| minimap_area.addLayout(self.work_mem_max_layout) |
| minimap_area.addLayout(self.long_mem_max_layout) |
| minimap_area.addLayout(self.num_prototypes_box_layout) |
| minimap_area.addLayout(self.mem_every_box_layout) |
|
|
| |
| import_area = QHBoxLayout() |
| import_area.setAlignment(Qt.AlignTop) |
| import_area.addWidget(self.import_mask_button) |
| import_area.addWidget(self.import_layer_button) |
| minimap_area.addLayout(import_area) |
|
|
| |
| minimap_area.addWidget(self.console) |
|
|
| draw_area.addLayout(minimap_area, 1) |
|
|
| layout = QVBoxLayout() |
| layout.addLayout(draw_area) |
| layout.addWidget(self.tl_slider) |
| layout.addLayout(navi) |
| self.setLayout(layout) |
|
|
| |
| self.timer = QTimer() |
| self.timer.setSingleShot(False) |
|
|
| |
| self.gpu_timer = QTimer() |
| self.gpu_timer.setSingleShot(False) |
| self.gpu_timer.timeout.connect(self.on_gpu_timer) |
| self.gpu_timer.setInterval(2000) |
| self.gpu_timer.start() |
|
|
| |
| self.curr_frame_dirty = False |
| self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8) |
| self.current_image_torch = None |
| self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8) |
| self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda() |
|
|
| |
| self.viz_mode = 'davis' |
| self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8) |
| self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) |
| self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8) |
| self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) |
| self.cursur = 0 |
| self.on_showing = None |
|
|
| |
| self.zoom_pixels = 150 |
| |
| |
| self.interaction = None |
| self.pressed = False |
| self.right_click = False |
| self.current_object = 1 |
| self.last_ex = self.last_ey = 0 |
|
|
| self.propagating = False |
|
|
| |
| for i in range(1, self.num_objects+1): |
| QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i)) |
|
|
| |
| QShortcut(QKeySequence(Qt.Key_Left), self).activated.connect(self.on_prev_frame) |
| QShortcut(QKeySequence(Qt.Key_Right), self).activated.connect(self.on_next_frame) |
|
|
| self.interacted_prob = None |
| self.overlay_layer = None |
| self.overlay_layer_torch = None |
|
|
| |
| self.vis_target_objects = [1] |
| |
| self._try_load_layer('./docs/ECCV-logo.png') |
| |
| self.load_current_image_mask() |
| self.show_current_frame() |
| self.show() |
|
|
| self.console_push_text('Initialized.') |
| self.initialized = True |
|
|
| def resizeEvent(self, event): |
| self.show_current_frame() |
|
|
| def console_push_text(self, text): |
| self.console.moveCursor(QTextCursor.End) |
| self.console.insertPlainText(text+'\n') |
|
|
| def interaction_radio_clicked(self, event): |
| self.last_interaction = self.curr_interaction |
| if self.radio_s2m.isChecked(): |
| self.curr_interaction = 'Scribble' |
| self.brush_size = 3 |
| self.brush_slider.setDisabled(True) |
| elif self.radio_fbrs.isChecked(): |
| self.curr_interaction = 'Click' |
| self.brush_size = 3 |
| self.brush_slider.setDisabled(True) |
| elif self.radio_free.isChecked(): |
| self.brush_slider.setDisabled(False) |
| self.brush_slide() |
| self.curr_interaction = 'Free' |
| if self.curr_interaction == 'Scribble': |
| self.commit_button.setEnabled(True) |
| else: |
| self.commit_button.setEnabled(False) |
|
|
| def load_current_image_mask(self, no_mask=False): |
| self.current_image = self.res_man.get_image(self.cursur) |
| self.current_image_torch = None |
|
|
| if not no_mask: |
| loaded_mask = self.res_man.get_mask(self.cursur) |
| if loaded_mask is None: |
| self.current_mask.fill(0) |
| else: |
| self.current_mask = loaded_mask.copy() |
| self.current_prob = None |
|
|
| def load_current_torch_image_mask(self, no_mask=False): |
| if self.current_image_torch is None: |
| self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image) |
|
|
| if self.current_prob is None and not no_mask: |
| self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() |
|
|
| def compose_current_im(self): |
| self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask, |
| self.overlay_layer, self.vis_target_objects) |
|
|
| def update_interact_vis(self): |
| |
| height, width, channel = self.viz.shape |
| bytesPerLine = 3 * width |
|
|
| vis_map = self.vis_map |
| vis_alpha = self.vis_alpha |
| brush_vis_map = self.brush_vis_map |
| brush_vis_alpha = self.brush_vis_alpha |
|
|
| self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha |
| self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha |
| self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8) |
|
|
| qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format_RGB888) |
| self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), |
| Qt.KeepAspectRatio, Qt.FastTransformation))) |
|
|
| self.main_canvas_size = self.main_canvas.size() |
| self.image_size = qImg.size() |
|
|
| def update_minimap(self): |
| ex, ey = self.last_ex, self.last_ey |
| r = self.zoom_pixels//2 |
| ex = int(round(max(r, min(self.width-r, ex)))) |
| ey = int(round(max(r, min(self.height-r, ey)))) |
|
|
| patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8) |
|
|
| height, width, channel = patch.shape |
| bytesPerLine = 3 * width |
| qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format_RGB888) |
| self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(), |
| Qt.KeepAspectRatio, Qt.FastTransformation))) |
|
|
| def update_current_image_fast(self): |
| |
| self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm, |
| self.current_prob, self.overlay_layer_torch, self.vis_target_objects) |
| if self.save_visualization: |
| self.res_man.save_visualization(self.cursur, self.viz) |
|
|
| height, width, channel = self.viz.shape |
| bytesPerLine = 3 * width |
|
|
| qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888) |
| self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), |
| Qt.KeepAspectRatio, Qt.FastTransformation))) |
|
|
| def show_current_frame(self, fast=False): |
| |
| if fast: |
| self.update_current_image_fast() |
| else: |
| self.compose_current_im() |
| self.update_interact_vis() |
| self.update_minimap() |
|
|
| self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) |
| self.tl_slider.setValue(self.cursur) |
|
|
| def pixel_pos_to_image_pos(self, x, y): |
| |
| oh, ow = self.image_size.height(), self.image_size.width() |
| nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width() |
|
|
| h_ratio = nh/oh |
| w_ratio = nw/ow |
| dominate_ratio = min(h_ratio, w_ratio) |
|
|
| |
| x /= dominate_ratio |
| y /= dominate_ratio |
|
|
| |
| fh, fw = nh/dominate_ratio, nw/dominate_ratio |
| x -= (fw-ow)/2 |
| y -= (fh-oh)/2 |
|
|
| return x, y |
|
|
| def is_pos_out_of_bound(self, x, y): |
| x, y = self.pixel_pos_to_image_pos(x, y) |
|
|
| out_of_bound = ( |
| (x < 0) or |
| (y < 0) or |
| (x > self.width-1) or |
| (y > self.height-1) |
| ) |
|
|
| return out_of_bound |
|
|
| def get_scaled_pos(self, x, y): |
| x, y = self.pixel_pos_to_image_pos(x, y) |
|
|
| x = max(0, min(self.width-1, x)) |
| y = max(0, min(self.height-1, y)) |
|
|
| return x, y |
|
|
| def clear_visualization(self): |
| self.vis_map.fill(0) |
| self.vis_alpha.fill(0) |
|
|
| def reset_this_interaction(self): |
| self.complete_interaction() |
| self.clear_visualization() |
| self.interaction = None |
| if self.fbrs_controller is not None: |
| self.fbrs_controller.unanchor() |
|
|
| def set_viz_mode(self): |
| self.viz_mode = self.combo.currentText() |
| self.show_current_frame() |
|
|
| def save_current_mask(self): |
| |
| self.res_man.save_mask(self.cursur, self.current_mask) |
|
|
| def tl_slide(self): |
| |
| |
| if not self.propagating: |
| if self.curr_frame_dirty: |
| self.save_current_mask() |
| self.curr_frame_dirty = False |
|
|
| self.reset_this_interaction() |
| self.cursur = self.tl_slider.value() |
| self.load_current_image_mask() |
| self.show_current_frame() |
|
|
| def brush_slide(self): |
| self.brush_size = self.brush_slider.value() |
| self.brush_label.setText('Brush size: %d' % self.brush_size) |
| try: |
| if type(self.interaction) == FreeInteraction: |
| self.interaction.set_size(self.brush_size) |
| except AttributeError: |
| |
| pass |
|
|
| def on_forward_propagation(self): |
| if self.propagating: |
| |
| self.propagating = False |
| else: |
| self.propagate_fn = self.on_next_frame |
| self.backward_run_button.setEnabled(False) |
| self.forward_run_button.setText('Pause Propagation') |
| self.on_propagation() |
|
|
| def on_backward_propagation(self): |
| if self.propagating: |
| |
| self.propagating = False |
| else: |
| self.propagate_fn = self.on_prev_frame |
| self.forward_run_button.setEnabled(False) |
| self.backward_run_button.setText('Pause Propagation') |
| self.on_propagation() |
|
|
| def on_pause(self): |
| self.propagating = False |
| self.forward_run_button.setEnabled(True) |
| self.backward_run_button.setEnabled(True) |
| self.clear_mem_button.setEnabled(True) |
| self.forward_run_button.setText('Forward Propagate') |
| self.backward_run_button.setText('Backward Propagate') |
| self.console_push_text('Propagation stopped.') |
|
|
| def on_propagation(self): |
| |
| self.load_current_torch_image_mask() |
| self.show_current_frame(fast=True) |
|
|
| self.console_push_text('Propagation started.') |
| self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:]) |
| self.current_mask = torch_prob_to_numpy_mask(self.current_prob) |
| |
| self.interacted_prob = None |
| self.reset_this_interaction() |
| |
| self.propagating = True |
| self.clear_mem_button.setEnabled(False) |
| |
| while self.propagating: |
| self.propagate_fn() |
|
|
| self.load_current_image_mask(no_mask=True) |
| self.load_current_torch_image_mask(no_mask=True) |
|
|
| self.current_prob = self.processor.step(self.current_image_torch) |
| self.current_mask = torch_prob_to_numpy_mask(self.current_prob) |
|
|
| self.save_current_mask() |
| self.show_current_frame(fast=True) |
|
|
| self.update_memory_size() |
| QApplication.processEvents() |
|
|
| if self.cursur == 0 or self.cursur == self.num_frames-1: |
| break |
|
|
| self.propagating = False |
| self.curr_frame_dirty = False |
| self.on_pause() |
| self.tl_slide() |
| QApplication.processEvents() |
|
|
| def pause_propagation(self): |
| self.propagating = False |
|
|
| def on_commit(self): |
| self.complete_interaction() |
| self.update_interacted_mask() |
|
|
| def on_prev_frame(self): |
| |
| self.cursur = max(0, self.cursur-1) |
| self.tl_slider.setValue(self.cursur) |
|
|
| def on_next_frame(self): |
| |
| self.cursur = min(self.cursur+1, self.num_frames-1) |
| self.tl_slider.setValue(self.cursur) |
|
|
| def on_play_video_timer(self): |
| self.cursur += 1 |
| if self.cursur > self.num_frames-1: |
| self.cursur = 0 |
| self.tl_slider.setValue(self.cursur) |
|
|
| def on_play_video(self): |
| if self.timer.isActive(): |
| self.timer.stop() |
| self.play_button.setText('Play Video') |
| else: |
| self.timer.start(1000 / 30) |
| self.play_button.setText('Stop Video') |
|
|
| def on_reset_mask(self): |
| self.current_mask.fill(0) |
| if self.current_prob is not None: |
| self.current_prob.fill_(0) |
| self.curr_frame_dirty = True |
| self.save_current_mask() |
| self.reset_this_interaction() |
| self.show_current_frame() |
|
|
| def on_zoom_plus(self): |
| self.zoom_pixels -= 25 |
| self.zoom_pixels = max(50, self.zoom_pixels) |
| self.update_minimap() |
|
|
| def on_zoom_minus(self): |
| self.zoom_pixels += 25 |
| self.zoom_pixels = min(self.zoom_pixels, 300) |
| self.update_minimap() |
|
|
| def set_navi_enable(self, boolean): |
| self.zoom_p_button.setEnabled(boolean) |
| self.zoom_m_button.setEnabled(boolean) |
| self.run_button.setEnabled(boolean) |
| self.tl_slider.setEnabled(boolean) |
| self.play_button.setEnabled(boolean) |
| self.lcd.setEnabled(boolean) |
|
|
| def hit_number_key(self, number): |
| if number == self.current_object: |
| return |
| self.current_object = number |
| if self.fbrs_controller is not None: |
| self.fbrs_controller.unanchor() |
| self.console_push_text(f'Current object changed to {number}.') |
| self.clear_brush() |
| self.vis_brush(self.last_ex, self.last_ey) |
| self.update_interact_vis() |
| self.show_current_frame() |
|
|
| def clear_brush(self): |
| self.brush_vis_map.fill(0) |
| self.brush_vis_alpha.fill(0) |
|
|
| def vis_brush(self, ex, ey): |
| self.brush_vis_map = cv2.circle(self.brush_vis_map, |
| (int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1) |
| self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha, |
| (int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1) |
|
|
| def on_mouse_press(self, event): |
| if self.is_pos_out_of_bound(event.x(), event.y()): |
| return |
|
|
| |
| if (event.button() == Qt.MidButton): |
| ex, ey = self.get_scaled_pos(event.x(), event.y()) |
| target_object = self.current_mask[int(ey),int(ex)] |
| if target_object in self.vis_target_objects: |
| self.vis_target_objects.remove(target_object) |
| else: |
| self.vis_target_objects.append(target_object) |
| self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}') |
| self.show_current_frame() |
| return |
|
|
| self.right_click = (event.button() == Qt.RightButton) |
| self.pressed = True |
|
|
| h, w = self.height, self.width |
|
|
| self.load_current_torch_image_mask() |
| image = self.current_image_torch |
|
|
| last_interaction = self.interaction |
| new_interaction = None |
| if self.curr_interaction == 'Scribble': |
| if last_interaction is None or type(last_interaction) != ScribbleInteraction: |
| self.complete_interaction() |
| new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(), |
| (h, w), self.s2m_controller, self.num_objects) |
| elif self.curr_interaction == 'Free': |
| if last_interaction is None or type(last_interaction) != FreeInteraction: |
| self.complete_interaction() |
| new_interaction = FreeInteraction(image, self.current_mask, (h, w), |
| self.num_objects) |
| new_interaction.set_size(self.brush_size) |
| elif self.curr_interaction == 'Click': |
| if (last_interaction is None or type(last_interaction) != ClickInteraction |
| or last_interaction.tar_obj != self.current_object): |
| self.complete_interaction() |
| self.fbrs_controller.unanchor() |
| new_interaction = ClickInteraction(image, self.current_prob, (h, w), |
| self.fbrs_controller, self.current_object) |
|
|
| if new_interaction is not None: |
| self.interaction = new_interaction |
|
|
| |
| self.on_mouse_motion(event) |
|
|
| def on_mouse_motion(self, event): |
| ex, ey = self.get_scaled_pos(event.x(), event.y()) |
| self.last_ex, self.last_ey = ex, ey |
| self.clear_brush() |
| |
| self.vis_brush(ex, ey) |
| if self.pressed: |
| if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free': |
| obj = 0 if self.right_click else self.current_object |
| self.vis_map, self.vis_alpha = self.interaction.push_point( |
| ex, ey, obj, (self.vis_map, self.vis_alpha) |
| ) |
| self.update_interact_vis() |
| self.update_minimap() |
|
|
| def update_interacted_mask(self): |
| self.current_prob = self.interacted_prob |
| self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob) |
| self.show_current_frame() |
| self.save_current_mask() |
| self.curr_frame_dirty = False |
|
|
| def complete_interaction(self): |
| if self.interaction is not None: |
| self.clear_visualization() |
| self.interaction = None |
|
|
| def on_mouse_release(self, event): |
| if not self.pressed: |
| |
| return |
|
|
| ex, ey = self.get_scaled_pos(event.x(), event.y()) |
|
|
| self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur)) |
| interaction = self.interaction |
|
|
| if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free': |
| self.on_mouse_motion(event) |
| interaction.end_path() |
| if self.curr_interaction == 'Free': |
| self.clear_visualization() |
| elif self.curr_interaction == 'Click': |
| ex, ey = self.get_scaled_pos(event.x(), event.y()) |
| self.vis_map, self.vis_alpha = interaction.push_point(ex, ey, |
| self.right_click, (self.vis_map, self.vis_alpha)) |
|
|
| self.interacted_prob = interaction.predict() |
| self.update_interacted_mask() |
| self.update_gpu_usage() |
|
|
| self.pressed = self.right_click = False |
|
|
| def wheelEvent(self, event): |
| ex, ey = self.get_scaled_pos(event.x(), event.y()) |
| if self.curr_interaction == 'Free': |
| self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30) |
| self.clear_brush() |
| self.vis_brush(ex, ey) |
| self.update_interact_vis() |
| self.update_minimap() |
|
|
| def update_gpu_usage(self): |
| info = torch.cuda.mem_get_info() |
| global_free, global_total = info |
| global_free /= (2**30) |
| global_total /= (2**30) |
| global_used = global_total - global_free |
|
|
| self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB') |
| self.gpu_mem_gauge.setValue(round(global_used/global_total*100)) |
|
|
| used_by_torch = torch.cuda.max_memory_allocated() / (2**20) |
| self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB') |
| self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024)) |
|
|
| def on_gpu_timer(self): |
| self.update_gpu_usage() |
|
|
| def update_memory_size(self): |
| try: |
| max_work_elements = self.processor.memory.max_work_elements |
| max_long_elements = self.processor.memory.max_long_elements |
|
|
| curr_work_elements = self.processor.memory.work_mem.size |
| curr_long_elements = self.processor.memory.long_mem.size |
|
|
| self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}') |
| self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100)) |
|
|
| self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}') |
| self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100)) |
|
|
| except AttributeError: |
| self.work_mem_gauge.setFormat('Unknown') |
| self.long_mem_gauge.setFormat('Unknown') |
| self.work_mem_gauge.setValue(0) |
| self.long_mem_gauge.setValue(0) |
|
|
| def on_work_min_change(self): |
| if self.initialized: |
| self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1)) |
| self.update_config() |
|
|
| def on_work_max_change(self): |
| if self.initialized: |
| self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1)) |
| self.update_config() |
|
|
| def update_config(self): |
| if self.initialized: |
| self.config['min_mid_term_frames'] = self.work_mem_min.value() |
| self.config['max_mid_term_frames'] = self.work_mem_max.value() |
| self.config['max_long_term_elements'] = self.long_mem_max.value() |
| self.config['num_prototypes'] = self.num_prototypes_box.value() |
| self.config['mem_every'] = self.mem_every_box.value() |
|
|
| self.processor.update_config(self.config) |
|
|
| def on_clear_memory(self): |
| self.processor.clear_memory() |
| torch.cuda.empty_cache() |
| self.update_gpu_usage() |
| self.update_memory_size() |
|
|
| def _open_file(self, prompt): |
| options = QFileDialog.Options() |
| file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options) |
| return file_name |
|
|
| def on_import_mask(self): |
| file_name = self._open_file('Mask') |
| if len(file_name) == 0: |
| return |
|
|
| mask = self.res_man.read_external_image(file_name, size=(self.height, self.width)) |
|
|
| shape_condition = ( |
| (len(mask.shape) == 2) and |
| (mask.shape[-1] == self.width) and |
| (mask.shape[-2] == self.height) |
| ) |
|
|
| object_condition = ( |
| mask.max() <= self.num_objects |
| ) |
|
|
| if not shape_condition: |
| self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.') |
| elif not object_condition: |
| self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.') |
| else: |
| self.console_push_text(f'Mask file {file_name} loaded.') |
| self.current_image_torch = self.current_prob = None |
| self.current_mask = mask |
| self.show_current_frame() |
| self.save_current_mask() |
|
|
| def on_import_layer(self): |
| file_name = self._open_file('Layer') |
| if len(file_name) == 0: |
| return |
|
|
| self._try_load_layer(file_name) |
|
|
| def _try_load_layer(self, file_name): |
| try: |
| layer = self.res_man.read_external_image(file_name, size=(self.height, self.width)) |
|
|
| if layer.shape[-1] == 3: |
| layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1) |
|
|
| condition = ( |
| (len(layer.shape) == 3) and |
| (layer.shape[-1] == 4) and |
| (layer.shape[-2] == self.width) and |
| (layer.shape[-3] == self.height) |
| ) |
|
|
| if not condition: |
| self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.') |
| else: |
| self.console_push_text(f'Layer file {file_name} loaded.') |
| self.overlay_layer = layer |
| self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255 |
| self.show_current_frame() |
| except FileNotFoundError: |
| self.console_push_text(f'{file_name} not found.') |
|
|
| def on_save_visualization_toggle(self): |
| self.save_visualization = self.save_visualization_checkbox.isChecked() |
|
|