| | """Contains modules for different types of alignment. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from sharp.models.decoders import UNetDecoder |
| | from sharp.models.encoders import UNetEncoder |
| | from sharp.utils import math as math_utils |
| |
|
| | from .params import AlignmentParams |
| |
|
| |
|
| | def create_alignment( |
| | params: AlignmentParams, depth_decoder_dim: int | None = None |
| | ) -> nn.Module | None: |
| | """Create depth alignment.""" |
| | if depth_decoder_dim is None: |
| | raise ValueError("Requires depth_decoder_dim for LearnedAlignment.") |
| | alignment = LearnedAlignment( |
| | depth_decoder_features=params.depth_decoder_features, |
| | depth_decoder_dim=depth_decoder_dim, |
| | steps=params.steps, |
| | stride=params.stride, |
| | base_width=params.base_width, |
| | activation_type=params.activation_type, |
| | ) |
| |
|
| | if params.frozen: |
| | alignment.requires_grad_(False) |
| |
|
| | return alignment |
| |
|
| |
|
| | class LearnedAlignment(nn.Module): |
| | """Aligns tensors using a UNet.""" |
| |
|
| | def __init__( |
| | self, |
| | steps: int = 4, |
| | stride: int = 8, |
| | base_width: int = 16, |
| | depth_decoder_features: bool = False, |
| | depth_decoder_dim: int = 256, |
| | activation_type: math_utils.ActivationType = "exp", |
| | ) -> None: |
| | """Initialize LearnedAlignment. |
| | |
| | Args: |
| | steps: Number of steps in the UNet. |
| | stride: Effective downsampling of the alignment module. |
| | base_width: Base width of the UNet. |
| | depth_decoder_features: Whether to use depth decoder features. |
| | depth_decoder_dim: Dimension of the depth decoder features. |
| | activation_type: Activation type for the alignment output. |
| | """ |
| | super().__init__() |
| | self.activation = math_utils.create_activation_pair(activation_type) |
| | bias_value = self.activation.inverse(torch.tensor(1.0)) |
| |
|
| | self.depth_decoder_features = depth_decoder_features |
| | if depth_decoder_features: |
| | dim_in = 2 + depth_decoder_dim |
| | else: |
| | dim_in = 2 |
| |
|
| | def is_power_of_two(n: int) -> bool: |
| | """Check if a number is a power of two.""" |
| | if n <= 0: |
| | return False |
| | return (n & (n - 1)) == 0 |
| |
|
| | if not is_power_of_two(stride): |
| | raise ValueError(f"Stride {stride} is not a power of two.") |
| |
|
| | steps_decoder = steps - int(math.log2(stride)) |
| | if steps_decoder < 1: |
| | raise ValueError(f"{steps_decoder} must be greater or equal to 1.") |
| | widths = [min(base_width << i, 1024) for i in range(steps + 1)] |
| | self.encoder = UNetEncoder(dim_in=dim_in, width=widths, steps=steps, norm_num_groups=4) |
| | self.decoder = UNetDecoder( |
| | dim_out=widths[0], width=widths, steps=steps_decoder, norm_num_groups=4 |
| | ) |
| | self.conv_out = nn.Conv2d(widths[0], 1, 1, bias=True) |
| | nn.init.zeros_(self.conv_out.weight) |
| | nn.init.constant_(self.conv_out.bias, bias_value) |
| |
|
| | def forward( |
| | self, |
| | tensor_src: torch.Tensor, |
| | tensor_tgt: torch.Tensor, |
| | depth_decoder_features: torch.Tensor | None = None, |
| | ) -> torch.Tensor: |
| | """Compute alignment map.""" |
| | |
| | |
| | tensor_src = 1.0 / tensor_src.clamp(min=1e-4) |
| | tensor_tgt = 1.0 / tensor_tgt.clamp(min=1e-4) |
| | tensor_input = torch.cat([tensor_src, tensor_tgt], dim=1) |
| | if self.depth_decoder_features: |
| | height, width = tensor_src.shape[-2:] |
| | upsampled_encodings = F.interpolate( |
| | depth_decoder_features, |
| | size=(height, width), |
| | mode="bilinear", |
| | ) |
| | tensor_input = torch.cat([tensor_input, upsampled_encodings], dim=1) |
| | features = self.encoder(tensor_input) |
| | output = self.conv_out(self.decoder(features)) |
| | alignment_map_lowres = self.activation.forward(output) |
| | if alignment_map_lowres.shape[-2:] != tensor_src.shape[-2]: |
| | alignment_map = F.interpolate( |
| | alignment_map_lowres, |
| | size=tensor_src.shape[-2:], |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | return alignment_map |
| |
|