| |
| """PyTorch FLM-Audio model, based on LLAMA implementation.""" |
|
|
| import math |
| import warnings |
| from typing import List, Optional, Tuple, Union |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers.activations import ACT2FN |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.cache_utils import Cache, DynamicCache, StaticCache |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
| from transformers.modeling_outputs import ( |
| ModelOutput, |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| replace_return_docstrings, |
| ) |
| from .configuration_flmaudio import FLMAudioConfig |
| from .depth_gpt import DepthGPT, DepthGPTConfig |
|
|
| if is_flash_attn_2_available(): |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "FLMAudioConfig" |
|
|
|
|
| def _get_unpad_data(attention_mask): |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = seqlens_in_batch.max().item() |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| return ( |
| indices, |
| cu_seqlens, |
| max_seqlen_in_batch, |
| ) |
|
|
|
|
| class FLMAudioRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| FLMAudioRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| ALL_LAYERNORM_LAYERS.append(FLMAudioRMSNorm) |
|
|
| class FLMAudioRotaryEmbedding(nn.Module): |
| def __init__(self, config, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) |
| position_ids_expanded = position_ids[:, :, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): |
| mrope_section = mrope_section * 2 |
| cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( |
| unsqueeze_dim |
| ) |
| sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( |
| unsqueeze_dim |
| ) |
|
|
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class FLMAudioMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| if self.config.pretraining_tp > 1: |
| slice = self.intermediate_size // self.config.pretraining_tp |
| gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) |
| up_proj_slices = self.up_proj.weight.split(slice, dim=0) |
| down_proj_slices = self.down_proj.weight.split(slice, dim=1) |
|
|
| gate_proj = torch.cat( |
| [ |
| F.linear(x, gate_proj_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ], |
| dim=-1, |
| ) |
| up_proj = torch.cat( |
| [ |
| F.linear(x, up_proj_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ], |
| dim=-1, |
| ) |
|
|
| intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) |
| down_proj = [ |
| F.linear(intermediate_states[i], down_proj_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ] |
| down_proj = sum(down_proj) |
| else: |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
| return down_proj |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| class FLMAudioAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: FLMAudioConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.attention_dropout = config.attention_dropout |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.is_causal = True |
| self.rope_scaling = config.rope_scaling |
|
|
| if (self.head_dim * self.num_heads) != self.hidden_size: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
|
|
| self.q_proj = nn.Linear( |
| self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| self.hidden_size, |
| self.num_key_value_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| self.v_proj = nn.Linear( |
| self.hidden_size, |
| self.num_key_value_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| self.o_proj = nn.Linear( |
| self.hidden_size, self.hidden_size, bias=config.attention_bias and not config.disable_att_o_bias |
| ) |
|
|
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| bsz, q_len, _ = hidden_states.size() |
|
|
| if self.config.pretraining_tp > 1: |
| key_value_slicing = ( |
| self.num_key_value_heads * self.head_dim |
| ) // self.config.pretraining_tp |
| query_slices = self.q_proj.weight.split( |
| (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 |
| ) |
| key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) |
| value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) |
|
|
| query_states = [ |
| F.linear(hidden_states, query_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ] |
| query_states = torch.cat(query_states, dim=-1) |
|
|
| key_states = [ |
| F.linear(hidden_states, key_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ] |
| key_states = torch.cat(key_states, dim=-1) |
|
|
| value_states = [ |
| F.linear(hidden_states, value_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ] |
| value_states = torch.cat(value_states, dim=-1) |
|
|
| else: |
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view( |
| bsz, q_len, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| key_states = key_states.view( |
| bsz, q_len, self.num_key_value_heads, self.head_dim |
| ).transpose(1, 2) |
| value_states = value_states.view( |
| bsz, q_len, self.num_key_value_heads, self.head_dim |
| ).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
|
|
| query_states, key_states = apply_multimodal_rotary_pos_emb( |
| query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] |
| ) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| attn_weights = torch.matmul( |
| query_states, key_states.transpose(2, 3) |
| ) / math.sqrt(self.head_dim) |
|
|
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| |
| attn_weights = nn.functional.softmax( |
| attn_weights, dim=-1, dtype=torch.float32 |
| ).to(query_states.dtype) |
| attn_weights = nn.functional.dropout( |
| attn_weights, p=self.attention_dropout, training=self.training |
| ) |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
| if self.config.pretraining_tp > 1: |
| attn_output = attn_output.split( |
| self.hidden_size // self.config.pretraining_tp, dim=2 |
| ) |
| o_proj_slices = self.o_proj.weight.split( |
| self.hidden_size // self.config.pretraining_tp, dim=1 |
| ) |
| attn_output = sum( |
| [ |
| F.linear(attn_output[i], o_proj_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ] |
| ) |
| else: |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| class FLMAudioFlashAttention2(FLMAudioAttention): |
| """ |
| FLM-Audio flash attention module. This module inherits from `FLMAudioAttention` as the weights of the module stays |
| untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| flash attention and deal with padding tokens in case the input contains any of them. |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| |
| |
| |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| output_attentions = False |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| |
| |
| |
| query_states = query_states.view( |
| bsz, q_len, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| key_states = key_states.view( |
| bsz, q_len, self.num_key_value_heads, self.head_dim |
| ).transpose(1, 2) |
| value_states = value_states.view( |
| bsz, q_len, self.num_key_value_heads, self.head_dim |
| ).transpose(1, 2) |
|
|
| |
| cos, sin = position_embeddings |
| |
| |
| |
| query_states, key_states = apply_multimodal_rotary_pos_emb( |
| query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] |
| ) |
|
|
| past_key_value = getattr(self, "past_key_value", past_key_value) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| |
| |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
|
|
| dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
| |
| |
| |
| |
| |
|
|
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| |
| elif hasattr(self.config, "_pre_quantization_dtype"): |
| target_dtype = self.config._pre_quantization_dtype |
| else: |
| target_dtype = self.q_proj.weight.dtype |
|
|
| logger.warning_once( |
| f"The input hidden states seems to be silently casted in float32, this might be related to" |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| f" {target_dtype}." |
| ) |
|
|
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
|
|
| attn_output = self._flash_attention_forward( |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| q_len, |
| dropout=dropout_rate, |
| ) |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
| def _flash_attention_forward( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| query_length, |
| dropout=0.0, |
| softmax_scale=None, |
| ): |
| """ |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| first unpad the input, then computes the attention scores and pad the final attention scores. |
| |
| Args: |
| query_states (`torch.Tensor`): |
| Input query states to be passed to Flash Attention API |
| key_states (`torch.Tensor`): |
| Input key states to be passed to Flash Attention API |
| value_states (`torch.Tensor`): |
| Input value states to be passed to Flash Attention API |
| attention_mask (`torch.Tensor`): |
| The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| position of padding tokens and 1 for the position of non-padding tokens. |
| dropout (`float`): |
| Attention dropout |
| softmax_scale (`float`, *optional*): |
| The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| """ |
| if not self._flash_attn_uses_top_left_mask: |
| causal = self.is_causal |
| else: |
| |
| causal = self.is_causal and query_length != 1 |
|
|
| |
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| ( |
| query_states, |
| key_states, |
| value_states, |
| indices_q, |
| cu_seq_lens, |
| max_seq_lens, |
| ) = self._upad_input( |
| query_states, key_states, value_states, attention_mask, query_length |
| ) |
|
|
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| attn_output = pad_input( |
| attn_output_unpad, indices_q, batch_size, query_length |
| ) |
| else: |
| attn_output = flash_attn_func( |
| query_states, |
| key_states, |
| value_states, |
| dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| return attn_output |
|
|
| def _upad_input( |
| self, query_layer, key_layer, value_layer, attention_mask, query_length |
| ): |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
| key_layer = index_first_axis( |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
| indices_k, |
| ) |
| value_layer = index_first_axis( |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
| indices_k, |
| ) |
| if query_length == kv_seq_len: |
| query_layer = index_first_axis( |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), |
| indices_k, |
| ) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| elif query_length == 1: |
| max_seqlen_in_batch_q = 1 |
| cu_seqlens_q = torch.arange( |
| batch_size + 1, dtype=torch.int32, device=query_layer.device |
| ) |
| indices_q = cu_seqlens_q[:-1] |
| query_layer = query_layer.squeeze(1) |
| else: |
| |
| attention_mask = attention_mask[:, -query_length:] |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( |
| query_layer, attention_mask |
| ) |
|
|
| return ( |
| query_layer, |
| key_layer, |
| value_layer, |
| indices_q, |
| (cu_seqlens_q, cu_seqlens_k), |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| ) |
|
|
|
|
| class FLMAudioSdpaAttention(FLMAudioAttention): |
| """ |
| FLM-Audio attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from |
| `FLMAudioAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to |
| SDPA API. |
| """ |
|
|
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if output_attentions: |
| |
| logger.warning_once( |
| "FLMAudioModel is using FLMAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
| 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| ) |
| return super().forward( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view( |
| bsz, q_len, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| key_states = key_states.view( |
| bsz, q_len, self.num_key_value_heads, self.head_dim |
| ).transpose(1, 2) |
| value_states = value_states.view( |
| bsz, q_len, self.num_key_value_heads, self.head_dim |
| ).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
|
|
| query_states, key_states = apply_multimodal_rotary_pos_emb( |
| query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] |
| ) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| causal_mask = attention_mask |
| |
| if attention_mask is not None: |
| causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] |
|
|
| |
| |
| if query_states.device.type == "cuda" and causal_mask is not None: |
| query_states = query_states.contiguous() |
| key_states = key_states.contiguous() |
| value_states = value_states.contiguous() |
|
|
| attn_output = F.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=causal_mask, |
| dropout_p=self.attention_dropout if self.training else 0.0, |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
|
|
| attn_output = self.o_proj(attn_output) |
|
|
| return attn_output, None, past_key_value |
|
|
|
|
| FLMAUDIO_ATTENTION_CLASSES = { |
| "eager": FLMAudioAttention, |
| "flash_attention_2": FLMAudioFlashAttention2, |
| "sdpa": FLMAudioSdpaAttention, |
| } |
|
|
|
|
| class FLMAudioDecoderLayer(nn.Module): |
| def __init__(self, config: FLMAudioConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = FLMAUDIO_ATTENTION_CLASSES.get( |
| config._attn_implementation, FLMAudioAttention |
| )(config=config, layer_idx=layer_idx) |
| self.mlp = FLMAudioMLP(config) |
| self.input_layernorm = FLMAudioRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
| self.post_attention_layernorm = FLMAudioRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[ |
| torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): |
| attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
| query_sequence_length, key_sequence_length)` if default attention is used. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| """ |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
|
|
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| FLMAUDIO_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`FLMAudioConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare FLM-Audio Model outputting raw hidden-states without any specific head on top.", |
| FLMAUDIO_START_DOCSTRING, |
| ) |
| class FLMAudioPreTrainedModel(PreTrainedModel): |
| config_class = FLMAudioConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["FLMAudioDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def _setup_cache( |
| self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None |
| ): |
| if ( |
| self.config._attn_implementation == "flash_attention_2" |
| and cache_cls == StaticCache |
| ): |
| raise ValueError( |
| "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " |
| "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" |
| ) |
|
|
| for layer in self.model.layers: |
| device = layer.input_layernorm.weight.device |
| if hasattr(self.config, "_pre_quantization_dtype"): |
| dtype = self.config._pre_quantization_dtype |
| else: |
| dtype = layer.self_attn.o_proj.weight.dtype |
| layer.self_attn.past_key_value = cache_cls( |
| self.config, max_batch_size, max_cache_len, device=device, dtype=dtype |
| ) |
|
|
| def _reset_cache(self): |
| for layer in self.model.layers: |
| layer.self_attn.past_key_value = None |
|
|
|
|
| class MultiModalEmbedding(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.use_mup = config.use_mup |
| self.input_mult = config.input_mult |
| self.hidden_size = config.hidden_size |
|
|
| self.vocab_size = config.vocab_size |
| self.aud_vocab_size = config.aud_vocab_size |
|
|
| self.aud_channel = config.aud_channel |
|
|
| self.aud_emp_token_id = config.mm_token_info.aud_emp_token_id |
|
|
| self.text_embeddings = nn.Embedding(self.vocab_size, self.hidden_size) |
|
|
| self.aud_listen_embeddings = nn.ModuleList( |
| [ |
| nn.Embedding(self.aud_vocab_size, self.hidden_size) |
| for _ in range(self.aud_channel) |
| ] |
| ) |
| self.aud_speak_embeddings = nn.ModuleList( |
| [ |
| nn.Embedding(self.aud_vocab_size, self.hidden_size) |
| for _ in range(self.aud_channel) |
| ] |
| ) |
|
|
| @staticmethod |
| def merge_multichannel_embeddings( |
| token_ids, embedding_layer, emp_token_id, embeddings |
| ): |
| if token_ids is not None and embedding_layer is not None: |
| assert token_ids.shape[2] == len(embedding_layer) |
| for c in range(token_ids.shape[2]): |
| _emb_state = embedding_layer[c](token_ids[:, :, c]) |
| _emb_state[token_ids[:, :, c] == emp_token_id] = 0.0 |
| embeddings += _emb_state |
| _emb_state = None |
| del _emb_state |
| return embeddings |
|
|
| def forward( |
| self, |
| text_ids, |
| speak_ids, |
| listen_ids, |
| ): |
| assert text_ids is not None |
| embeddings = self.text_embeddings(text_ids) |
| mask = ~(text_ids == self.config.pad_token_id) |
|
|
| for aud_chn_idx in range(self.aud_channel): |
| aud_speak_embed = self.aud_speak_embeddings[aud_chn_idx]( |
| speak_ids[..., aud_chn_idx] |
| ).squeeze(0) |
| aud_listen_embed = self.aud_listen_embeddings[aud_chn_idx]( |
| listen_ids[..., aud_chn_idx] |
| ).squeeze(0) |
| embeddings[mask] += aud_speak_embed + aud_listen_embed |
|
|
| if self.use_mup: |
| embeddings = embeddings * self.input_mult |
|
|
| return embeddings |
|
|
|
|
| FLMAUDIO_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
| `past_key_values`). |
| |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
| returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
| |
| Two formats are allowed: |
| - a [`~cache_utils.Cache`] instance; |
| - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
| cache format. |
| |
| The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
| legacy cache format will be returned. |
| |
| If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
| have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
| of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| the complete sequence length. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare FLM-Audio Model outputting raw hidden-states without any specific head on top.", |
| FLMAUDIO_START_DOCSTRING, |
| ) |
| class FLMAudioModel(FLMAudioPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FLMAudioDecoderLayer`] |
| |
| Args: |
| config: FLMAudioConfig |
| """ |
|
|
| def __init__(self, config: FLMAudioConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = MultiModalEmbedding(config) |
| self.layers = nn.ModuleList( |
| [ |
| FLMAudioDecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self.norm = FLMAudioRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = FLMAudioRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| self.rope_deltas = None |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> MultiModalEmbedding: |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value: MultiModalEmbedding): |
| self.embed_tokens = value |
|
|
| def get_rope_index( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| second_per_grid_ts: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
| mrope_position_deltas = [] |
|
|
| if attention_mask is not None: |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) |
| max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
| mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
| else: |
| position_ids = ( |
| torch.arange(input_ids.shape[1], device=input_ids.device) |
| .view(1, 1, -1) |
| .expand(3, input_ids.shape[0], -1) |
| ) |
| mrope_position_deltas = torch.zeros( |
| [input_ids.shape[0], 1], |
| device=input_ids.device, |
| dtype=input_ids.dtype, |
| ) |
|
|
| return position_ids, mrope_position_deltas |
|
|
|
|
| @add_start_docstrings_to_model_forward(FLMAUDIO_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| text_ids: torch.LongTensor = None, |
| listen_ids: torch.LongTensor = None, |
| speak_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| second_per_grid_ts: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| if (text_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens( |
| text_ids, |
| speak_ids, |
| listen_ids, |
| ) |
|
|
| past_seen_tokens = 0 |
| if use_cache: |
| if not isinstance(past_key_values, StaticCache): |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| past_seen_tokens = past_key_values.get_seq_length() |
|
|
| if cache_position is None: |
| if isinstance(past_key_values, StaticCache): |
| raise ValueError( |
| "cache_position is a required argument when using StaticCache." |
| ) |
| cache_position = torch.arange( |
| past_seen_tokens, |
| past_seen_tokens + inputs_embeds.shape[1], |
| device=inputs_embeds.device, |
| ) |
|
|
| |
| if position_ids is None: |
| |
| if ( |
| (cache_position is not None and cache_position[0] == 0) |
| or self.rope_deltas is None |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) |
| ): |
| position_ids, rope_deltas = self.get_rope_index( |
| text_ids, |
| second_per_grid_ts, |
| attention_mask, |
| ) |
| self.rope_deltas = rope_deltas |
| |
| else: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = ( |
| (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
| if cache_position is not None |
| else 0 |
| ) |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position |
| ) |
|
|
| |
| hidden_states = inputs_embeds |
|
|
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| position_embeddings, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = ( |
| next_decoder_cache.to_legacy_cache() |
| if isinstance(next_decoder_cache, Cache) |
| else next_decoder_cache |
| ) |
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] |
| if v is not None |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
| |
| |
| |
| |
| def _update_causal_mask(self, attention_mask, input_tensor, cache_position): |
| if self.config._attn_implementation == "flash_attention_2": |
| if attention_mask is not None and 0.0 in attention_mask: |
| return attention_mask |
| return None |
|
|
| dtype, device = input_tensor.dtype, input_tensor.device |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| if hasattr( |
| getattr(self.layers[0], "self_attn", {}), "past_key_value" |
| ): |
| target_length = self.config.max_position_embeddings |
| else: |
| target_length = ( |
| attention_mask.shape[-1] |
| if isinstance(attention_mask, torch.Tensor) |
| else cache_position[-1] + 1 |
| ) |
|
|
| causal_mask = torch.full( |
| (sequence_length, target_length), |
| fill_value=min_dtype, |
| dtype=dtype, |
| device=device, |
| ) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange( |
| target_length, device=device |
| ) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand( |
| input_tensor.shape[0], 1, -1, -1 |
| ) |
| if attention_mask is not None: |
| causal_mask = ( |
| causal_mask.clone() |
| ) |
| if attention_mask.dim() == 2: |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ |
| :, None, None, : |
| ].eq(0.0) |
| causal_mask[..., :mask_length] = causal_mask[ |
| ..., :mask_length |
| ].masked_fill(padding_mask, min_dtype) |
| elif attention_mask.dim() == 4: |
| |
| |
| if attention_mask.shape[-2] < cache_position[0] + sequence_length: |
| offset = cache_position[0] |
| else: |
| offset = 0 |
| mask_shape = attention_mask.shape |
| mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype |
| causal_mask[ |
| : mask_shape[0], |
| : mask_shape[1], |
| offset : mask_shape[2] + offset, |
| : mask_shape[3], |
| ] = mask_slice |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type == "cuda" |
| ): |
| |
| |
| |
| causal_mask = AttentionMaskConverter._unmask_unattended( |
| causal_mask, min_dtype |
| ) |
|
|
| return causal_mask |
|
|
|
|
| @dataclass |
| class FLMAudioCausalLMOutputWithPast(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| audio_logits: torch.FloatTensor = None |
| past_key_values: Optional[List[torch.FloatTensor]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| rope_deltas: Optional[torch.LongTensor] = None |
|
|
|
|
| class FLMAudioForCausalLM(FLMAudioPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = FLMAudioModel(config) |
| self.vocab_size = config.vocab_size |
| self.output_mult = config.output_mult |
|
|
| depth_config = DepthGPTConfig( |
| block_size=config.aud_channel, |
| vocab_size=config.aud_vocab_size, |
| n_layer=config.aud_depthgpt.n_layer, |
| n_head=config.aud_depthgpt.n_head, |
| n_embd=config.aud_depthgpt.n_embd, |
| dropout=config.aud_depthgpt.dropout, |
| bias=config.aud_depthgpt.bias, |
| main_hidden_size=config.hidden_size, |
| pad_token_id=config.mm_token_info.aud_emp_token_id, |
| use_cmlp=config.aud_depthgpt.use_cmlp, |
| use_rmsnorm=config.aud_depthgpt.use_rmsnorm, |
| use_swiglu=config.aud_depthgpt.use_swiglu, |
| ) |
|
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.aud_output_layers = DepthGPT(depth_config) |
|
|
| self.use_mup = config.use_mup |
| if self.use_mup: |
| self.mup_scale_factor = config.mup_scale_factor |
| self.output_mult = config.output_mult / self.mup_scale_factor |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| def _forward_text(self, outputs, labels, return_dict): |
|
|
| logits = self.lm_head(outputs[0]) |
| logits = logits.float() |
| |
| if self.use_mup: |
| logits = logits * self.output_mult |
|
|
| loss = None |
| if labels is not None: |
| raise NotImplementedError |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return FLMAudioCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.last_hidden_state, |
| attentions=outputs.attentions, |
| ) |
|
|
| def forward_audio(self, transformer_output_states, audio_input_ids): |
| return self.aud_output_layers(transformer_output_states, audio_input_ids) |
|
|
| @add_start_docstrings_to_model_forward(FLMAUDIO_INPUTS_DOCSTRING) |
| @replace_return_docstrings( |
| output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC |
| ) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| listen_ids: torch.LongTensor = None, |
| speak_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| second_per_grid_ts: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, FLMAudioForCausalLM |
| |
| >>> model = FLMAudioForCausalLM.from_pretrained("CofeAI/FLM-Audio") |
| >>> tokenizer = AutoTokenizer.from_pretrained("CofeAI/FLM-Audio") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
|
|
| if listen_ids is None and speak_ids is None: |
| batch_size, seq_len = input_ids.shape[:2] |
| listen_ids = torch.full((seq_len*batch_size, 8), self.model.config.mm_token_info.aud_pad_token_id, device=input_ids.device, dtype=input_ids.dtype) |
| speak_ids = torch.full((seq_len*batch_size, 8), self.model.config.mm_token_info.aud_pad_token_id, device=input_ids.device, dtype=input_ids.dtype) |
| outputs = self.model( |
| text_ids=input_ids, |
| listen_ids=listen_ids, |
| speak_ids=speak_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| second_per_grid_ts=second_per_grid_ts |
| ) |
| return self._forward_text(outputs, labels, return_dict) |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple( |
| past_state.index_select(0, beam_idx.to(past_state.device)) |
| for past_state in layer_past |
| ), |
| ) |
| return reordered_past |
|
|
| def _get_initial_token(self) -> torch.Tensor: |
| |
| |
| device = next(iter(self.parameters())).device |
| zero = torch.full([1, 1, 1], 0, device=device, dtype=torch.long) |
| special = torch.full_like(zero, self.config.mm_token_info.aud_pad_token_id) |
|
|
| text_special = torch.full_like( |
| zero, self.config.mm_token_info.text_wait_token_id |
| ) |
| audio_token = special |
| text_token = text_special |
| audio_token = audio_token.expand(-1, 2 * self.config.aud_channel, -1).clone() |
| token = torch.cat([text_token, audio_token], dim=1) |
| return token |
|
|