| """Utility functions for visualization. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| from matplotlib import pyplot as plt |
|
|
| METRIC_DEPTH_MAX_CLAMP_METER = 50.0 |
|
|
|
|
| def colorize_depth(depth: torch.Tensor, val_max: float = 10.0) -> torch.Tensor: |
| """Colorize depth map.""" |
| depth_channels = depth.shape[-3] |
|
|
| |
| if depth_channels == 1: |
| return colorize_scalar_map( |
| depth.squeeze(-3), val_min=0.0, val_max=val_max, color_map="turbo" |
| ) |
|
|
| |
| |
| else: |
| colored_depths = [] |
| for c in range(depth_channels): |
| colored_depths.append( |
| colorize_scalar_map( |
| depth[..., c, :, :], val_min=0.0, val_max=val_max, color_map="turbo" |
| ) |
| ) |
| return torch.cat(colored_depths, dim=-1) |
|
|
|
|
| def colorize_alpha(alpha: torch.Tensor) -> torch.Tensor: |
| """Colorize alpha map.""" |
| return colorize_scalar_map(alpha.squeeze(-3), val_min=0.0, val_max=1.0, color_map="coolwarm") |
|
|
|
|
| def colorize_scalar_map( |
| scalar_map: torch.Tensor, val_min=0.0, val_max=1.0, color_map: str = "jet" |
| ) -> torch.Tensor: |
| """Colorize a scalar map of. |
| |
| Args: |
| scalar_map: Map of with format BHW. |
| val_min: Minimu value to display. |
| val_max: Maximum value to display. |
| color_map: Which color map to use. Will be passed to matplotlob. |
| |
| Returns: |
| A colorized image with format BHWC. |
| """ |
| if scalar_map.ndim not in (2, 3, 4): |
| raise ValueError("Only scalar maps of 2 or 3 or 4 dimensions supported.") |
|
|
| cmap = plt.get_cmap(color_map) |
|
|
| scalar_map_np = scalar_map.detach().cpu().float().numpy() |
| scalar_map_np = (scalar_map_np - val_min) / (val_max - val_min) |
| scalar_map_np = np.clip(scalar_map_np, a_min=0.0, a_max=1.0) |
|
|
| color_map_np = cmap(scalar_map_np)[..., :3] |
| tensor = torch.as_tensor(color_map_np * 255.0, dtype=torch.uint8) |
|
|
| if tensor.ndim == 3: |
| return tensor.permute(2, 0, 1) |
| elif tensor.ndim == 4: |
| return tensor.permute(0, 3, 1, 2) |
| elif tensor.ndim == 5: |
| return tensor.permute(0, 1, 4, 2, 3) |
| else: |
| assert False, "Invalid tensor shape encountered." |
|
|