| from dataclasses import dataclass, field |
| from typing import Any, Dict, Optional |
|
|
| @dataclass |
| class LossConfiguration: |
| num_classes: int |
|
|
| xent_weight: float = 1.0 |
| dice_weight: float = 1.0 |
| focal_loss: bool = False |
| focal_loss_gamma: float = 2.0 |
| requires_frustrum: bool = True |
| requires_flood_mask: bool = False |
| class_weights: Optional[Any] = None |
| label_smoothing: float = 0.1 |
|
|
| @dataclass |
| class BackboneConfigurationBase: |
| pretrained: bool |
| frozen: bool |
| output_dim: bool |
|
|
| @dataclass |
| class DINOConfiguration(BackboneConfigurationBase): |
| pretrained: bool = True |
| frozen: bool = False |
| output_dim: int = 128 |
|
|
| @dataclass |
| class ResNetConfiguration(BackboneConfigurationBase): |
| input_dim: int |
| encoder: str |
| remove_stride_from_first_conv: bool |
| num_downsample: Optional[int] |
| decoder_norm: str |
| do_average_pooling: bool |
| checkpointed: bool |
|
|
| @dataclass |
| class ImageEncoderConfiguration: |
| name: str |
| backbone: Any |
|
|
| @dataclass |
| class ModelConfiguration: |
| segmentation_head: Dict[str, Any] |
| image_encoder: ImageEncoderConfiguration |
|
|
| name: str |
| num_classes: int |
| latent_dim: int |
| z_max: int |
| x_max: int |
| |
| pixel_per_meter: int |
| num_scale_bins: int |
|
|
| loss: LossConfiguration |
|
|
| scale_range: list[int] = field(default_factory=lambda: [0, 9]) |
| z_min: Optional[int] = None |