| """Contains Dense Transformer Prediction architecture. |
| |
| Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import copy |
| from typing import NamedTuple, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from sharp.models import normalizers |
| from sharp.models.decoders import MultiresConvDecoder, create_monodepth_decoder |
| from sharp.models.encoders import ( |
| SlidingPyramidNetwork, |
| create_monodepth_encoder, |
| ) |
| from sharp.utils import module_surgery |
|
|
| from .params import MonodepthAdaptorParams, MonodepthParams |
|
|
| DimsDecoder = Tuple[int, int, int, int, int] |
|
|
|
|
| class MonodepthDensePredictionTransformer(nn.Module): |
| """Dense Prediction Transformer for monodepth. |
| |
| Attach the disparity prediction head for monodepth prediction. |
| """ |
|
|
| def __init__( |
| self, |
| encoder: SlidingPyramidNetwork, |
| decoder: MultiresConvDecoder, |
| last_dims: tuple[int, int], |
| ): |
| """Initialize Dense Prediction Transformer. |
| |
| Args: |
| encoder: The SlidingPyramidTransformer backbone. |
| decoder: The MultiresConvDecoder decoder. |
| last_dims: The dimension for the last convolution layers. |
| """ |
| super().__init__() |
|
|
| self.normalizer = normalizers.AffineRangeNormalizer( |
| input_range=(0, 1), output_range=(-1, 1) |
| ) |
| self.encoder = encoder |
| self.decoder = decoder |
|
|
| dim_decoder = decoder.dim_out |
| self.head = nn.Sequential( |
| nn.Conv2d(dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1), |
| nn.ConvTranspose2d( |
| in_channels=dim_decoder // 2, |
| out_channels=dim_decoder // 2, |
| kernel_size=2, |
| stride=2, |
| padding=0, |
| bias=True, |
| ), |
| nn.Conv2d( |
| dim_decoder // 2, |
| last_dims[0], |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| nn.ReLU(True), |
| nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0), |
| nn.ReLU(), |
| ) |
|
|
| |
| self.head[4].bias.data.fill_(0) |
|
|
| self.grad_checkpointing = False |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, is_enabled=True): |
| """Enable grad checkpointing.""" |
| self.grad_checkpointing = is_enabled |
| self.encoder.set_grad_checkpointing(self.grad_checkpointing) |
| self.decoder.set_grad_checkpointing(self.grad_checkpointing) |
|
|
| def forward(self, image: torch.Tensor) -> torch.Tensor: |
| """Decode by projection and fusion of multi-resolution encodings.""" |
| encodings = self.encoder(self.normalizer(image)) |
| num_encoder_features = len(self.encoder.dims_encoder) |
| features = self.decoder(encodings[:num_encoder_features]) |
| disparity = self.head(features) |
| return disparity |
|
|
| def internal_resolution(self) -> int: |
| """Return the internal image size of the network.""" |
| return self.encoder.internal_resolution() |
|
|
|
|
| def create_monodepth_dpt( |
| params: MonodepthParams | None = None, |
| ) -> MonodepthDensePredictionTransformer: |
| """Creates DepthDensePredictionTransformer model. |
| |
| Args: |
| params: Parameters of monodepth network. |
| |
| Returns: |
| The configured monodepth DPT. |
| """ |
| if params is None: |
| params = MonodepthParams() |
| encoder: SlidingPyramidNetwork = create_monodepth_encoder( |
| params.patch_encoder_preset, |
| params.image_encoder_preset, |
| use_patch_overlap=params.use_patch_overlap, |
| last_encoder=params.dims_decoder[0], |
| ) |
|
|
| decoder: MultiresConvDecoder = create_monodepth_decoder( |
| params.patch_encoder_preset, params.dims_decoder |
| ) |
|
|
| monodepth_model = MonodepthDensePredictionTransformer( |
| encoder=encoder, decoder=decoder, last_dims=(32, 1) |
| ) |
|
|
| |
| |
| monodepth_model.requires_grad_(False) |
|
|
| monodepth_model.encoder.set_requires_grad_( |
| patch_encoder=params.unfreeze_patch_encoder, |
| image_encoder=params.unfreeze_image_encoder, |
| ) |
| monodepth_model.decoder.requires_grad_(params.unfreeze_decoder) |
| monodepth_model.head.requires_grad_(params.unfreeze_head) |
|
|
| if not params.unfreeze_norm_layers: |
| module_surgery.freeze_norm_layer(monodepth_model) |
|
|
| monodepth_model.set_grad_checkpointing(params.grad_checkpointing) |
|
|
| return monodepth_model |
|
|
|
|
| class MonodepthOutput(NamedTuple): |
| """Output of the monodepth model.""" |
|
|
| |
| disparity: torch.Tensor |
| |
| encoder_features: list[torch.Tensor] |
| |
| decoder_features: torch.Tensor |
| |
| output_features: list[torch.Tensor] |
| |
| intermediate_features: list[torch.Tensor] = [] |
|
|
|
|
| class MonodepthWithEncodingAdaptor(nn.Module): |
| """Monodepth model with feature maps.""" |
|
|
| def __init__( |
| self, |
| monodepth_predictor: MonodepthDensePredictionTransformer, |
| return_encoder_features: bool, |
| return_decoder_features: bool, |
| num_monodepth_layers: int, |
| sorting_monodepth: bool, |
| ): |
| """Initialize MonodepthWithEncodingAdaptor. |
| |
| Args: |
| monodepth_predictor: The monodepth model. |
| return_encoder_features: Whether to return encoder features from monodepth model. |
| return_decoder_features: Whether to return decoder features from monodepth model. |
| num_monodepth_layers: How many layers the monodepth model predicts. |
| sorting_monodepth: Whether to sort the monodepth output (for two layer monodepth). |
| """ |
| super().__init__() |
| self.monodepth_predictor = monodepth_predictor |
| self.return_encoder_features = return_encoder_features |
| self.return_decoder_features = return_decoder_features |
| self.num_monodepth_layers = num_monodepth_layers |
| self.sorting_monodepth = sorting_monodepth |
|
|
| def forward(self, image: torch.Tensor) -> MonodepthOutput: |
| """Process image and return disparity and feature maps.""" |
| inputs = self.monodepth_predictor.normalizer(image) |
| encoder_output = self.monodepth_predictor.encoder(inputs) |
|
|
| num_encoder_features = len(self.monodepth_predictor.encoder.dims_encoder) |
|
|
| |
| |
| encoder_features = encoder_output[:num_encoder_features] |
| intermediate_features = encoder_output[num_encoder_features:] |
| decoder_features = self.monodepth_predictor.decoder(encoder_features) |
| disparity = self.monodepth_predictor.head(decoder_features) |
|
|
| |
| if self.num_monodepth_layers == 2 and self.sorting_monodepth: |
| first_layer_disparity = disparity.max(dim=1, keepdims=True).values |
| second_layer_disparity = disparity.min(dim=1, keepdims=True).values |
| disparity = torch.cat([first_layer_disparity, second_layer_disparity], dim=1) |
|
|
| output_features = [] |
| if self.return_encoder_features: |
| output_features.extend(encoder_features) |
|
|
| if self.return_decoder_features: |
| output_features.append(decoder_features) |
|
|
| return MonodepthOutput( |
| disparity=disparity, |
| encoder_features=encoder_features, |
| decoder_features=decoder_features, |
| output_features=output_features, |
| intermediate_features=intermediate_features, |
| ) |
|
|
| def get_feature_dims(self) -> list[int]: |
| """Return dimensions of output feature maps.""" |
| dims = [] |
| if self.return_encoder_features: |
| dims.extend(self.monodepth_predictor.encoder.dims_encoder) |
|
|
| if self.return_decoder_features: |
| dims.append(self.monodepth_predictor.decoder.dim_out) |
|
|
| return dims |
|
|
| def internal_resolution(self) -> int: |
| """Return the internal image size of the network.""" |
| return self.monodepth_predictor.internal_resolution() |
|
|
| def replicate_head(self, num_repeat: int): |
| """Replicate the last convolution layer (head[4] in DPT) for multi layer depth.""" |
| conv_last = copy.deepcopy(self.monodepth_predictor.head[4]) |
| self.monodepth_predictor.head[4].out_channels = num_repeat |
| self.monodepth_predictor.head[4].weight = nn.Parameter( |
| conv_last.weight.repeat(num_repeat, 1, 1, 1) |
| ) |
| self.monodepth_predictor.head[4].bias = nn.Parameter(conv_last.bias.repeat(num_repeat)) |
|
|
|
|
| def create_monodepth_adaptor( |
| monodepth_predictor: MonodepthDensePredictionTransformer, |
| params: MonodepthAdaptorParams, |
| num_monodepth_layers: int, |
| sorting_monodepth: bool, |
| ) -> MonodepthWithEncodingAdaptor: |
| """Create an adaptor that returns both disparity and features.""" |
| adaptor = MonodepthWithEncodingAdaptor( |
| monodepth_predictor=monodepth_predictor, |
| return_encoder_features=params.encoder_features, |
| return_decoder_features=params.decoder_features, |
| num_monodepth_layers=num_monodepth_layers, |
| sorting_monodepth=sorting_monodepth, |
| ) |
| return adaptor |
|
|