| | """Contains decoder head for direct prediction of delta values. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from .gaussian_decoder import ImageFeatures |
| |
|
| |
|
| | class DirectPredictionHead(nn.Module): |
| | """Decodes features into delta values using convolutions.""" |
| |
|
| | def __init__(self, feature_dim: int, num_layers: int) -> None: |
| | """Initialize DirectGaussianPredictor. |
| | |
| | Args: |
| | feature_dim: Number of input features. |
| | num_layers: The number of layers of Gaussians to predict. |
| | """ |
| | super().__init__() |
| | self.num_layers = num_layers |
| |
|
| | |
| | self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1) |
| | self.geometry_prediction_head.weight.data.zero_() |
| | assert self.geometry_prediction_head.bias is not None |
| | self.geometry_prediction_head.bias.data.zero_() |
| |
|
| | self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1) |
| | self.texture_prediction_head.weight.data.zero_() |
| | assert self.texture_prediction_head.bias is not None |
| | self.texture_prediction_head.bias.data.zero_() |
| |
|
| | def forward(self, image_features: ImageFeatures) -> torch.Tensor: |
| | """Predict deltas for 3D Gaussians. |
| | |
| | Args: |
| | image_features: Image features from decoder. |
| | |
| | Returns: |
| | The predicted deltas for Gaussian attributes. |
| | """ |
| | delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features) |
| | delta_values_texture = self.texture_prediction_head(image_features.texture_features) |
| | delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers)) |
| | delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers)) |
| | delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1) |
| | return delta_values |
| |
|