| """ |
| AuriStream Configuration for HuggingFace Transformers. |
| |
| AuriStream is a speech language model by Greta Tuckute and Klemen Kotar. |
| """ |
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| class AuriStreamConfig(PretrainedConfig): |
| """ |
| Configuration class for AuriStream models. |
| |
| This configuration supports various model sizes and prediction head configurations |
| for the AuriStream speech language model family. |
| |
| Args: |
| vocab_size (`int`, *optional*, defaults to 8192): |
| Vocabulary size of the model (number of cochlear tokens). |
| n_embd (`int`, *optional*, defaults to 768): |
| Dimensionality of the embeddings and hidden states. |
| n_layer (`int`, *optional*, defaults to 12): |
| Number of transformer layers. |
| n_head (`int`, *optional*, defaults to 12): |
| Number of attention heads for each attention layer. |
| n_pred_steps (`int`, *optional*, defaults to 1): |
| Number of future prediction steps (multi-token prediction heads). |
| dropout (`float`, *optional*, defaults to 0.0): |
| Dropout probability for all fully connected layers. |
| bias (`bool`, *optional*, defaults to False): |
| Whether to use bias in linear layers. |
| rope_theta (`float`, *optional*, defaults to 10000.0): |
| Base theta for RoPE embeddings. |
| input_conv_kernel_size (`int`, *optional*, defaults to 0): |
| Kernel size for input convolution layer (0 means no input conv). |
| """ |
| |
| model_type = "AuriStream" |
| |
| def __init__( |
| self, |
| vocab_size: int = 8192, |
| n_embd: int = 768, |
| n_layer: int = 12, |
| n_head: int = 12, |
| n_pred_steps: int = 1, |
| dropout: float = 0.0, |
| bias: bool = False, |
| rope_theta: float = 10000.0, |
| input_conv_kernel_size: int = 0, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.n_embd = n_embd |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.n_pred_steps = n_pred_steps |
| self.dropout = dropout |
| self.bias = bias |
| self.rope_theta = rope_theta |
| self.input_conv_kernel_size = input_conv_kernel_size |
| |
| super().__init__(**kwargs) |
| |
| @classmethod |
| def from_local_config(cls, local_cfg): |
| """ |
| Create an AuriStreamConfig from a local dataclass config. |
| |
| Args: |
| local_cfg: A dataclass config object (e.g., AuriStream100M20PredConfig) |
| |
| Returns: |
| AuriStreamConfig instance |
| """ |
| config_dict = {} |
| |
| |
| known_attrs = [ |
| 'vocab_size', 'n_embd', 'n_layer', 'n_head', 'n_pred_steps', |
| 'dropout', 'bias', 'rope_theta', 'input_conv_kernel_size' |
| ] |
| |
| for attr in known_attrs: |
| if hasattr(local_cfg, attr): |
| config_dict[attr] = getattr(local_cfg, attr) |
| |
| |
| if 'n_pred_steps' not in config_dict: |
| config_dict['n_pred_steps'] = 1 |
| |
| return cls(**config_dict) |
|
|