| """Contains definition of RGB-only gaussian predictor. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
|
|
| import torch |
| from torch import nn |
|
|
| from sharp.models.monodepth import MonodepthWithEncodingAdaptor |
| from sharp.utils.gaussians import Gaussians3D |
|
|
| from .composer import GaussianComposer |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
|
|
| class DepthAlignment(nn.Module): |
| """Depth alignment in a dedicated nn.Module. |
| |
| Wrap scale_map_estimator to perform the conditional logic in a separated torch |
| module outside the forward of RGBGaussianPredictor. This module can be then |
| excluded during symbolic tracing. |
| """ |
|
|
| def __init__(self, scale_map_estimator: nn.Module | None): |
| """Initialize DepthAlignmentWrapper. |
| |
| Args: |
| scale_map_estimator: Module to align monodepth to ground truth depth. |
| """ |
| super().__init__() |
| self.scale_map_estimator = scale_map_estimator |
|
|
| def forward( |
| self, |
| monodepth: torch.Tensor, |
| depth: torch.Tensor, |
| depth_decoder_features: torch.Tensor | None = None, |
| ): |
| """Optionally align monodepth to ground truth with a local scale map. |
| |
| Args: |
| monodepth: The monodepth model with intermediate features to use. |
| depth: Ground truth depth to align predicted depth to. |
| depth_decoder_features: The (optional) monodepth decoder features. |
| """ |
| if depth is not None and self.scale_map_estimator is not None: |
| depth_alignment_map = self.scale_map_estimator( |
| monodepth[:, 0:1], depth, depth_decoder_features |
| ) |
| monodepth = depth_alignment_map * monodepth |
| else: |
| |
| |
| depth_alignment_map = torch.ones_like(monodepth) |
| return monodepth, depth_alignment_map |
|
|
|
|
| class RGBGaussianPredictor(nn.Module): |
| """Predicts 3D Gaussians from images.""" |
|
|
| feature_model: nn.Module |
|
|
| def __init__( |
| self, |
| init_model: nn.Module, |
| monodepth_model: MonodepthWithEncodingAdaptor, |
| feature_model: nn.Module, |
| prediction_head: nn.Module, |
| gaussian_composer: GaussianComposer, |
| scale_map_estimator: nn.Module | None, |
| ) -> None: |
| """Initialize RGBGaussianPredictor. |
| |
| Args: |
| init_model: A model mapping image and depth to base values. |
| monodepth_model: The monodepth model with intermediate features to use. |
| feature_model: The image2image model to predict Gaussians from. |
| prediction_head: Head to decode image features. |
| gaussian_composer: Module to compose final prediction from deltas and |
| base values. |
| scale_map_estimator: Module to align monodepth to ground truth depth. |
| |
| Note: |
| ---- |
| when monodepth_model is trainable, using local depth alignment can |
| result in the monodepth model losing its ability to predict shapes. It is |
| hence recommend to deactivate the corresponding flag. |
| """ |
| super().__init__() |
| self.init_model = init_model |
| self.feature_model = feature_model |
| self.monodepth_model = monodepth_model |
| self.prediction_head = prediction_head |
| self.gaussian_composer = gaussian_composer |
| self.depth_alignment = DepthAlignment(scale_map_estimator) |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| disparity_factor: torch.Tensor, |
| depth: torch.Tensor | None = None, |
| ) -> Gaussians3D: |
| """Predict 3D Gaussians. |
| |
| Args: |
| image: The image to process. |
| disparity_factor: Factor to convert depth to disparities. |
| depth: Ground truth depth to align predicted depth to. |
| |
| Returns: |
| The predicted 3D Gaussians. |
| |
| Note: |
| ---- |
| During training, it is recommended to feed an additional ground truth depth |
| map to the network to align the predicted depth to. During inference, it is |
| recommended to use depth_gt=None and use monodepth_disparity output from the |
| model instead to compute depth. |
| """ |
| |
| monodepth_output = self.monodepth_model(image) |
| monodepth_disparity = monodepth_output.disparity |
|
|
| disparity_factor = disparity_factor[:, None, None, None] |
| monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| monodepth, _ = self.depth_alignment( |
| monodepth, |
| depth, |
| monodepth_output.decoder_features, |
| ) |
|
|
| init_output = self.init_model(image, monodepth) |
| image_features = self.feature_model( |
| init_output.feature_input, encodings=monodepth_output.output_features |
| ) |
| delta_values = self.prediction_head(image_features) |
| gaussians = self.gaussian_composer( |
| delta=delta_values, |
| base_values=init_output.gaussian_base_values, |
| global_scale=init_output.global_scale, |
| ) |
| return gaussians |
|
|
| def internal_resolution(self) -> int: |
| """Internal resolution.""" |
| return self.monodepth_model.internal_resolution() |
|
|
| @property |
| def output_resolution(self) -> int: |
| """Output resolution of Gaussians.""" |
| return self.internal_resolution() // 2 |
|
|