| | """Contains Dense Transformer Prediction architecture. |
| | |
| | Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import copy |
| | from typing import NamedTuple, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from sharp.models import normalizers |
| | from sharp.models.decoders import MultiresConvDecoder, create_monodepth_decoder |
| | from sharp.models.encoders import ( |
| | SlidingPyramidNetwork, |
| | create_monodepth_encoder, |
| | ) |
| | from sharp.utils import module_surgery |
| |
|
| | from .params import MonodepthAdaptorParams, MonodepthParams |
| |
|
| | DimsDecoder = Tuple[int, int, int, int, int] |
| |
|
| |
|
| | class MonodepthDensePredictionTransformer(nn.Module): |
| | """Dense Prediction Transformer for monodepth. |
| | |
| | Attach the disparity prediction head for monodepth prediction. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | encoder: SlidingPyramidNetwork, |
| | decoder: MultiresConvDecoder, |
| | last_dims: tuple[int, int], |
| | ): |
| | """Initialize Dense Prediction Transformer. |
| | |
| | Args: |
| | encoder: The SlidingPyramidTransformer backbone. |
| | decoder: The MultiresConvDecoder decoder. |
| | last_dims: The dimension for the last convolution layers. |
| | """ |
| | super().__init__() |
| |
|
| | self.normalizer = normalizers.AffineRangeNormalizer( |
| | input_range=(0, 1), output_range=(-1, 1) |
| | ) |
| | self.encoder = encoder |
| | self.decoder = decoder |
| |
|
| | dim_decoder = decoder.dim_out |
| | self.head = nn.Sequential( |
| | nn.Conv2d(dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1), |
| | nn.ConvTranspose2d( |
| | in_channels=dim_decoder // 2, |
| | out_channels=dim_decoder // 2, |
| | kernel_size=2, |
| | stride=2, |
| | padding=0, |
| | bias=True, |
| | ), |
| | nn.Conv2d( |
| | dim_decoder // 2, |
| | last_dims[0], |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | ), |
| | nn.ReLU(True), |
| | nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0), |
| | nn.ReLU(), |
| | ) |
| |
|
| | |
| | self.head[4].bias.data.fill_(0) |
| |
|
| | self.grad_checkpointing = False |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, is_enabled=True): |
| | """Enable grad checkpointing.""" |
| | self.grad_checkpointing = is_enabled |
| | self.encoder.set_grad_checkpointing(self.grad_checkpointing) |
| | self.decoder.set_grad_checkpointing(self.grad_checkpointing) |
| |
|
| | def forward(self, image: torch.Tensor) -> torch.Tensor: |
| | """Decode by projection and fusion of multi-resolution encodings.""" |
| | encodings = self.encoder(self.normalizer(image)) |
| | num_encoder_features = len(self.encoder.dims_encoder) |
| | features = self.decoder(encodings[:num_encoder_features]) |
| | disparity = self.head(features) |
| | return disparity |
| |
|
| | def internal_resolution(self) -> int: |
| | """Return the internal image size of the network.""" |
| | return self.encoder.internal_resolution() |
| |
|
| |
|
| | def create_monodepth_dpt( |
| | params: MonodepthParams | None = None, |
| | ) -> MonodepthDensePredictionTransformer: |
| | """Creates DepthDensePredictionTransformer model. |
| | |
| | Args: |
| | params: Parameters of monodepth network. |
| | |
| | Returns: |
| | The configured monodepth DPT. |
| | """ |
| | if params is None: |
| | params = MonodepthParams() |
| | encoder: SlidingPyramidNetwork = create_monodepth_encoder( |
| | params.patch_encoder_preset, |
| | params.image_encoder_preset, |
| | use_patch_overlap=params.use_patch_overlap, |
| | last_encoder=params.dims_decoder[0], |
| | ) |
| |
|
| | decoder: MultiresConvDecoder = create_monodepth_decoder( |
| | params.patch_encoder_preset, params.dims_decoder |
| | ) |
| |
|
| | monodepth_model = MonodepthDensePredictionTransformer( |
| | encoder=encoder, decoder=decoder, last_dims=(32, 1) |
| | ) |
| |
|
| | |
| | |
| | monodepth_model.requires_grad_(False) |
| |
|
| | monodepth_model.encoder.set_requires_grad_( |
| | patch_encoder=params.unfreeze_patch_encoder, |
| | image_encoder=params.unfreeze_image_encoder, |
| | ) |
| | monodepth_model.decoder.requires_grad_(params.unfreeze_decoder) |
| | monodepth_model.head.requires_grad_(params.unfreeze_head) |
| |
|
| | if not params.unfreeze_norm_layers: |
| | module_surgery.freeze_norm_layer(monodepth_model) |
| |
|
| | monodepth_model.set_grad_checkpointing(params.grad_checkpointing) |
| |
|
| | return monodepth_model |
| |
|
| |
|
| | class MonodepthOutput(NamedTuple): |
| | """Output of the monodepth model.""" |
| |
|
| | |
| | disparity: torch.Tensor |
| | |
| | encoder_features: list[torch.Tensor] |
| | |
| | decoder_features: torch.Tensor |
| | |
| | output_features: list[torch.Tensor] |
| | |
| | intermediate_features: list[torch.Tensor] = [] |
| |
|
| |
|
| | class MonodepthWithEncodingAdaptor(nn.Module): |
| | """Monodepth model with feature maps.""" |
| |
|
| | def __init__( |
| | self, |
| | monodepth_predictor: MonodepthDensePredictionTransformer, |
| | return_encoder_features: bool, |
| | return_decoder_features: bool, |
| | num_monodepth_layers: int, |
| | sorting_monodepth: bool, |
| | ): |
| | """Initialize MonodepthWithEncodingAdaptor. |
| | |
| | Args: |
| | monodepth_predictor: The monodepth model. |
| | return_encoder_features: Whether to return encoder features from monodepth model. |
| | return_decoder_features: Whether to return decoder features from monodepth model. |
| | num_monodepth_layers: How many layers the monodepth model predicts. |
| | sorting_monodepth: Whether to sort the monodepth output (for two layer monodepth). |
| | """ |
| | super().__init__() |
| | self.monodepth_predictor = monodepth_predictor |
| | self.return_encoder_features = return_encoder_features |
| | self.return_decoder_features = return_decoder_features |
| | self.num_monodepth_layers = num_monodepth_layers |
| | self.sorting_monodepth = sorting_monodepth |
| |
|
| | def forward(self, image: torch.Tensor) -> MonodepthOutput: |
| | """Process image and return disparity and feature maps.""" |
| | inputs = self.monodepth_predictor.normalizer(image) |
| | encoder_output = self.monodepth_predictor.encoder(inputs) |
| |
|
| | num_encoder_features = len(self.monodepth_predictor.encoder.dims_encoder) |
| |
|
| | |
| | |
| | encoder_features = encoder_output[:num_encoder_features] |
| | intermediate_features = encoder_output[num_encoder_features:] |
| | decoder_features = self.monodepth_predictor.decoder(encoder_features) |
| | disparity = self.monodepth_predictor.head(decoder_features) |
| |
|
| | |
| | if self.num_monodepth_layers == 2 and self.sorting_monodepth: |
| | first_layer_disparity = disparity.max(dim=1, keepdims=True).values |
| | second_layer_disparity = disparity.min(dim=1, keepdims=True).values |
| | disparity = torch.cat([first_layer_disparity, second_layer_disparity], dim=1) |
| |
|
| | output_features = [] |
| | if self.return_encoder_features: |
| | output_features.extend(encoder_features) |
| |
|
| | if self.return_decoder_features: |
| | output_features.append(decoder_features) |
| |
|
| | return MonodepthOutput( |
| | disparity=disparity, |
| | encoder_features=encoder_features, |
| | decoder_features=decoder_features, |
| | output_features=output_features, |
| | intermediate_features=intermediate_features, |
| | ) |
| |
|
| | def get_feature_dims(self) -> list[int]: |
| | """Return dimensions of output feature maps.""" |
| | dims = [] |
| | if self.return_encoder_features: |
| | dims.extend(self.monodepth_predictor.encoder.dims_encoder) |
| |
|
| | if self.return_decoder_features: |
| | dims.append(self.monodepth_predictor.decoder.dim_out) |
| |
|
| | return dims |
| |
|
| | def internal_resolution(self) -> int: |
| | """Return the internal image size of the network.""" |
| | return self.monodepth_predictor.internal_resolution() |
| |
|
| | def replicate_head(self, num_repeat: int): |
| | """Replicate the last convolution layer (head[4] in DPT) for multi layer depth.""" |
| | conv_last = copy.deepcopy(self.monodepth_predictor.head[4]) |
| | self.monodepth_predictor.head[4].out_channels = num_repeat |
| | self.monodepth_predictor.head[4].weight = nn.Parameter( |
| | conv_last.weight.repeat(num_repeat, 1, 1, 1) |
| | ) |
| | self.monodepth_predictor.head[4].bias = nn.Parameter(conv_last.bias.repeat(num_repeat)) |
| |
|
| |
|
| | def create_monodepth_adaptor( |
| | monodepth_predictor: MonodepthDensePredictionTransformer, |
| | params: MonodepthAdaptorParams, |
| | num_monodepth_layers: int, |
| | sorting_monodepth: bool, |
| | ) -> MonodepthWithEncodingAdaptor: |
| | """Create an adaptor that returns both disparity and features.""" |
| | adaptor = MonodepthWithEncodingAdaptor( |
| | monodepth_predictor=monodepth_predictor, |
| | return_encoder_features=params.encoder_features, |
| | return_decoder_features=params.decoder_features, |
| | num_monodepth_layers=num_monodepth_layers, |
| | sorting_monodepth=sorting_monodepth, |
| | ) |
| | return adaptor |
| |
|