| |
| |
| |
| |
| |
|
|
| import contextlib |
| import math |
| import os |
| import unittest |
| from typing import Tuple |
|
|
| import torch |
| from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset |
| from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud |
|
|
| from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround |
| from pytorch3d.implicitron.tools.config import expand_args_fields |
| from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d |
| from pytorch3d.renderer.cameras import CamerasBase |
| from tests.common_testing import interactive_testing_requested |
| from visdom import Visdom |
|
|
| from .common_resources import get_skateboard_data |
|
|
|
|
| class TestModelVisualize(unittest.TestCase): |
| def test_flyaround_one_sequence( |
| self, |
| image_size: int = 256, |
| ): |
| if not interactive_testing_requested(): |
| return |
| category = "skateboard" |
| stack = contextlib.ExitStack() |
| dataset_root, path_manager = stack.enter_context(get_skateboard_data()) |
| self.addCleanup(stack.close) |
| frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") |
| sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") |
| subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") |
| expand_args_fields(JsonIndexDataset) |
| train_dataset = JsonIndexDataset( |
| frame_annotations_file=frame_file, |
| sequence_annotations_file=sequence_file, |
| subset_lists_file=subset_lists_file, |
| dataset_root=dataset_root, |
| image_height=image_size, |
| image_width=image_size, |
| box_crop=True, |
| load_point_clouds=True, |
| path_manager=path_manager, |
| subsets=[ |
| "train_known", |
| ], |
| ) |
|
|
| |
| sequence_names = list(train_dataset.seq_annots.keys()) |
|
|
| |
| show_sequence_name = sequence_names[0] |
|
|
| output_dir = os.path.split(os.path.abspath(__file__))[0] |
|
|
| visdom_show_preds = Visdom().check_connection() |
|
|
| for load_dataset_pointcloud in [True, False]: |
|
|
| model = _PointcloudRenderingModel( |
| train_dataset, |
| show_sequence_name, |
| device="cuda:0", |
| load_dataset_pointcloud=load_dataset_pointcloud, |
| ) |
|
|
| video_path = os.path.join( |
| output_dir, |
| f"load_pcl_{load_dataset_pointcloud}", |
| ) |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| for output_video_frames_dir in [None, video_path]: |
| render_flyaround( |
| train_dataset, |
| show_sequence_name, |
| model, |
| video_path, |
| n_flyaround_poses=10, |
| fps=5, |
| max_angle=2 * math.pi, |
| trajectory_type="circular_lsq_fit", |
| trajectory_scale=1.1, |
| scene_center=(0.0, 0.0, 0.0), |
| up=(0.0, 1.0, 0.0), |
| traj_offset=1.0, |
| n_source_views=1, |
| visdom_show_preds=visdom_show_preds, |
| visdom_environment="test_model_visalize", |
| visdom_server="http://127.0.0.1", |
| visdom_port=8097, |
| num_workers=10, |
| seed=None, |
| video_resize=None, |
| visualize_preds_keys=[ |
| "images_render", |
| "depths_render", |
| "masks_render", |
| "_all_source_images", |
| ], |
| output_video_frames_dir=output_video_frames_dir, |
| ) |
|
|
|
|
| class _PointcloudRenderingModel(torch.nn.Module): |
| def __init__( |
| self, |
| train_dataset: JsonIndexDataset, |
| sequence_name: str, |
| render_size: Tuple[int, int] = (400, 400), |
| device=None, |
| load_dataset_pointcloud: bool = False, |
| max_frames: int = 30, |
| num_workers: int = 10, |
| ): |
| super().__init__() |
| self._render_size = render_size |
| point_cloud, _ = get_implicitron_sequence_pointcloud( |
| train_dataset, |
| sequence_name=sequence_name, |
| mask_points=True, |
| max_frames=max_frames, |
| num_workers=num_workers, |
| load_dataset_point_cloud=load_dataset_pointcloud, |
| ) |
| self._point_cloud = point_cloud.to(device) |
|
|
| def forward( |
| self, |
| camera: CamerasBase, |
| **kwargs, |
| ): |
| image_render, mask_render, depth_render = render_point_cloud_pytorch3d( |
| camera[0], |
| self._point_cloud, |
| render_size=self._render_size, |
| point_radius=1e-2, |
| topk=10, |
| bg_color=0.0, |
| ) |
| return { |
| "images_render": image_render.clamp(0.0, 1.0), |
| "masks_render": mask_render, |
| "depths_render": depth_render, |
| } |
|
|