| """Contains decoder head for direct prediction of delta values. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from torch import nn |
|
|
| from .gaussian_decoder import ImageFeatures |
|
|
|
|
| class DirectPredictionHead(nn.Module): |
| """Decodes features into delta values using convolutions.""" |
|
|
| def __init__(self, feature_dim: int, num_layers: int) -> None: |
| """Initialize DirectGaussianPredictor. |
| |
| Args: |
| feature_dim: Number of input features. |
| num_layers: The number of layers of Gaussians to predict. |
| """ |
| super().__init__() |
| self.num_layers = num_layers |
|
|
| |
| self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1) |
| self.geometry_prediction_head.weight.data.zero_() |
| assert self.geometry_prediction_head.bias is not None |
| self.geometry_prediction_head.bias.data.zero_() |
|
|
| self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1) |
| self.texture_prediction_head.weight.data.zero_() |
| assert self.texture_prediction_head.bias is not None |
| self.texture_prediction_head.bias.data.zero_() |
|
|
| def forward(self, image_features: ImageFeatures) -> torch.Tensor: |
| """Predict deltas for 3D Gaussians. |
| |
| Args: |
| image_features: Image features from decoder. |
| |
| Returns: |
| The predicted deltas for Gaussian attributes. |
| """ |
| delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features) |
| delta_values_texture = self.texture_prediction_head(image_features.texture_features) |
| delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers)) |
| delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers)) |
| delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1) |
| return delta_values |
|
|