| |
| """Distilled Audio State-Space Model (DASS) configuration""" |
|
|
| from typing import Any, Dict |
|
|
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| class DASSConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of a [`DASSModel`]. It is used to instantiate a DASS |
| model according to the specified arguments, defining the model architecture. Instantiating a configuration with the |
| defaults will yield a similar configuration to that of the |
| [DASS-small](https://github.com/Saurabhbhati/DASS/) architecture. |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| |
| Args: |
| patch_size (`int`, *optional*, defaults to 4): |
| The size (resolution) of each patch. |
| embed_dim (`int`, *optional*, defaults to 96): |
| Dimensionality of patch embedding. |
| depths (`list(int)`, *optional*, defaults to `[2, 2, 8, 2]`): |
| Depth of each layer in the DASS encoder. |
| dims (`list(int)`, *optional*, defaults to `[96, 192, 384, 768]`): |
| Dimensionality of each layer in the DASS encoder. |
| drop_path_rate (`float`, *optional*, defaults to 0.2): |
| Stochastic depth rate. |
| num_classes (`int`, *optional*, defaults to 527): |
| Number of classes for classification. |
| max_length (`int`, *optional*, defaults to 1024): |
| Temporal dimension of the spectrograms. |
| num_mel_bins (`int`, *optional*, defaults to 128): |
| Frequency dimension of the spectrograms (number of Mel-frequency bins). |
| use_checkpoint (`bool`, *optional*, defaults to `False`): |
| Whether to use checkpointing to save memory. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import DASSConfig, DASSModel |
| |
| >>> # Initializing a DASS small style configuration |
| >>> configuration = DASSConfig() |
| |
| >>> # Initializing a model (with random weights) from the DASS small style configuration |
| >>> model = DASSModel(configuration) |
| |
| >>> # Accessing the model configuration |
| >>> configuration = model.config |
| ```""" |
|
|
| model_type = "dass" |
|
|
| def __init__( |
| self, |
| patch_size: int = 4, |
| embed_dim: int = 96, |
| depths: list = [2, 2, 8, 2], |
| dims: list =[96, 192, 384, 768], |
| drop_path_rate: float = 0.2, |
| num_classes: int = 527, |
| max_length: int = 1024, |
| num_mel_bins: int = 128, |
| use_checkpoint: bool = False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.patch_size = patch_size |
| self.embed_dim = embed_dim |
| self.depths = depths |
| self.dims = dims |
| self.drop_path_rate = drop_path_rate |
| self.num_classes = num_classes |
| self.max_length = max_length |
| self.num_mel_bins = num_mel_bins |
| self.use_checkpoint = use_checkpoint |
|
|
| |
| |
| |
| def _get_non_default_generation_parameters(self) -> Dict[str, Any]: |
| return {} |
|
|
|
|
| __all__ = ["DASSConfig"] |
|
|