| | """Visualization of predicted and ground truth for a single batch.""" |
| | """Adapted from https://github.com/cvg/GeoCalib""" |
| |
|
| | from typing import Any, Dict |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from scripts.camera.geometry.perspective_fields import get_latitude_field |
| | from scripts.camera.utils.conversions import rad2deg |
| | from scripts.camera.utils.tensor import batch_to_device |
| | from scripts.camera.visualization.viz2d import ( |
| | plot_confidences, |
| | plot_heatmaps, |
| | plot_image_grid, |
| | plot_latitudes, |
| | plot_vector_fields, |
| | ) |
| |
|
| |
|
| | def make_up_figure( |
| | pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 |
| | ) -> Dict[str, Any]: |
| | """Get predicted and ground truth up fields and errors. |
| | |
| | Args: |
| | pred (Dict[str, torch.Tensor]): Predicted up field. |
| | data (Dict[str, torch.Tensor]): Ground truth up field. |
| | n_pairs (int): Number of pairs to visualize. |
| | |
| | Returns: |
| | Dict[str, Any]: Dictionary with figure. |
| | """ |
| | pred = batch_to_device(pred, "cpu", detach=True) |
| | data = batch_to_device(data, "cpu", detach=True) |
| |
|
| | n_pairs = min(n_pairs, len(data["image"])) |
| |
|
| | if "up_field" not in pred.keys(): |
| | return {} |
| |
|
| | up_fields = [] |
| | for i in range(n_pairs): |
| | row = [data["up_field"][i]] |
| | titles = ["Up GT"] |
| |
|
| | if "up_confidence" in pred.keys(): |
| | row += [pred["up_confidence"][i]] |
| | titles += ["Up Confidence"] |
| |
|
| | row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row] |
| | up_fields.append(row) |
| |
|
| | |
| | N, M = len(up_fields), len(up_fields[0]) + 1 |
| | imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)] |
| | fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True) |
| | ax = np.array(ax) |
| |
|
| | for i in range(n_pairs): |
| | plot_vector_fields([up_fields[i][0]], axes=ax[i, [1]]) |
| | |
| |
|
| | if "up_confidence" in pred.keys(): |
| | plot_confidences([up_fields[i][3]], axes=ax[i, [4]]) |
| |
|
| | return {"up": fig} |
| |
|
| |
|
| | def make_latitude_figure( |
| | pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 |
| | ) -> Dict[str, Any]: |
| | """Get predicted and ground truth latitude fields and errors. |
| | |
| | Args: |
| | pred (Dict[str, torch.Tensor]): Predicted latitude field. |
| | data (Dict[str, torch.Tensor]): Ground truth latitude field. |
| | n_pairs (int, optional): Number of pairs to visualize. Defaults to 2. |
| | |
| | Returns: |
| | Dict[str, Any]: Dictionary with figure. |
| | """ |
| | pred = batch_to_device(pred, "cpu", detach=True) |
| | data = batch_to_device(data, "cpu", detach=True) |
| |
|
| | n_pairs = min(n_pairs, len(data["image"])) |
| | latitude_fields = [] |
| |
|
| | if "latitude_field" not in pred.keys(): |
| | return {} |
| |
|
| | for i in range(n_pairs): |
| | row = [ |
| | rad2deg(data["latitude_field"][i][0]), |
| | |
| | |
| | ] |
| | titles = ["Latitude GT"] |
| |
|
| | if "latitude_confidence" in pred.keys(): |
| | row += [pred["latitude_confidence"][i]] |
| | titles += ["Latitude Confidence"] |
| |
|
| | row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row] |
| | latitude_fields.append(row) |
| |
|
| | |
| | N, M = len(latitude_fields), len(latitude_fields[0]) + 1 |
| | imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)] |
| | fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True) |
| | ax = np.array(ax) |
| |
|
| | for i in range(n_pairs): |
| | plot_latitudes([latitude_fields[i][0]], is_radians=False, axes=ax[i, [1]]) |
| | |
| |
|
| | if "latitude_confidence" in pred.keys(): |
| | plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]]) |
| |
|
| | return {"latitude": fig} |
| |
|
| |
|
| | def make_camera_figure( |
| | pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 |
| | ) -> Dict[str, Any]: |
| | """Get predicted and ground truth camera parameters. |
| | |
| | Args: |
| | pred (Dict[str, torch.Tensor]): Predicted camera parameters. |
| | data (Dict[str, torch.Tensor]): Ground truth camera parameters. |
| | n_pairs (int, optional): Number of pairs to visualize. Defaults to 2. |
| | |
| | Returns: |
| | Dict[str, Any]: Dictionary with figure. |
| | """ |
| | pred = batch_to_device(pred, "cpu", detach=True) |
| | data = batch_to_device(data, "cpu", detach=True) |
| |
|
| | n_pairs = min(n_pairs, len(data["image"])) |
| |
|
| | if "camera" not in pred.keys(): |
| | return {} |
| |
|
| | latitudes = [] |
| | for i in range(n_pairs): |
| | titles = ["Cameras GT"] |
| | row = [get_latitude_field(data["camera"][i], data["gravity"][i])] |
| |
|
| | if "camera" in pred.keys() and "gravity" in pred.keys(): |
| | row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])] |
| | titles += ["Cameras Pred"] |
| |
|
| | row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row] |
| | latitudes.append(row) |
| |
|
| | |
| | N, M = len(latitudes), len(latitudes[0]) + 1 |
| | imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)] |
| | fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True) |
| | ax = np.array(ax) |
| |
|
| | for i in range(n_pairs): |
| | plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:]) |
| |
|
| | return {"camera": fig} |
| |
|
| |
|
| | def make_perspective_figures( |
| | pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 |
| | ) -> Dict[str, Any]: |
| | """Get predicted and ground truth perspective fields. |
| | |
| | Args: |
| | pred (Dict[str, torch.Tensor]): Predicted perspective fields. |
| | data (Dict[str, torch.Tensor]): Ground truth perspective fields. |
| | n_pairs (int, optional): Number of pairs to visualize. Defaults to 2. |
| | |
| | Returns: |
| | Dict[str, Any]: Dictionary with figure. |
| | """ |
| | n_pairs = min(n_pairs, len(data["image"])) |
| | figures = make_up_figure(pred, data, n_pairs) |
| | figures |= make_latitude_figure(pred, data, n_pairs) |
| | |
| |
|
| | {f.tight_layout() for f in figures.values()} |
| |
|
| | return figures |
| |
|