| """Contains factory function for loading/creating monodepth decoder. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
|
|
| from __future__ import annotations |
|
|
| from sharp.models.presets import ( |
| MONODEPTH_ENCODER_DIMS_MAP, |
| ViTPreset, |
| ) |
|
|
| from .multires_conv_decoder import MultiresConvDecoder |
|
|
|
|
| def create_monodepth_decoder( |
| patch_encoder_preset: ViTPreset, |
| dims_decoder=None, |
| ) -> MultiresConvDecoder: |
| """Create DepthDensePredictionTransformer model. |
| |
| Args: |
| patch_encoder_preset: The preset patch encoder architecture in SPN. |
| dims_decoder: The decoder architecture. |
| """ |
| dims_encoder = MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset] |
| if dims_decoder is None: |
| dims_decoder = dims_encoder[0] |
| if isinstance(dims_decoder, int): |
| dims_decoder = [dims_decoder] |
| decoder = MultiresConvDecoder( |
| dims_encoder=[dims_decoder[0]] + list(dims_encoder), dims_decoder=dims_decoder |
| ) |
|
|
| return decoder |
|
|