| | """Contains utility code for gsplat renderer. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from pathlib import Path |
| | from typing import NamedTuple |
| |
|
| | import gsplat |
| | import torch |
| | from torch import nn |
| |
|
| | from sharp.utils import color_space as cs_utils |
| | from sharp.utils import io, vis |
| | from sharp.utils.gaussians import BackgroundColor, Gaussians3D |
| |
|
| |
|
| | class RenderingOutputs(NamedTuple): |
| | """Outputs of 3D Gaussians renderer.""" |
| |
|
| | color: torch.Tensor |
| | depth: torch.Tensor |
| | alpha: torch.Tensor |
| |
|
| |
|
| | def write_renderings(rendering: RenderingOutputs, output_folder: Path, filename: str): |
| | """Write rendered color/depth/alpha to files.""" |
| | batch_size = len(rendering.color) |
| | if batch_size != 1: |
| | raise RuntimeError("We only support saving rendering of batch size = 1") |
| |
|
| | def _save_image_tensor(tensor: torch.Tensor, suffix: str): |
| | np_array = tensor.permute(1, 2, 0).numpy() |
| | io.save_image(np_array, (output_folder / filename).with_suffix(suffix)) |
| |
|
| | color = (rendering.color[0].cpu() * 255.0).to(dtype=torch.uint8) |
| | colorized_depth = vis.colorize_depth(rendering.depth[0], val_max=100.0) |
| | colorized_alpha = vis.colorize_alpha(rendering.alpha[0]) |
| |
|
| | _save_image_tensor(color, ".color.png") |
| | _save_image_tensor(colorized_depth, ".depth.png") |
| | _save_image_tensor(colorized_alpha, ".alpha.png") |
| |
|
| |
|
| | class GSplatRenderer(nn.Module): |
| | """Module to render 3D Gaussians to images using gsplat.""" |
| |
|
| | color_space: cs_utils.ColorSpace |
| | background_color: BackgroundColor |
| |
|
| | def __init__( |
| | self, |
| | color_space: cs_utils.ColorSpace = "sRGB", |
| | background_color: BackgroundColor = "black", |
| | low_pass_filter_eps: float = 0.0, |
| | ) -> None: |
| | """Initialize gsplat renderer. |
| | |
| | Args: |
| | color_space: The color space to use for rendering. |
| | background_color: The background color to use for rendering. |
| | low_pass_filter_eps: The epsilon value for the low pass filter. |
| | """ |
| | super().__init__() |
| | self.color_space = color_space |
| | self.background_color = background_color |
| | self.low_pass_filter_eps = low_pass_filter_eps |
| |
|
| | def forward( |
| | self, |
| | gaussians: Gaussians3D, |
| | extrinsics: torch.Tensor, |
| | intrinsics: torch.Tensor, |
| | image_width: int, |
| | image_height: int, |
| | ) -> RenderingOutputs: |
| | """Predict images from gaussians. |
| | |
| | Args: |
| | gaussians: The Gaussians to render. |
| | extrinsics: The extrinsics of the camera to render to in OpenCV format. |
| | intrinsics: The intriniscs of the camera to render to in OpenCV format. |
| | image_width: The desired output image width. |
| | image_height: The desired output image height. |
| | """ |
| | batch_size = len(gaussians.mean_vectors) |
| | outputs_list: list[RenderingOutputs] = [] |
| |
|
| | for ib in range(batch_size): |
| | colors, alphas, meta = gsplat.rendering.rasterization( |
| | means=gaussians.mean_vectors[ib], |
| | quats=gaussians.quaternions[ib], |
| | scales=gaussians.singular_values[ib], |
| | opacities=gaussians.opacities[ib], |
| | colors=gaussians.colors[ib], |
| | viewmats=extrinsics[ib : ib + 1], |
| | Ks=intrinsics[ib : ib + 1, :3, :3], |
| | width=image_width, |
| | height=image_height, |
| | render_mode="RGB+D", |
| | rasterize_mode="classic", |
| | absgrad=False, |
| | packed=False, |
| | eps2d=self.low_pass_filter_eps, |
| | ) |
| |
|
| | rendered_color = colors[..., 0:3].permute([0, 3, 1, 2]) |
| | rendered_depth_unnormalized = colors[..., 3:4].permute([0, 3, 1, 2]) |
| | rendered_alpha = alphas.permute([0, 3, 1, 2]) |
| |
|
| | |
| | rendered_color = self.compose_with_background( |
| | rendered_color, rendered_alpha, self.background_color |
| | ) |
| |
|
| | |
| | if self.color_space == "sRGB": |
| | pass |
| | elif self.color_space == "linearRGB": |
| | rendered_color = cs_utils.linearRGB2sRGB(rendered_color) |
| | else: |
| | ValueError("Unsupported ColorSpace type.") |
| |
|
| | |
| | cov2d = self._conics_to_covars2d(meta["conics"]) |
| | |
| | splats_visible_mask = meta["depths"] > 1e-2 |
| | cov2d[~splats_visible_mask][..., 0, 0] = 1 |
| | cov2d[~splats_visible_mask][..., 1, 1] = 1 |
| | cov2d[~splats_visible_mask][..., 0, 1] = 0 |
| |
|
| | |
| | rendered_depth = rendered_depth_unnormalized / torch.clip(rendered_alpha, min=1e-8) |
| |
|
| | outputs = RenderingOutputs( |
| | color=rendered_color, |
| | depth=rendered_depth, |
| | alpha=rendered_alpha, |
| | ) |
| | outputs_list.append(outputs) |
| |
|
| | return RenderingOutputs( |
| | color=torch.cat([item.color for item in outputs_list], dim=0).contiguous(), |
| | depth=torch.cat([item.depth for item in outputs_list], dim=0).contiguous(), |
| | alpha=torch.cat([item.alpha for item in outputs_list], dim=0).contiguous(), |
| | ) |
| |
|
| | @staticmethod |
| | def compose_with_background( |
| | rendered_rgb: torch.Tensor, |
| | rendered_alpha: torch.Tensor, |
| | background_color: BackgroundColor, |
| | ) -> torch.Tensor: |
| | """Compose rendered RGB with background color.""" |
| | if background_color == "black": |
| | return rendered_rgb |
| | elif background_color == "white": |
| | return rendered_rgb + (1.0 - rendered_alpha) |
| | elif background_color == "random_color": |
| | return ( |
| | rendered_rgb |
| | + (1.0 - rendered_alpha) |
| | * torch.rand(3, dtype=rendered_rgb.dtype, device=rendered_rgb.device)[ |
| | None, :, None, None |
| | ] |
| | ) |
| | elif background_color == "random_pixel": |
| | return rendered_rgb + (1.0 - rendered_alpha) * torch.rand_like(rendered_rgb) |
| | else: |
| | raise ValueError("Unsupported BackgroundColor type.") |
| |
|
| | @staticmethod |
| | def _conics_to_covars2d(conics: torch.Tensor, eps=1e-8) -> torch.Tensor: |
| | """Convert conics to covariance matrices.""" |
| | a = conics[..., 0] |
| | b = conics[..., 1] |
| | c = conics[..., 2] |
| | |
| | det = 1 / (a * c - b**2 + eps) |
| | det = det.clamp(min=eps) |
| | |
| | covars2d = torch.zeros(*conics.shape[:-1], 2, 2, device=conics.device) |
| | covars2d[..., 1, 1] = a * det |
| | covars2d[..., 0, 0] = c * det |
| | covars2d[..., 0, 1] = -b * det |
| | covars2d[..., 1, 0] = -b * det |
| | covars2d = torch.nan_to_num(covars2d, nan=0.0, posinf=0.0, neginf=0.0) |
| | return covars2d |
| |
|