| | |
| | from typing import Dict, List, Optional |
| |
|
| | import cv2 |
| | import mmcv |
| | import numpy as np |
| | import torch |
| | from mmengine.dist import master_only |
| | from mmengine.structures import PixelData |
| | from mmengine.visualization import Visualizer |
| |
|
| | from mmseg.registry import VISUALIZERS |
| | from mmseg.structures import SegDataSample |
| | from mmseg.utils import get_classes, get_palette |
| |
|
| |
|
| | @VISUALIZERS.register_module() |
| | class SegLocalVisualizer(Visualizer): |
| | """Local Visualizer. |
| | |
| | Args: |
| | name (str): Name of the instance. Defaults to 'visualizer'. |
| | image (np.ndarray, optional): the origin image to draw. The format |
| | should be RGB. Defaults to None. |
| | vis_backends (list, optional): Visual backend config list. |
| | Defaults to None. |
| | save_dir (str, optional): Save file dir for all storage backends. |
| | If it is None, the backend storage will not save any data. |
| | classes (list, optional): Input classes for result rendering, as the |
| | prediction of segmentation model is a segment map with label |
| | indices, `classes` is a list which includes items responding to the |
| | label indices. If classes is not defined, visualizer will take |
| | `cityscapes` classes by default. Defaults to None. |
| | palette (list, optional): Input palette for result rendering, which is |
| | a list of color palette responding to the classes. Defaults to None. |
| | dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_ |
| | visulizer will use the meta information of the dataset i.e. classes |
| | and palette, but the `classes` and `palette` have higher priority. |
| | Defaults to None. |
| | alpha (int, float): The transparency of segmentation mask. |
| | Defaults to 0.8. |
| | |
| | Examples: |
| | >>> import numpy as np |
| | >>> import torch |
| | >>> from mmengine.structures import PixelData |
| | >>> from mmseg.structures import SegDataSample |
| | >>> from mmseg.visualization import SegLocalVisualizer |
| | |
| | >>> seg_local_visualizer = SegLocalVisualizer() |
| | >>> image = np.random.randint(0, 256, |
| | ... size=(10, 12, 3)).astype('uint8') |
| | >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) |
| | >>> gt_sem_seg = PixelData(**gt_sem_seg_data) |
| | >>> gt_seg_data_sample = SegDataSample() |
| | >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg |
| | >>> seg_local_visualizer.dataset_meta = dict( |
| | >>> classes=('background', 'foreground'), |
| | >>> palette=[[120, 120, 120], [6, 230, 230]]) |
| | >>> seg_local_visualizer.add_datasample('visualizer_example', |
| | ... image, gt_seg_data_sample) |
| | >>> seg_local_visualizer.add_datasample( |
| | ... 'visualizer_example', image, |
| | ... gt_seg_data_sample, show=True) |
| | """ |
| |
|
| | def __init__(self, |
| | name: str = 'visualizer', |
| | image: Optional[np.ndarray] = None, |
| | vis_backends: Optional[Dict] = None, |
| | save_dir: Optional[str] = None, |
| | classes: Optional[List] = None, |
| | palette: Optional[List] = None, |
| | dataset_name: Optional[str] = None, |
| | alpha: float = 0.8, |
| | **kwargs): |
| | super().__init__(name, image, vis_backends, save_dir, **kwargs) |
| | self.alpha: float = alpha |
| | self.set_dataset_meta(palette, classes, dataset_name) |
| |
|
| | def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: |
| | """Get semantic seg center coordinate. |
| | |
| | Args: |
| | mask: np.ndarray: get from sem_seg |
| | """ |
| | loc = np.argwhere(mask == 1) |
| |
|
| | loc_sort = np.array( |
| | sorted(loc.tolist(), key=lambda row: (row[0], row[1]))) |
| | y_list = loc_sort[:, 0] |
| | unique, indices, counts = np.unique( |
| | y_list, return_index=True, return_counts=True) |
| | y_loc = unique[counts.argmax()] |
| | y_most_freq_loc = loc[loc_sort[:, 0] == y_loc] |
| | center_num = len(y_most_freq_loc) // 2 |
| | x = y_most_freq_loc[center_num][1] |
| | y = y_most_freq_loc[center_num][0] |
| | return np.array([x, y]) |
| |
|
| | def _draw_sem_seg(self, |
| | image: np.ndarray, |
| | sem_seg: PixelData, |
| | classes: Optional[List], |
| | palette: Optional[List], |
| | with_labels: Optional[bool] = True) -> np.ndarray: |
| | """Draw semantic seg of GT or prediction. |
| | |
| | Args: |
| | image (np.ndarray): The image to draw. |
| | sem_seg (:obj:`PixelData`): Data structure for pixel-level |
| | annotations or predictions. |
| | classes (list, optional): Input classes for result rendering, as |
| | the prediction of segmentation model is a segment map with |
| | label indices, `classes` is a list which includes items |
| | responding to the label indices. If classes is not defined, |
| | visualizer will take `cityscapes` classes by default. |
| | Defaults to None. |
| | palette (list, optional): Input palette for result rendering, which |
| | is a list of color palette responding to the classes. |
| | Defaults to None. |
| | with_labels(bool, optional): Add semantic labels in visualization |
| | result, Default to True. |
| | |
| | Returns: |
| | np.ndarray: the drawn image which channel is RGB. |
| | """ |
| | num_classes = len(classes) |
| |
|
| | sem_seg = sem_seg.cpu().data |
| | ids = np.unique(sem_seg)[::-1] |
| | legal_indices = ids < num_classes |
| | ids = ids[legal_indices] |
| | labels = np.array(ids, dtype=np.int64) |
| |
|
| | colors = [palette[label] for label in labels] |
| |
|
| | mask = np.zeros_like(image, dtype=np.uint8) |
| | for label, color in zip(labels, colors): |
| | mask[sem_seg[0] == label, :] = color |
| |
|
| | if with_labels: |
| | font = cv2.FONT_HERSHEY_SIMPLEX |
| | |
| | scale = 0.05 |
| | fontScale = min(image.shape[0], image.shape[1]) / (25 / scale) |
| | fontColor = (255, 255, 255) |
| | if image.shape[0] < 300 or image.shape[1] < 300: |
| | thickness = 1 |
| | rectangleThickness = 1 |
| | else: |
| | thickness = 2 |
| | rectangleThickness = 2 |
| | lineType = 2 |
| |
|
| | if isinstance(sem_seg[0], torch.Tensor): |
| | masks = sem_seg[0].numpy() == labels[:, None, None] |
| | else: |
| | masks = sem_seg[0] == labels[:, None, None] |
| | masks = masks.astype(np.uint8) |
| | for mask_num in range(len(labels)): |
| | classes_id = labels[mask_num] |
| | classes_color = colors[mask_num] |
| | loc = self._get_center_loc(masks[mask_num]) |
| | text = classes[classes_id] |
| | (label_width, label_height), baseline = cv2.getTextSize( |
| | text, font, fontScale, thickness) |
| | mask = cv2.rectangle(mask, loc, |
| | (loc[0] + label_width + baseline, |
| | loc[1] + label_height + baseline), |
| | classes_color, -1) |
| | mask = cv2.rectangle(mask, loc, |
| | (loc[0] + label_width + baseline, |
| | loc[1] + label_height + baseline), |
| | (0, 0, 0), rectangleThickness) |
| | mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height), |
| | font, fontScale, fontColor, thickness, |
| | lineType) |
| | color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( |
| | np.uint8) |
| | self.set_image(color_seg) |
| | return color_seg |
| |
|
| | def _draw_depth_map(self, image: np.ndarray, |
| | depth_map: PixelData) -> np.ndarray: |
| | """Draws a depth map on a given image. |
| | |
| | This function takes an image and a depth map as input, |
| | renders the depth map, and concatenates it with the original image. |
| | Finally, it updates the internal image state of the visualizer with |
| | the concatenated result. |
| | |
| | Args: |
| | image (np.ndarray): The original image where the depth map will |
| | be drawn. The array should be in the format HxWx3 where H is |
| | the height, W is the width. |
| | |
| | depth_map (PixelData): Depth map to be drawn. The depth map |
| | should be in the form of a PixelData object. It will be |
| | converted to a torch tensor if it is a numpy array. |
| | |
| | Returns: |
| | np.ndarray: The concatenated image with the depth map drawn. |
| | |
| | Example: |
| | >>> depth_map_data = PixelData(data=torch.rand(1, 10, 10)) |
| | >>> image = np.random.randint(0, 256, |
| | >>> size=(10, 10, 3)).astype('uint8') |
| | >>> visualizer = SegLocalVisualizer() |
| | >>> visualizer._draw_depth_map(image, depth_map_data) |
| | """ |
| | depth_map = depth_map.cpu().data |
| | if isinstance(depth_map, np.ndarray): |
| | depth_map = torch.from_numpy(depth_map) |
| | if depth_map.ndim == 2: |
| | depth_map = depth_map[None] |
| |
|
| | depth_map = self.draw_featmap(depth_map, resize_shape=image.shape[:2]) |
| | out_image = np.concatenate((image, depth_map), axis=0) |
| | self.set_image(out_image) |
| | return out_image |
| |
|
| | def set_dataset_meta(self, |
| | classes: Optional[List] = None, |
| | palette: Optional[List] = None, |
| | dataset_name: Optional[str] = None) -> None: |
| | """Set meta information to visualizer. |
| | |
| | Args: |
| | classes (list, optional): Input classes for result rendering, as |
| | the prediction of segmentation model is a segment map with |
| | label indices, `classes` is a list which includes items |
| | responding to the label indices. If classes is not defined, |
| | visualizer will take `cityscapes` classes by default. |
| | Defaults to None. |
| | palette (list, optional): Input palette for result rendering, which |
| | is a list of color palette responding to the classes. |
| | Defaults to None. |
| | dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_ |
| | visulizer will use the meta information of the dataset i.e. |
| | classes and palette, but the `classes` and `palette` have |
| | higher priority. Defaults to None. |
| | """ |
| | |
| | |
| | |
| | if dataset_name is None: |
| | dataset_name = 'cityscapes' |
| | classes = classes if classes else get_classes(dataset_name) |
| | palette = palette if palette else get_palette(dataset_name) |
| | assert len(classes) == len( |
| | palette), 'The length of classes should be equal to palette' |
| | self.dataset_meta: dict = {'classes': classes, 'palette': palette} |
| |
|
| | @master_only |
| | def add_datasample( |
| | self, |
| | name: str, |
| | image: np.ndarray, |
| | data_sample: Optional[SegDataSample] = None, |
| | draw_gt: bool = True, |
| | draw_pred: bool = True, |
| | show: bool = False, |
| | wait_time: float = 0, |
| | |
| | out_file: Optional[str] = None, |
| | step: int = 0, |
| | with_labels: Optional[bool] = True) -> None: |
| | """Draw datasample and save to all backends. |
| | |
| | - If GT and prediction are plotted at the same time, they are |
| | displayed in a stitched image where the left image is the |
| | ground truth and the right image is the prediction. |
| | - If ``show`` is True, all storage backends are ignored, and |
| | the images will be displayed in a local window. |
| | - If ``out_file`` is specified, the drawn image will be |
| | saved to ``out_file``. it is usually used when the display |
| | is not available. |
| | |
| | Args: |
| | name (str): The image identifier. |
| | image (np.ndarray): The image to draw. |
| | gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. |
| | Defaults to None. |
| | pred_sample (:obj:`SegDataSample`, optional): Prediction |
| | SegDataSample. Defaults to None. |
| | draw_gt (bool): Whether to draw GT SegDataSample. Default to True. |
| | draw_pred (bool): Whether to draw Prediction SegDataSample. |
| | Defaults to True. |
| | show (bool): Whether to display the drawn image. Default to False. |
| | wait_time (float): The interval of show (s). Defaults to 0. |
| | out_file (str): Path to output file. Defaults to None. |
| | step (int): Global step value to record. Defaults to 0. |
| | with_labels(bool, optional): Add semantic labels in visualization |
| | result, Defaults to True. |
| | """ |
| | classes = self.dataset_meta.get('classes', None) |
| | palette = self.dataset_meta.get('palette', None) |
| |
|
| | gt_img_data = None |
| | pred_img_data = None |
| |
|
| | if draw_gt and data_sample is not None: |
| | if 'gt_sem_seg' in data_sample: |
| | assert classes is not None, 'class information is ' \ |
| | 'not provided when ' \ |
| | 'visualizing semantic ' \ |
| | 'segmentation results.' |
| | gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg, |
| | classes, palette, with_labels) |
| |
|
| | if 'gt_depth_map' in data_sample: |
| | gt_img_data = gt_img_data if gt_img_data is not None else image |
| | gt_img_data = self._draw_depth_map(gt_img_data, |
| | data_sample.gt_depth_map) |
| |
|
| | if draw_pred and data_sample is not None: |
| |
|
| | if 'pred_sem_seg' in data_sample: |
| |
|
| | assert classes is not None, 'class information is ' \ |
| | 'not provided when ' \ |
| | 'visualizing semantic ' \ |
| | 'segmentation results.' |
| | pred_img_data = self._draw_sem_seg(image, |
| | data_sample.pred_sem_seg, |
| | classes, palette, |
| | with_labels) |
| |
|
| | if 'pred_depth_map' in data_sample: |
| | pred_img_data = pred_img_data if pred_img_data is not None \ |
| | else image |
| | pred_img_data = self._draw_depth_map( |
| | pred_img_data, data_sample.pred_depth_map) |
| |
|
| | if gt_img_data is not None and pred_img_data is not None: |
| | drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) |
| | elif gt_img_data is not None: |
| | drawn_img = gt_img_data |
| | else: |
| | drawn_img = pred_img_data |
| |
|
| | if show: |
| | self.show(drawn_img, win_name=name, wait_time=wait_time) |
| |
|
| | if out_file is not None: |
| | mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file) |
| | else: |
| | self.add_image(name, drawn_img, step) |
| |
|