| from transformers import PretrainedConfig |
| from transformers.utils import logging |
| from transformers.models.esm import EsmConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class ProtSTConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of a [`ProtSTModel`]. |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| Args: |
| protein_config (`dict`, *optional*): |
| Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`]. |
| ```""" |
|
|
| model_type = "protst" |
|
|
| def __init__( |
| self, |
| protein_config=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| if protein_config is None: |
| protein_config = {} |
| logger.info("`protein_config` is `None`. Initializing the `ProtSTProteinConfig` with default values.") |
|
|
| self.protein_config = EsmConfig(**protein_config) |
|
|
| @classmethod |
| def from_protein_text_configs( |
| cls, protein_config: EsmConfig, **kwargs |
| ): |
| r""" |
| Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns: |
| [`ProtSTConfig`]: An instance of a configuration object |
| """ |
|
|
| return cls(protein_config=protein_config.to_dict(), **kwargs) |