| | from transformers import AutoConfig |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.models.dinov2.configuration_dinov2 import Dinov2Config |
| |
|
| |
|
| | class VisionConfig(PretrainedConfig): |
| | def __init__( |
| | self, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | @staticmethod |
| | def from_exp_config(vision_config: dict): |
| |
|
| | model_type = vision_config["model_type"] |
| |
|
| | if model_type in [ |
| | "siglip_vision_model", |
| | "clip_vision_model", |
| | "dinov2", |
| | "sam", |
| | "raddino", |
| | ]: |
| | config = AutoConfig.from_pretrained( |
| | vision_config["pretrained_name_or_path"] |
| | ) |
| | config = config.to_dict() |
| | vision_config.update(config) |
| | elif model_type == "xrayclip": |
| | config = AutoConfig.from_pretrained( |
| | vision_config["pretrained_name_or_path"] |
| | ) |
| | config = config.to_dict() |
| | config["model_type"] = "xrayclip" |
| | vision_config.update(config) |
| | elif model_type == "biomedclip": |
| | pass |
| | elif model_type == "m3ae": |
| | pass |
| |
|
| | else: |
| | raise NotImplementedError() |
| |
|
| | vision_config = VisionConfig(**vision_config) |
| |
|
| | return vision_config |
| |
|
| |
|
| | class TextConfig(PretrainedConfig): |
| | def __init__( |
| | self, |
| | model_type, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.model_type = model_type |
| |
|
| | @staticmethod |
| | def from_exp_config( |
| | text_config: dict, |
| | ): |
| | model_type = text_config["model_type"] |
| |
|
| | if model_type in [ |
| | "siglip_text_model", |
| | "clip_text_model", |
| | "mpnet", |
| | "biomedclip", |
| | "bioclinicalmpbert", |
| | ]: |
| | text_config = TextConfig(**text_config) |
| | else: |
| | raise NotImplementedError() |
| |
|
| | return text_config |
| |
|
| |
|
| | class AlignTransformerConfig(PretrainedConfig): |
| | def __init__( |
| | self, |
| | model_type: str = "align_transformer", |
| | projector_config=None, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.model_type = model_type |
| | self.projector_config = projector_config |
| |
|
| | @staticmethod |
| | def from_exp_config( |
| | align_transformer_config: dict, |
| | ): |
| | projector_config = align_transformer_config.pop("projector_config", None) |
| |
|
| | config = Dinov2Config(**align_transformer_config) |
| | config = config.to_dict() |
| |
|
| | align_transformer_config = AlignTransformerConfig( |
| | **(config | align_transformer_config), |
| | projector_config=projector_config, |
| | ) |
| |
|
| | return align_transformer_config |
| |
|
| |
|
| | class CxrAlignConfig(PretrainedConfig): |
| | is_composition = True |
| |
|
| | def __init__( |
| | self, |
| | vision_config: dict, |
| | text_config: dict, |
| | align_transformer_config: dict, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | |
| | self.vision_config = VisionConfig.from_exp_config(vision_config) |
| |
|
| | |
| | self.text_config = TextConfig.from_exp_config(text_config) |
| |
|
| | self.align_transformer_config = AlignTransformerConfig.from_exp_config( |
| | align_transformer_config |
| | ) |
| |
|
| | self.kwargs = kwargs |
| |
|