| | from typing import Optional, Union |
| |
|
| | from transformers import Qwen2Config |
| | from transformers.configuration_utils import PretrainedConfig |
| |
|
| |
|
| | class StepAudio2EncoderConfig(PretrainedConfig): |
| | model_type = "step_audio_2_encoder" |
| |
|
| | def __init__( |
| | self, |
| | n_mels=128, |
| | n_audio_ctx=1500, |
| | n_audio_state=512, |
| | n_audio_head=8, |
| | n_audio_layer=6, |
| | llm_dim=4096, |
| | kernel_size=3, |
| | adapter_stride=2, |
| | **kwargs, |
| | ): |
| | self.n_mels = n_mels |
| | self.n_audio_ctx = n_audio_ctx |
| | self.n_audio_state = n_audio_state |
| | self.n_audio_head = n_audio_head |
| | self.n_audio_layer = n_audio_layer |
| | self.llm_dim = llm_dim |
| | self.kernel_size = kernel_size |
| | self.adapter_stride = adapter_stride |
| | super().__init__(**kwargs) |
| |
|
| | class StepAudio2TextConfig(PretrainedConfig): |
| | model_type = "step_audio_2_text" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size=64012, |
| | hidden_size=4096, |
| | intermediate_size=11008, |
| | num_hidden_layers=48, |
| | num_attention_heads=32, |
| | num_attention_groups=4, |
| | num_key_value_heads=4, |
| | hidden_act="silu", |
| | max_position_embeddings=8192, |
| | initializer_range=0.02, |
| | rms_norm_eps=1e-6, |
| | rope_theta=1000000.0, |
| | rope_scaling=None, |
| | eos_token_id=None, |
| | **kwargs |
| | ): |
| |
|
| | if eos_token_id is not None: |
| | if isinstance(eos_token_id, list): |
| | eos_token_id = list(set([151643, 151645, 151665] + eos_token_id)) |
| | else: |
| | eos_token_id = [151643, 151645, 151665, eos_token_id] |
| | else: |
| | eos_token_id = [151643, 151645, 151665] |
| |
|
| | super().__init__( |
| | eos_token_id=eos_token_id, |
| | **kwargs) |
| |
|
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.num_attention_groups = num_attention_groups |
| | self.num_key_value_heads = num_key_value_heads |
| | assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads" |
| | self.hidden_act = hidden_act |
| | self.max_position_embeddings = max_position_embeddings |
| | self.initializer_range = initializer_range |
| | self.rms_norm_eps = rms_norm_eps |
| | self.rope_theta = rope_theta |
| | self.rope_scaling = rope_scaling |
| |
|
| | |
| | torch_dtype = kwargs.get("torch_dtype", getattr(self, "torch_dtype", "bfloat16")) |
| | |
| | self.text_config = Qwen2Config( |
| | vocab_size=vocab_size, |
| | hidden_size=hidden_size, |
| | intermediate_size=intermediate_size, |
| | num_hidden_layers=num_hidden_layers, |
| | num_attention_heads=num_attention_heads, |
| | num_key_value_heads=num_key_value_heads, |
| | hidden_act=hidden_act, |
| | max_position_embeddings=max_position_embeddings, |
| | initializer_range=initializer_range, |
| | rms_norm_eps=rms_norm_eps, |
| | rope_theta=rope_theta, |
| | rope_scaling=rope_scaling, |
| | architectures=["Qwen2ForCausalLM"], |
| | torch_dtype=torch_dtype, |
| | ) |
| |
|
| | class StepAudio2Config(PretrainedConfig): |
| | model_type = "step_audio_2" |
| | architectures = ["StepAudio2ForCausalLM"] |
| | |
| | |
| | |
| |
|
| | def __init__( |
| | self, |
| | audio_encoder_config :Optional[Union[dict, StepAudio2EncoderConfig]] = None, |
| | text_config: Optional[Union[dict, StepAudio2TextConfig]] = None, |
| | use_sliding_window: bool = False, |
| | sliding_window: Optional[int] = 2048, |
| | max_window_layers: Optional[int] = None, |
| | **kwargs |
| | ): |
| | kwargs.setdefault("use_sliding_window", use_sliding_window) |
| | kwargs.setdefault("sliding_window", sliding_window) |
| | if max_window_layers is None: |
| | max_window_layers = kwargs.get("num_hidden_layers", None) |
| | kwargs.setdefault("max_window_layers", max_window_layers) |
| | |
| | |
| | if 'torch_dtype' in kwargs: |
| | self.torch_dtype = kwargs['torch_dtype'] |
| | |
| | super().__init__(**kwargs) |
| |
|
| | |
| | |
| | if text_config is None: |
| | |
| | flat_text_params = {} |
| | text_param_names = [ |
| | 'vocab_size', 'hidden_size', 'intermediate_size', 'num_hidden_layers', |
| | 'num_attention_heads', 'num_attention_groups', 'num_key_value_heads', |
| | 'hidden_act', 'max_position_embeddings', 'initializer_range', |
| | 'rms_norm_eps', 'rope_theta', 'rope_scaling', 'eos_token_id', 'pad_token_id' |
| | ] |
| | |
| | for param_name in text_param_names: |
| | if param_name in kwargs: |
| | flat_text_params[param_name] = kwargs[param_name] |
| | |
| | |
| | if 'hidden_act' not in flat_text_params: |
| | flat_text_params['hidden_act'] = 'silu' |
| | |
| | |
| | if 'initializer_range' not in flat_text_params: |
| | flat_text_params['initializer_range'] = 0.02 |
| | |
| | |
| | if 'torch_dtype' in kwargs: |
| | flat_text_params['torch_dtype'] = kwargs['torch_dtype'] |
| | |
| | if flat_text_params: |
| | |
| | text_config = StepAudio2TextConfig(**flat_text_params).text_config |
| | else: |
| | |
| | text_config = StepAudio2TextConfig().text_config |
| | elif isinstance(text_config, dict): |
| | text_config = StepAudio2TextConfig(**text_config).text_config |
| |
|
| | self.text_config = text_config |
| |
|
| | if audio_encoder_config is None: |
| | |
| | if 'audio_encoder_config' in kwargs and isinstance(kwargs['audio_encoder_config'], dict): |
| | self.audio_encoder_config = StepAudio2EncoderConfig(**kwargs['audio_encoder_config']) |
| | else: |
| | self.audio_encoder_config = StepAudio2EncoderConfig() |
| | elif isinstance(audio_encoder_config, dict): |
| | self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config) |
| |
|