| """Contains modules to initialize Gaussians from RGBD. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import NamedTuple |
|
|
| import torch |
| from torch import nn |
|
|
| from .params import ColorInitOption, DepthInitOption, InitializerParams |
|
|
|
|
| def create_initializer(params: InitializerParams) -> nn.Module: |
| """Create inpainter.""" |
| return MultiLayerInitializer( |
| num_layers=params.num_layers, |
| stride=params.stride, |
| base_depth=params.base_depth, |
| scale_factor=params.scale_factor, |
| disparity_factor=params.disparity_factor, |
| color_option=params.color_option, |
| first_layer_depth_option=params.first_layer_depth_option, |
| rest_layer_depth_option=params.rest_layer_depth_option, |
| normalize_depth=params.normalize_depth, |
| feature_input_stop_grad=params.feature_input_stop_grad, |
| ) |
|
|
|
|
| class GaussianBaseValues(NamedTuple): |
| """Base values for gaussian predictor. |
| |
| We predict x and y in normalized device coordinates (NDC) where (-1, -1) is the top |
| left corner and (1, 1) the bottom right corner. The last component of |
| mean_vectors_ndc is inverse depth. |
| """ |
|
|
| mean_x_ndc: torch.Tensor |
| mean_y_ndc: torch.Tensor |
| mean_inverse_z_ndc: torch.Tensor |
|
|
| scales: torch.Tensor |
| quaternions: torch.Tensor |
| colors: torch.Tensor |
| opacities: torch.Tensor |
|
|
|
|
| class InitializerOutput(NamedTuple): |
| """Output of initializer.""" |
|
|
| |
| gaussian_base_values: GaussianBaseValues |
|
|
| |
| feature_input: torch.Tensor |
|
|
| |
| global_scale: torch.Tensor | None = None |
|
|
|
|
| class MultiLayerInitializer(nn.Module): |
| """Initialize Gaussians with multilayer representation. |
| |
| The returned tensors have the shape |
| |
| batch_size x dim x num_layers x height x width |
| |
| where dim indicates the dimensionality of the property. |
| Some of the dimensions might be set to 1 for efficiency reasons. |
| """ |
|
|
| def __init__( |
| self, |
| num_layers: int, |
| stride: int, |
| base_depth: float, |
| scale_factor: float, |
| disparity_factor: float, |
| color_option: ColorInitOption = "first_layer", |
| first_layer_depth_option: DepthInitOption = "surface_min", |
| rest_layer_depth_option: DepthInitOption = "surface_min", |
| normalize_depth: bool = True, |
| feature_input_stop_grad: bool = True, |
| ) -> None: |
| """Initialize MultilayerInitializer. |
| |
| Args: |
| stride: The downsample rate of output feature map. |
| base_depth: The depth of the first layer (after the foreground |
| layer if use_depth=True). |
| scale_factor: Multiply scale of Gaussians by this factor. |
| disparity_factor: Factor to convert inverse depth to disparity. |
| num_layers: How many layers of Gaussians to predict. |
| color_option: Which color option to initialize the multi-layer gaussians. |
| first_layer_depth_option: Which depth option to initialize the first layer of gaussians. |
| rest_layer_depth_option: Which depth option to initialize the rest layers of gaussians. |
| normalize_depth: # Whether to normalize depth to [DepthTransformParam.depth_min, |
| DepthTransformParam.depth_max). |
| feature_input_stop_grad: Whether to not propagate gradients through feature inputs. |
| """ |
| super().__init__() |
| self.num_layers = num_layers |
| self.stride = stride |
| self.base_depth = base_depth |
| self.scale_factor = scale_factor |
| self.disparity_factor = disparity_factor |
| self.color_option = color_option |
| self.first_layer_depth_option = first_layer_depth_option |
| self.rest_layer_depth_option = rest_layer_depth_option |
| self.normalize_depth = normalize_depth |
| self.feature_input_stop_grad = feature_input_stop_grad |
|
|
| def prepare_feature_input(self, image: torch.Tensor, depth: torch.Tensor) -> torch.Tensor: |
| """Prepare the feature input to the Guassian predictor.""" |
| if self.feature_input_stop_grad: |
| image = image.detach() |
| depth = depth.detach() |
|
|
| normalized_disparity = self.disparity_factor / depth |
| features_in = torch.cat([image, normalized_disparity], dim=1) |
| features_in = 2.0 * features_in - 1.0 |
| return features_in |
|
|
| def forward(self, image: torch.Tensor, depth: torch.Tensor) -> InitializerOutput: |
| """Construct Gaussian base values and prepare feature input. |
| |
| Args: |
| image: The image to process. |
| depth: The corresponding depth map from the monodepth network. |
| |
| Returns: |
| The base value for Gaussians. |
| """ |
| image = image.contiguous() |
| depth = depth.contiguous() |
| device = depth.device |
| batch_size, _, image_height, image_width = depth.shape |
| base_height, base_width = ( |
| image_height // self.stride, |
| image_width // self.stride, |
| ) |
| |
| |
| global_scale: torch.Tensor | None = None |
| if self.normalize_depth: |
| depth, depth_factor = _rescale_depth(depth) |
| global_scale = 1.0 / depth_factor |
|
|
| def _create_disparity_layers(num_layers: int = 1) -> torch.Tensor: |
| """Create multiple disparity layers.""" |
| disparity = torch.linspace(1.0 / self.base_depth, 0.0, num_layers + 1, device=device) |
| return disparity[None, None, :-1, None, None].repeat( |
| batch_size, 1, 1, base_height, base_width |
| ) |
|
|
| def _create_surface_layer( |
| depth: torch.Tensor, |
| depth_pooling_mode: str, |
| ) -> torch.Tensor: |
| """Create multiple surface layers.""" |
| disparity = 1.0 / depth |
| if depth_pooling_mode == "min": |
| disparity = torch.max_pool2d(disparity, self.stride, self.stride) |
| elif depth_pooling_mode == "max": |
| disparity = -torch.max_pool2d(-disparity, self.stride, self.stride) |
| else: |
| raise ValueError(f"Invalid depth pooling mode {depth_pooling_mode}.") |
|
|
| return disparity[:, :, None, :, :] |
|
|
| |
| |
|
|
| |
| |
| if self.first_layer_depth_option == "surface_min": |
| first_disparity = _create_surface_layer(depth[:, 0:1], "min") |
| elif self.first_layer_depth_option == "surface_max": |
| first_disparity = _create_surface_layer(depth[:, 0:1], "max") |
| elif self.first_layer_depth_option in ("base_depth", "linear_disparity"): |
| first_disparity = _create_disparity_layers() |
| else: |
| raise ValueError(f"Unknown depth init option: {self.first_layer_depth_option}.") |
|
|
| if self.num_layers == 1: |
| disparity = first_disparity |
| else: |
| following_depth = depth if depth.shape[1] == 1 else depth[:, 1:] |
| if self.rest_layer_depth_option == "surface_min": |
| following_disparity = _create_surface_layer(following_depth, "min") |
| elif self.rest_layer_depth_option == "surface_max": |
| following_disparity = _create_surface_layer(following_depth, "max") |
| elif self.rest_layer_depth_option == "base_depth": |
| following_disparity = torch.cat( |
| [_create_disparity_layers() for i in range(self.num_layers - 1)], |
| dim=2, |
| ) |
| elif self.rest_layer_depth_option == "linear_disparity": |
| following_disparity = _create_disparity_layers(self.num_layers - 1) |
| else: |
| raise ValueError(f"Unknown depth init option: {self.rest_layer_depth_option}.") |
|
|
| disparity = torch.cat([first_disparity, following_disparity], dim=2) |
|
|
| |
| base_x_ndc, base_y_ndc = _create_base_xy(depth, self.stride, self.num_layers) |
| disparity_scale_factor = 2 * self.scale_factor * self.stride / float(image_width) |
| base_scales = _create_base_scale(disparity, disparity_scale_factor) |
|
|
| base_quaternions = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device) |
| base_quaternions = base_quaternions[None, :, None, None, None] |
|
|
| |
| |
| |
| |
| |
| |
| |
| base_opacities = torch.tensor([min(1.0 / self.num_layers, 0.5)], device=device) |
| base_colors = torch.empty( |
| batch_size, 3, self.num_layers, base_height, base_width, device=device |
| ).fill_(0.5) |
| |
| if self.color_option == "none": |
| pass |
| elif self.color_option == "first_layer": |
| base_colors[:, :, 0] = torch.nn.functional.avg_pool2d(image, self.stride, self.stride) |
| elif self.color_option == "all_layers": |
| temp = torch.nn.functional.avg_pool2d(image, self.stride, self.stride) |
| base_colors = temp[:, :, None, :, :].repeat(1, 1, self.num_layers, 1, 1) |
| else: |
| raise ValueError(f"Unknown color init option: {self.color_option}.") |
|
|
| features_in = self.prepare_feature_input(image, depth) |
| base_gaussians = GaussianBaseValues( |
| mean_x_ndc=base_x_ndc, |
| mean_y_ndc=base_y_ndc, |
| mean_inverse_z_ndc=disparity, |
| scales=base_scales, |
| quaternions=base_quaternions, |
| colors=base_colors, |
| opacities=base_opacities, |
| ) |
|
|
| return InitializerOutput( |
| gaussian_base_values=base_gaussians, |
| feature_input=features_in, |
| global_scale=global_scale, |
| ) |
|
|
|
|
| def _create_base_xy( |
| depth: torch.Tensor, stride: int, num_layers: int |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Create base x and y coordinates for the gaussians in NDC space.""" |
| device = depth.device |
| batch_size, _, image_height, image_width = depth.shape |
| xx = torch.arange(0.5 * stride, image_width, stride, device=device) |
| yy = torch.arange(0.5 * stride, image_height, stride, device=device) |
| xx = 2 * xx / image_width - 1.0 |
| yy = 2 * yy / image_height - 1.0 |
|
|
| xx, yy = torch.meshgrid(xx, yy, indexing="xy") |
| base_x_ndc = xx[None, None, None].repeat(batch_size, 1, num_layers, 1, 1) |
| base_y_ndc = yy[None, None, None].repeat(batch_size, 1, num_layers, 1, 1) |
|
|
| return base_x_ndc, base_y_ndc |
|
|
|
|
| def _create_base_scale(disparity: torch.Tensor, disparity_scale_factor: float) -> torch.Tensor: |
| """Create base scale for the gaussians.""" |
| inverse_disparity = torch.ones_like(disparity) / disparity |
| base_scales = inverse_disparity * disparity_scale_factor |
| return base_scales |
|
|
|
|
| def _rescale_depth( |
| depth: torch.Tensor, depth_min: float = 1.0, depth_max: float = 1e2 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Rescale a depth image tensor. |
| |
| Args: |
| depth: The depth tensor to transform. |
| depth_min: The min depth to scale depth to. |
| depth_max: The max clamp depth after scaling. |
| |
| Returns: |
| The rescaled depth and rescale factor. |
| """ |
| current_depth_min = depth.flatten(depth.ndim - 3).min(dim=-1).values |
| depth_factor = depth_min / (current_depth_min + 1e-6) |
| depth = (depth * depth_factor[..., None, None, None]).clamp(max=depth_max) |
| return depth, depth_factor |
|
|