| | """Contains definition of RGB-only gaussian predictor. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import logging |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from sharp.models.monodepth import MonodepthWithEncodingAdaptor |
| | from sharp.utils.gaussians import Gaussians3D |
| |
|
| | from .composer import GaussianComposer |
| |
|
| | LOGGER = logging.getLogger(__name__) |
| |
|
| |
|
| | class DepthAlignment(nn.Module): |
| | """Depth alignment in a dedicated nn.Module. |
| | |
| | Wrap scale_map_estimator to perform the conditional logic in a separated torch |
| | module outside the forward of RGBGaussianPredictor. This module can be then |
| | excluded during symbolic tracing. |
| | """ |
| |
|
| | def __init__(self, scale_map_estimator: nn.Module | None): |
| | """Initialize DepthAlignmentWrapper. |
| | |
| | Args: |
| | scale_map_estimator: Module to align monodepth to ground truth depth. |
| | """ |
| | super().__init__() |
| | self.scale_map_estimator = scale_map_estimator |
| |
|
| | def forward( |
| | self, |
| | monodepth: torch.Tensor, |
| | depth: torch.Tensor, |
| | depth_decoder_features: torch.Tensor | None = None, |
| | ): |
| | """Optionally align monodepth to ground truth with a local scale map. |
| | |
| | Args: |
| | monodepth: The monodepth model with intermediate features to use. |
| | depth: Ground truth depth to align predicted depth to. |
| | depth_decoder_features: The (optional) monodepth decoder features. |
| | """ |
| | if depth is not None and self.scale_map_estimator is not None: |
| | depth_alignment_map = self.scale_map_estimator( |
| | monodepth[:, 0:1], depth, depth_decoder_features |
| | ) |
| | monodepth = depth_alignment_map * monodepth |
| | else: |
| | |
| | |
| | depth_alignment_map = torch.ones_like(monodepth) |
| | return monodepth, depth_alignment_map |
| |
|
| |
|
| | class RGBGaussianPredictor(nn.Module): |
| | """Predicts 3D Gaussians from images.""" |
| |
|
| | feature_model: nn.Module |
| |
|
| | def __init__( |
| | self, |
| | init_model: nn.Module, |
| | monodepth_model: MonodepthWithEncodingAdaptor, |
| | feature_model: nn.Module, |
| | prediction_head: nn.Module, |
| | gaussian_composer: GaussianComposer, |
| | scale_map_estimator: nn.Module | None, |
| | ) -> None: |
| | """Initialize RGBGaussianPredictor. |
| | |
| | Args: |
| | init_model: A model mapping image and depth to base values. |
| | monodepth_model: The monodepth model with intermediate features to use. |
| | feature_model: The image2image model to predict Gaussians from. |
| | prediction_head: Head to decode image features. |
| | gaussian_composer: Module to compose final prediction from deltas and |
| | base values. |
| | scale_map_estimator: Module to align monodepth to ground truth depth. |
| | |
| | Note: |
| | ---- |
| | when monodepth_model is trainable, using local depth alignment can |
| | result in the monodepth model losing its ability to predict shapes. It is |
| | hence recommend to deactivate the corresponding flag. |
| | """ |
| | super().__init__() |
| | self.init_model = init_model |
| | self.feature_model = feature_model |
| | self.monodepth_model = monodepth_model |
| | self.prediction_head = prediction_head |
| | self.gaussian_composer = gaussian_composer |
| | self.depth_alignment = DepthAlignment(scale_map_estimator) |
| |
|
| | def forward( |
| | self, |
| | image: torch.Tensor, |
| | disparity_factor: torch.Tensor, |
| | depth: torch.Tensor | None = None, |
| | ) -> Gaussians3D: |
| | """Predict 3D Gaussians. |
| | |
| | Args: |
| | image: The image to process. |
| | disparity_factor: Factor to convert depth to disparities. |
| | depth: Ground truth depth to align predicted depth to. |
| | |
| | Returns: |
| | The predicted 3D Gaussians. |
| | |
| | Note: |
| | ---- |
| | During training, it is recommended to feed an additional ground truth depth |
| | map to the network to align the predicted depth to. During inference, it is |
| | recommended to use depth_gt=None and use monodepth_disparity output from the |
| | model instead to compute depth. |
| | """ |
| | |
| | monodepth_output = self.monodepth_model(image) |
| | monodepth_disparity = monodepth_output.disparity |
| |
|
| | disparity_factor = disparity_factor[:, None, None, None] |
| | monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | monodepth, _ = self.depth_alignment( |
| | monodepth, |
| | depth, |
| | monodepth_output.decoder_features, |
| | ) |
| |
|
| | init_output = self.init_model(image, monodepth) |
| | image_features = self.feature_model( |
| | init_output.feature_input, encodings=monodepth_output.output_features |
| | ) |
| | delta_values = self.prediction_head(image_features) |
| | gaussians = self.gaussian_composer( |
| | delta=delta_values, |
| | base_values=init_output.gaussian_base_values, |
| | global_scale=init_output.global_scale, |
| | ) |
| | return gaussians |
| |
|
| | def internal_resolution(self) -> int: |
| | """Internal resolution.""" |
| | return self.monodepth_model.internal_resolution() |
| |
|
| | @property |
| | def output_resolution(self) -> int: |
| | """Output resolution of Gaussians.""" |
| | return self.internal_resolution() // 2 |
| |
|