| from typing import Callable, Optional, Tuple, TypedDict, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache |
| from transformers.modeling_attn_mask_utils import ( |
| AttentionMaskConverter, |
| _prepare_4d_attention_mask, |
| _prepare_4d_attention_mask_for_sdpa, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions |
| from transformers.models.hubert.configuration_hubert import HubertConfig |
| from transformers.models.hubert.modeling_hubert import ( |
| HubertAttnAdapterLayer, |
| HubertFeedForward, |
| is_deepspeed_zero3_enabled, |
| ) |
| from transformers.utils import is_torchdynamo_compiling, logging |
| from typing_extensions import Unpack |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| class FlashAttentionKwargs(TypedDict, total=False): |
| """ |
| Keyword arguments for Flash Attention with Compile. |
| |
| Attributes: |
| cumulative_seqlens_q (`torch.LongTensor`, *optional*) |
| Gets cumulative sequence length for query state. |
| cumulative_seqlens_k (`torch.LongTensor`, *optional*) |
| Gets cumulative sequence length for key state. |
| max_length_q (`int`, *optional*): |
| Maximum sequence length for query state. |
| max_length_k (`int`, *optional*): |
| Maximum sequence length for key state. |
| """ |
|
|
| cumulative_seqlens_q: Optional[torch.LongTensor] |
| cumulative_seqlens_k: Optional[torch.LongTensor] |
| max_length_q: Optional[int] |
| max_length_k: Optional[int] |
|
|
|
|
| class SinusoidalPositionalEmbedding(nn.Module): |
| def __init__(self, config) -> None: |
| super().__init__() |
| weight = torch.empty( |
| ( |
| config.max_position_embeddings, |
| config.hidden_size, |
| ), |
| requires_grad=False, |
| ) |
| self._init_sinusoidal_embedding(weight) |
| self.register_buffer("position_embeddings", weight) |
|
|
| def _init_sinusoidal_embedding(self, embeddings: torch.Tensor) -> None: |
| T, D = embeddings.size() |
| position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / D) for j in range(D)] for pos in range(T)]) |
| embeddings[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) |
| embeddings[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) |
|
|
| def forward( |
| self, |
| inputs: torch.Tensor, |
| past_key_values_length: int = 0, |
| position_ids: torch.LongTensor | None = None, |
| ) -> torch.Tensor: |
| if position_ids is None: |
| bsz, seq_len = inputs.shape[:2] |
| position_ids = torch.arange( |
| past_key_values_length, |
| past_key_values_length + seq_len, |
| dtype=torch.long, |
| device=self.position_embeddings.device, |
| ).expand(bsz, -1) |
| else: |
| position_ids = position_ids.unsqueeze(0) |
| return self.position_embeddings[position_ids] |
|
|
|
|
| |
| class LearnedPositionalEmbedding(nn.Embedding): |
| """ |
| This module learns positional embeddings up to a fixed maximum size. |
| """ |
|
|
| def __init__(self, num_embeddings: int, embedding_dim: int): |
| |
| |
| |
| |
| super().__init__(num_embeddings, embedding_dim) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| past_key_values_length: int = 0, |
| position_ids: torch.LongTensor = None, |
| ): |
| """`input_ids' shape is expected to be [bsz x seqlen].""" |
|
|
| if position_ids is None: |
| bsz, seq_len = input_ids.shape[:2] |
| position_ids = torch.arange( |
| past_key_values_length, |
| past_key_values_length + seq_len, |
| dtype=torch.long, |
| device=self.weight.device, |
| ).expand(bsz, -1) |
| else: |
| position_ids = position_ids.unsqueeze(0) |
|
|
| |
| return super().forward(position_ids) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: Optional[float] = None, |
| dropout: float = 0.0, |
| head_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| if scaling is None: |
| scaling = query.size(-1) ** -0.5 |
|
|
| attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| if head_mask is not None: |
| attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) |
|
|
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class AVHubertAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| dropout: float = 0.0, |
| is_decoder: bool = False, |
| bias: bool = True, |
| is_causal: bool = False, |
| config: Optional[HubertConfig] = None, |
| layer_idx: Optional[int] = None, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.head_dim = embed_dim // num_heads |
| self.config = config |
|
|
| if (self.head_dim * num_heads) != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
| f" and `num_heads`: {num_heads})." |
| ) |
| self.scaling = self.head_dim**-0.5 |
| self.is_decoder = is_decoder |
| self.is_causal = is_causal |
| self.layer_idx = layer_idx |
| if layer_idx is None and self.is_decoder: |
| logger.warning_once( |
| f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| key_value_states: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Cache] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| cache_position: Optional[torch.Tensor] = None, |
| |
| |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| |
| |
| is_cross_attention = key_value_states is not None |
|
|
| |
| bsz, tgt_len = hidden_states.shape[:-1] |
| src_len = key_value_states.shape[1] if is_cross_attention else tgt_len |
|
|
| q_input_shape = (bsz, tgt_len, -1, self.head_dim) |
| kv_input_shape = (bsz, src_len, -1, self.head_dim) |
|
|
| |
| query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) |
|
|
| if past_key_value is not None: |
| if isinstance(past_key_value, EncoderDecoderCache): |
| is_updated = past_key_value.is_updated.get(self.layer_idx) |
| if is_cross_attention: |
| |
| curr_past_key_value = past_key_value.cross_attention_cache |
| else: |
| curr_past_key_value = past_key_value.self_attention_cache |
| else: |
| curr_past_key_value = past_key_value |
|
|
| current_states = key_value_states if is_cross_attention else hidden_states |
| if is_cross_attention and past_key_value is not None and is_updated: |
| |
| key_states = curr_past_key_value.key_cache[self.layer_idx] |
| value_states = curr_past_key_value.value_cache[self.layer_idx] |
| else: |
| key_states = self.k_proj(current_states) |
| value_states = self.v_proj(current_states) |
| key_states = key_states.view(*kv_input_shape).transpose(1, 2) |
| value_states = value_states.view(*kv_input_shape).transpose(1, 2) |
|
|
| if past_key_value is not None: |
| |
| cache_position = cache_position if not is_cross_attention else None |
| key_states, value_states = curr_past_key_value.update( |
| key_states, |
| value_states, |
| self.layer_idx, |
| {"cache_position": cache_position}, |
| ) |
| |
| if is_cross_attention: |
| past_key_value.is_updated[self.layer_idx] = True |
|
|
| attention_interface: Callable = eager_attention_forward |
| |
| |
| |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.dropout, |
| scaling=self.scaling, |
| output_attentions=output_attentions, |
| head_mask=layer_head_mask, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() |
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| class AVHubertDecoderLayer(nn.Module): |
| def __init__(self, config: HubertConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.attention = AVHubertAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| dropout=config.attention_dropout, |
| is_decoder=True, |
| is_causal=True, |
| config=config, |
| layer_idx=layer_idx, |
| ) |
| self.dropout = nn.Dropout(config.hidden_dropout) |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| self.encoder_attn = AVHubertAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| dropout=config.attention_dropout, |
| is_decoder=True, |
| config=config, |
| layer_idx=layer_idx, |
| ) |
|
|
| self.encoder_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.feed_forward = HubertFeedForward(config) |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| if getattr(config, "adapter_attn_dim", None) is not None: |
| self.adapter_layer = HubertAttnAdapterLayer(config) |
| else: |
| self.adapter_layer = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| encoder_hidden_states: torch.Tensor | None = None, |
| encoder_attention_mask: torch.Tensor | None = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| cross_attn_layer_head_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = True, |
| cache_position: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| residual = hidden_states |
| hidden_states, self_attn_weights, past_key_value = self.attention( |
| hidden_states=hidden_states, |
| past_key_value=past_key_value, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| ) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = residual + hidden_states |
| hidden_states = self.layer_norm(hidden_states) |
|
|
| |
| cross_attn_weights = None |
| if encoder_hidden_states is not None: |
| residual = hidden_states |
| hidden_states, cross_attn_weights, _ = self.encoder_attn( |
| hidden_states=hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| layer_head_mask=cross_attn_layer_head_mask, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| ) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = hidden_states + residual |
| hidden_states = self.encoder_layer_norm(hidden_states) |
|
|
| hidden_states = hidden_states + self.feed_forward(hidden_states) |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| if self.adapter_layer is not None: |
| hidden_states = hidden_states + self.adapter_layer(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights, cross_attn_weights) |
|
|
| if use_cache: |
| outputs += (past_key_value,) |
|
|
| return outputs |
|
|
|
|
| class AVHubertDecoderLayerStableLayerNorm(nn.Module): |
| def __init__(self, config: HubertConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.attention = AVHubertAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| dropout=config.attention_dropout, |
| is_decoder=True, |
| is_causal=True, |
| config=config, |
| layer_idx=layer_idx, |
| ) |
| self.dropout = nn.Dropout(config.hidden_dropout) |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| self.encoder_attn = AVHubertAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| dropout=config.attention_dropout, |
| is_decoder=True, |
| config=config, |
| layer_idx=layer_idx, |
| ) |
| self.encoder_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.feed_forward = HubertFeedForward(config) |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| if getattr(config, "adapter_attn_dim", None) is not None: |
| self.adapter_layer = HubertAttnAdapterLayer(config) |
| else: |
| self.adapter_layer = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| encoder_hidden_states: torch.Tensor | None = None, |
| encoder_attention_mask: torch.Tensor | None = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| cross_attn_layer_head_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = True, |
| cache_position: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| residual = hidden_states |
| hidden_states = self.layer_norm(hidden_states) |
|
|
| hidden_states, self_attn_weights, past_key_value = self.attention( |
| hidden_states=hidden_states, |
| past_key_value=past_key_value, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| ) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| |
| cross_attn_weights = None |
| if encoder_hidden_states is not None: |
| residual = hidden_states |
| hidden_states = self.encoder_layer_norm(hidden_states) |
|
|
| hidden_states, cross_attn_weights, _ = self.encoder_attn( |
| hidden_states=hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| layer_head_mask=cross_attn_layer_head_mask, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| ) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = hidden_states + residual |
|
|
| hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) |
|
|
| if self.adapter_layer is not None: |
| hidden_states = hidden_states + self.adapter_layer(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights, cross_attn_weights) |
|
|
| if use_cache: |
| outputs += (past_key_value,) |
|
|
| return outputs |
|
|
|
|
| class AVHubertDecoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| if config.learned_pos: |
| self.pos_embed = LearnedPositionalEmbedding( |
| num_embeddings=config.max_position_embeddings, |
| embedding_dim=config.hidden_size, |
| ) |
| else: |
| self.pos_embed = SinusoidalPositionalEmbedding(config=config) |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout) |
| self.layers = nn.ModuleList([AVHubertDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| inputs_embeds: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| encoder_hidden_states: torch.Tensor | None = None, |
| encoder_attention_mask: torch.Tensor | None = None, |
| head_mask: torch.Tensor | None = None, |
| cross_attn_head_mask: torch.Tensor | None = None, |
| past_key_values: EncoderDecoderCache | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| cache_position: torch.LongTensor | None = None, |
| ): |
| input_shape = inputs_embeds.shape[:-1] |
| if use_cache and past_key_values is None: |
| past_key_values = EncoderDecoderCache.from_legacy_cache() |
| past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| batch_size, seq_length = inputs_embeds.size()[:-1] |
| if cache_position is None: |
| cache_position = torch.arange( |
| past_key_values_length, |
| past_key_values_length + seq_length, |
| device=inputs_embeds.device, |
| ) |
|
|
| if attention_mask is None and not is_torchdynamo_compiling(): |
| |
| mask_seq_length = past_key_values_length + seq_length |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
|
|
| self_attn_cache = ( |
| past_key_values.self_attention_cache |
| if isinstance(past_key_values, EncoderDecoderCache) |
| else past_key_values |
| ) |
|
|
| attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, self_attn_cache) |
| encoder_attention_mask = self._update_cross_attn_mask( |
| encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds |
| ) |
|
|
| |
| position_embeddings = self.pos_embed(inputs_embeds, past_key_values_length, position_ids=cache_position) |
| hidden_states = inputs_embeds + position_embeddings |
| hidden_states = self.dropout(hidden_states) |
|
|
| deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
| next_decoder_cache = None |
|
|
| |
| for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): |
| if attn_mask is not None: |
| if attn_mask.size()[0] != (len(self.layers)): |
| raise ValueError( |
| f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
| f" {head_mask.size()[0]}." |
| ) |
|
|
| for idx, layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| |
| dropout_probability = torch.rand([]) |
|
|
| skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False |
| if not skip_the_layer or deepspeed_zero3_is_enabled: |
| |
| |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| layer.__call__, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| output_attentions, |
| ) |
| raise NotImplementedError("Currently, gradient checkpointing is not supported.") |
| else: |
| layer_outputs = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| cross_attn_layer_head_mask=( |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
| ), |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if skip_the_layer: |
| layer_outputs = (None, None, None, None) |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[3 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| if encoder_hidden_states is not None: |
| all_cross_attentions += (layer_outputs[2],) |
|
|
| hidden_states = self.layer_norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| next_cache, |
| all_hidden_states, |
| all_self_attns, |
| all_cross_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: Optional[torch.Tensor], |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_key_values: Cache, |
| ): |
| if self.config._attn_implementation == "flex_attention": |
| raise NotImplementedError |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| raise NotImplementedError |
|
|
| |
| |
| |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| using_compilable_cache = True if isinstance(past_key_values, StaticCache) else False |
|
|
| |
| if self.config._attn_implementation == "sdpa" and not using_compilable_cache: |
| if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| attention_mask, |
| inputs_embeds=input_tensor, |
| past_key_values_length=past_seen_tokens, |
| is_training=self.training, |
| ): |
| return None |
|
|
| dtype = input_tensor.dtype |
| sequence_length = input_tensor.shape[1] |
| if using_compilable_cache: |
| target_length = past_key_values.get_max_cache_shape() |
| else: |
| target_length = ( |
| attention_mask.shape[-1] |
| if isinstance(attention_mask, torch.Tensor) |
| else past_seen_tokens + sequence_length + 1 |
| ) |
|
|
| |
| causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| cache_position=cache_position, |
| batch_size=input_tensor.shape[0], |
| ) |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type in ["cuda", "xpu", "npu"] |
| ): |
| |
| |
| |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
| return causal_mask |
|
|
| @staticmethod |
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| **kwargs, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, |
| to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = torch.full( |
| (sequence_length, target_length), |
| fill_value=min_dtype, |
| dtype=dtype, |
| device=cache_position.device, |
| ) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| causal_mask.device |
| ) |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
|
|
| return causal_mask |
|
|
| def _update_cross_attn_mask( |
| self, |
| encoder_hidden_states: Union[torch.Tensor, None], |
| encoder_attention_mask: Union[torch.Tensor, None], |
| input_shape: torch.Size, |
| inputs_embeds: torch.Tensor, |
| ): |
| |
| if encoder_hidden_states is not None and encoder_attention_mask is not None: |
| if self.config._attn_implementation == "flash_attention_2": |
| encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None |
| elif self.config._attn_implementation == "sdpa": |
| |
| |
| |
| encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
| encoder_attention_mask, |
| inputs_embeds.dtype, |
| tgt_len=input_shape[-1], |
| ) |
| elif self.config._attn_implementation == "flex_attention": |
| raise NotImplementedError |
| else: |
| |
| encoder_attention_mask = _prepare_4d_attention_mask( |
| encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
| ) |
|
|
| return encoder_attention_mask |
|
|
|
|
| class AVHubertDecoderStableLayerNorm(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| if config.learned_pos: |
| self.pos_embed = LearnedPositionalEmbedding( |
| num_embeddings=config.max_position_embeddings, |
| embedding_dim=config.hidden_size, |
| ) |
| else: |
| self.pos_embed = SinusoidalPositionalEmbedding(config=config) |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout) |
| self.layers = nn.ModuleList( |
| [ |
| AVHubertDecoderLayerStableLayerNorm(config, layer_idx=layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| inputs_embeds: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| encoder_hidden_states: torch.Tensor | None = None, |
| encoder_attention_mask: torch.Tensor | None = None, |
| head_mask: torch.Tensor | None = None, |
| cross_attn_head_mask: torch.Tensor | None = None, |
| past_key_values: EncoderDecoderCache | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| cache_position: torch.LongTensor | None = None, |
| ): |
| input_shape = inputs_embeds.shape[:-1] |
| if use_cache and past_key_values is None: |
| past_key_values = EncoderDecoderCache.from_legacy_cache() |
| past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| batch_size, seq_length = inputs_embeds.size()[:-1] |
| if cache_position is None: |
| cache_position = torch.arange( |
| past_key_values_length, |
| past_key_values_length + seq_length, |
| device=inputs_embeds.device, |
| ) |
|
|
| if attention_mask is None and not is_torchdynamo_compiling(): |
| |
| mask_seq_length = past_key_values_length + seq_length |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
|
|
| self_attn_cache = ( |
| past_key_values.self_attention_cache |
| if isinstance(past_key_values, EncoderDecoderCache) |
| else past_key_values |
| ) |
|
|
| attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, self_attn_cache) |
| encoder_attention_mask = self._update_cross_attn_mask( |
| encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds |
| ) |
|
|
| |
| position_embeddings = self.pos_embed(inputs_embeds, past_key_values_length, position_ids=cache_position) |
| hidden_states = inputs_embeds + position_embeddings |
| hidden_states = self.dropout(hidden_states) |
|
|
| deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
| next_decoder_cache = None |
|
|
| |
| for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): |
| if attn_mask is not None: |
| if attn_mask.size()[0] != (len(self.layers)): |
| raise ValueError( |
| f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
| f" {head_mask.size()[0]}." |
| ) |
|
|
| for idx, layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| |
| dropout_probability = torch.rand([]) |
|
|
| skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False |
| if not skip_the_layer or deepspeed_zero3_is_enabled: |
| |
| |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| layer.__call__, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| output_attentions, |
| ) |
| raise NotImplementedError("Currently, gradient checkpointing is not supported.") |
| else: |
| layer_outputs = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| cross_attn_layer_head_mask=( |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
| ), |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if skip_the_layer: |
| layer_outputs = (None, None, None, None) |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[3 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| if encoder_hidden_states is not None: |
| all_cross_attentions += (layer_outputs[2],) |
|
|
| hidden_states = self.layer_norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| next_cache, |
| all_hidden_states, |
| all_self_attns, |
| all_cross_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: Optional[torch.Tensor], |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_key_values: Cache, |
| ): |
| if self.config._attn_implementation == "flex_attention": |
| raise NotImplementedError |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| raise NotImplementedError |
|
|
| |
| |
| |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| using_compilable_cache = True if isinstance(past_key_values, StaticCache) else False |
|
|
| |
| if self.config._attn_implementation == "sdpa" and not using_compilable_cache: |
| if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| attention_mask, |
| inputs_embeds=input_tensor, |
| past_key_values_length=past_seen_tokens, |
| is_training=self.training, |
| ): |
| return None |
|
|
| dtype = input_tensor.dtype |
| sequence_length = input_tensor.shape[1] |
| if using_compilable_cache: |
| target_length = past_key_values.get_max_cache_shape() |
| else: |
| target_length = ( |
| attention_mask.shape[-1] |
| if isinstance(attention_mask, torch.Tensor) |
| else past_seen_tokens + sequence_length + 1 |
| ) |
|
|
| |
| causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| cache_position=cache_position, |
| batch_size=input_tensor.shape[0], |
| ) |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type in ["cuda", "xpu", "npu"] |
| ): |
| |
| |
| |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
| return causal_mask |
|
|
| @staticmethod |
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| **kwargs, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, |
| to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = torch.full( |
| (sequence_length, target_length), |
| fill_value=min_dtype, |
| dtype=dtype, |
| device=cache_position.device, |
| ) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| causal_mask.device |
| ) |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
|
|
| return causal_mask |
|
|
| def _update_cross_attn_mask( |
| self, |
| encoder_hidden_states: Union[torch.Tensor, None], |
| encoder_attention_mask: Union[torch.Tensor, None], |
| input_shape: torch.Size, |
| inputs_embeds: torch.Tensor, |
| ): |
| |
| if encoder_hidden_states is not None and encoder_attention_mask is not None: |
| if self.config._attn_implementation == "flash_attention_2": |
| encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None |
| elif self.config._attn_implementation == "sdpa": |
| |
| |
| |
| encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
| encoder_attention_mask, |
| inputs_embeds.dtype, |
| tgt_len=input_shape[-1], |
| ) |
| elif self.config._attn_implementation == "flex_attention": |
| raise NotImplementedError |
| else: |
| |
| encoder_attention_mask = _prepare_4d_attention_mask( |
| encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
| ) |
|
|
| return encoder_attention_mask |
|
|