| """Defines module to compose final Gaussians from base values and delta values. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from sharp.models.initializer import GaussianBaseValues |
| from sharp.utils import math as math_utils |
| from sharp.utils.color_space import ColorSpace, sRGB2linearRGB |
| from sharp.utils.gaussians import Gaussians3D |
|
|
| from .params import DeltaFactor |
|
|
|
|
| def _get_scale_activation_constant(max_scale: float, min_scale: float) -> tuple[float, float]: |
| """Return constants for scale activation function.""" |
| |
| constant_a = (max_scale - min_scale) / (1 - min_scale) / (max_scale - 1) |
| constant_b = math_utils.inverse_sigmoid( |
| torch.tensor((1.0 - min_scale) / (max_scale - min_scale)) |
| ).item() |
| return constant_a, constant_b |
|
|
|
|
| class GaussianComposer(nn.Module): |
| """Converts base values and deltas into Gaussians.""" |
|
|
| color_activation_type: math_utils.ActivationType |
| opacity_activation_type: math_utils.ActivationType |
|
|
| def __init__( |
| self, |
| delta_factor: DeltaFactor, |
| min_scale: float, |
| max_scale: float, |
| color_activation_type: math_utils.ActivationType, |
| opacity_activation_type: math_utils.ActivationType, |
| color_space: ColorSpace, |
| base_scale_on_predicted_mean: bool, |
| scale_factor: int = 1, |
| ) -> None: |
| """Initialize GaussianComposer. |
| |
| Args: |
| delta_factor: Multiply delta offsets by this factor. |
| min_scale: The minimal scale factor for gaussian scale activation. |
| max_scale: The maximal scale factor for gaussian scale activation. |
| color_activation_type: Which activation function to use for colors. |
| opacity_activation_type: Which activation function to use for opacities. |
| color_space: Which color space is used in training. |
| scale_factor: The scale factor to upsample the delta_values before composition. |
| base_scale_on_predicted_mean: Whether to account z offsets for estimating base scale. |
| """ |
| super().__init__() |
| self.delta_factor = delta_factor |
| self.max_scale = max_scale |
| self.min_scale = min_scale |
| self.color_activation_type = color_activation_type |
| self.opacity_activation_type = opacity_activation_type |
| self.color_space = color_space |
| self.scale_factor = scale_factor |
| self.base_scale_on_predicted_mean = base_scale_on_predicted_mean |
|
|
| def upsample_delta_value(self, delta: torch.Tensor, scale_factor: int = 1): |
| """Upsample the delta value. |
| |
| Args: |
| delta: The delta values predicted by gaussian predictor. |
| scale_factor: The scale factor to upsample the delta_values. |
| """ |
| ( |
| batch_size, |
| num_channels, |
| num_layers, |
| image_height, |
| image_width, |
| ) = delta.shape |
| new_height = image_height * scale_factor |
| new_width = image_width * scale_factor |
| upsampled_delta = F.interpolate( |
| delta.view(batch_size, num_channels * num_layers, image_height, image_width), |
| scale_factor=scale_factor, |
| ).view(batch_size, num_channels, num_layers, new_height, new_width) |
| return upsampled_delta |
|
|
| def forward( |
| self, |
| delta: torch.Tensor, |
| base_values: GaussianBaseValues, |
| global_scale: torch.Tensor | None = None, |
| flatten_output: bool = True, |
| ) -> Gaussians3D: |
| """Combine predicted delta values with base gaussian values and apply activation function. |
| |
| Args: |
| delta: The delta values predicted by gaussian predictor. |
| base_values: The gaussian base values. |
| global_scale: Global scale of Gaussians. |
| flatten_output: Flatten the gaussian parameters. |
| |
| Returns: |
| The computed 3D Gaussians. |
| """ |
| |
| scale_factor = self.scale_factor |
| |
| actual_scale_factor = base_values.mean_x_ndc.shape[-1] // delta.shape[-1] |
| if scale_factor != 1 and actual_scale_factor != 1: |
| delta = self.upsample_delta_value(delta, scale_factor) |
|
|
| mean_vectors = self._forward_mean(base_values, delta) |
|
|
| |
| base_scales = ( |
| (base_values.scales * base_values.mean_inverse_z_ndc * mean_vectors[:, 2:3, ...]) |
| if self.base_scale_on_predicted_mean |
| else base_values.scales |
| ) |
| singular_values = self._scale_activation( |
| base_scales, |
| delta[:, 3:6], |
| self.min_scale, |
| self.max_scale, |
| ) |
| quaternions = self._quaternion_activation(base_values.quaternions, delta[:, 6:10]) |
| colors = self._color_activation(base_values.colors, delta[:, 10:13]) |
| opacities = self._opacity_activation(base_values.opacities, delta[:, 13]) |
|
|
| if flatten_output: |
| |
| |
| mean_vectors = mean_vectors.permute(0, 2, 3, 4, 1).flatten(1, 3) |
| singular_values = singular_values.permute(0, 2, 3, 4, 1).flatten(1, 3) |
| quaternions = quaternions.permute(0, 2, 3, 4, 1).flatten(1, 3) |
| colors = colors.permute(0, 2, 3, 4, 1).flatten(1, 3) |
| opacities = opacities.flatten(1, 3) |
|
|
| |
| if global_scale is not None: |
| mean_vectors = global_scale[:, None, None] * mean_vectors |
| singular_values = global_scale[:, None, None] * singular_values |
|
|
| return Gaussians3D( |
| mean_vectors=mean_vectors, |
| singular_values=singular_values, |
| quaternions=quaternions, |
| colors=colors, |
| opacities=opacities, |
| ) |
|
|
| def _forward_mean(self, base_values: GaussianBaseValues, delta: torch.Tensor) -> torch.Tensor: |
| |
| delta_factor = torch.tensor( |
| [self.delta_factor.xy, self.delta_factor.xy, self.delta_factor.z], |
| device=delta.device, |
| )[None, :, None, None, None] |
|
|
| dtype = base_values.mean_x_ndc.dtype |
| device = base_values.mean_x_ndc.device |
| target_shape = (1, 3, 1, 1, 1) |
| mean_x_mask = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device).reshape( |
| target_shape |
| ) |
| mean_y_mask = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device).reshape( |
| target_shape |
| ) |
| mean_z_mask = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).reshape( |
| target_shape |
| ) |
|
|
| mean_vectors_ndc = ( |
| base_values.mean_x_ndc.repeat(target_shape) * mean_x_mask |
| + base_values.mean_y_ndc.repeat(target_shape) * mean_y_mask |
| + base_values.mean_inverse_z_ndc.repeat(target_shape) * mean_z_mask |
| ) |
|
|
| mean_vectors = self._mean_activation(mean_vectors_ndc, delta_factor * delta[:, :3]) |
| return mean_vectors |
|
|
| def _mean_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: |
| """Mean activation function. |
| |
| Args: |
| base: Tensor of shape [B, 3, H, W], where first two feature dimensions |
| (x,y) are in normalized device coordinates (NDC) where (-1, -1) is |
| the top, while the third dimension is inverse depth. |
| learned_delta: Tensor of shape [B, 3, H, W] with predicted delta values. |
| |
| Returns: |
| Returns: The final mean vector after combining base and delta and applying nonlinearies. |
| """ |
| xx = base[:, 0:1] + learned_delta[:, 0:1] |
| yy = base[:, 1:2] + learned_delta[:, 1:2] |
|
|
| a = base[:, 2:3] |
| b = learned_delta[:, 2:3] |
|
|
| |
| inverse_zz = F.softplus(math_utils.inverse_softplus(a) + b) |
| zz = 1.0 / (inverse_zz + 1e-3) |
|
|
| mean_vectors = torch.cat([zz * xx, zz * yy, zz], dim=1) |
| return mean_vectors |
|
|
| def _scale_activation( |
| self, |
| base: torch.Tensor, |
| learned_delta: torch.Tensor, |
| min_scale: float, |
| max_scale: float, |
| ) -> torch.Tensor: |
| constant_a, constant_b = _get_scale_activation_constant(max_scale, min_scale) |
| scale_factor = (max_scale - min_scale) * torch.sigmoid( |
| constant_a * self.delta_factor.scale * learned_delta + constant_b |
| ) + min_scale |
| return base * scale_factor |
|
|
| def _quaternion_activation( |
| self, base: torch.Tensor, learned_delta: torch.Tensor |
| ) -> torch.Tensor: |
| |
| return base + self.delta_factor.quaternion * learned_delta |
|
|
| def _color_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: |
| |
| |
| if self.color_activation_type == "sigmoid": |
| base = torch.clamp(base, min=0.01, max=0.99) |
| elif self.color_activation_type in ("exp", "softplus"): |
| base = torch.clamp(base, min=0.01) |
|
|
| activation = math_utils.create_activation_pair(self.color_activation_type) |
| colors: torch.Tensor = activation.forward( |
| activation.inverse(base) + self.delta_factor.color * learned_delta |
| ) |
| |
| if self.color_space == "linearRGB": |
| colors = sRGB2linearRGB(colors) |
| return colors |
|
|
| def _opacity_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: |
| activation = math_utils.create_activation_pair(self.opacity_activation_type) |
| return activation.forward( |
| activation.inverse(base) + self.delta_factor.opacity * learned_delta |
| ) |
|
|