|
|
from functools import partial |
|
|
from typing import Any, Callable, Literal, Optional |
|
|
|
|
|
import torch |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VLDecoderLayer, |
|
|
Qwen2_5_VLFlashAttention2, |
|
|
rotate_half, |
|
|
) |
|
|
|
|
|
from .casa_attention import CASAAttention, CASAAttentionHandler |
|
|
from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig |
|
|
|
|
|
|
|
|
class QwenCASAAttentionHandler(CASAAttentionHandler): |
|
|
"""Overrides CASAAttention with the right pos embedding computation for Qwen""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args: Any, |
|
|
get_rope_index: Callable | None = None, |
|
|
grid_thw: torch.Tensor | None = None, |
|
|
position_ids_offset: int = 0, |
|
|
**kwargs: Any, |
|
|
): |
|
|
assert get_rope_index is not None, "get_rope_index should be given for QwenCASA" |
|
|
self.get_rope_index = partial(get_rope_index, image_grid_thw=grid_thw) |
|
|
self.position_ids_offset = position_ids_offset |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def compute_position_embeddings( |
|
|
self, |
|
|
rope_fn: Callable, |
|
|
sample_lengths: list[int], |
|
|
dummy_for_dtype_and_device: torch.Tensor, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute info required for position embeddings. Can be overriden e.g. for Qwen""" |
|
|
|
|
|
dummy_input_ids = torch.zeros( |
|
|
(int(sum(sample_lengths)),), device=dummy_for_dtype_and_device.device, dtype=torch.long |
|
|
) |
|
|
|
|
|
dummy_input_ids[self.image_tokens_mask[:, 0]] = 151655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_of_images = torch.logical_and( |
|
|
dummy_input_ids == 0, |
|
|
torch.nn.functional.pad(dummy_input_ids[1:] == 151655, (0, 1), value=0), |
|
|
) |
|
|
dummy_input_ids[start_of_images] = 151652 |
|
|
|
|
|
|
|
|
padding_side = "left" if self.attention_mask is not None else "right" |
|
|
s = list(torch.split(dummy_input_ids, self.full_batch_lengths)) |
|
|
mlen = max(_s.shape[0] for _s in s) |
|
|
trims = [mlen - _s.shape[0] for _s in s] |
|
|
dummy_input_ids = torch.stack( |
|
|
[ |
|
|
torch.nn.functional.pad( |
|
|
_s, |
|
|
( |
|
|
trims[i] if padding_side == "left" else 0, |
|
|
trims[i] if padding_side == "right" else 0, |
|
|
), |
|
|
value=-1, |
|
|
) |
|
|
for i, _s in enumerate(s) |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
attention_mask = torch.ones_like(dummy_input_ids) |
|
|
for i, t in enumerate(trims): |
|
|
if padding_side == "right": |
|
|
attention_mask[i, attention_mask.shape[-1] - t :] = 0 |
|
|
else: |
|
|
attention_mask[i, :t] = 0 |
|
|
|
|
|
|
|
|
position_ids = ( |
|
|
self.get_rope_index(dummy_input_ids, attention_mask=attention_mask)[0] |
|
|
+ self.position_ids_offset |
|
|
) |
|
|
|
|
|
|
|
|
cos, sin = rope_fn(dummy_for_dtype_and_device, position_ids) |
|
|
|
|
|
if padding_side == "right": |
|
|
cos = torch.cat( |
|
|
[cos[:, i : i + 1, : cos.shape[2] - t, :] for i, t in enumerate(trims)], dim=2 |
|
|
) |
|
|
sin = torch.cat( |
|
|
[sin[:, i : i + 1, : sin.shape[2] - t, :] for i, t in enumerate(trims)], dim=2 |
|
|
) |
|
|
else: |
|
|
cos = torch.cat([cos[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2) |
|
|
sin = torch.cat([sin[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2) |
|
|
return cos, sin |
|
|
|
|
|
def get_position_embedding( |
|
|
self, |
|
|
key: Literal["q", "kv"], |
|
|
num_queries: int = 0, |
|
|
) -> tuple[torch.Tensor, torch.Tensor] | None: |
|
|
if self.position_embeds is None: |
|
|
return None |
|
|
cos, sin = self.position_embeds |
|
|
|
|
|
if key == "q": |
|
|
cos, sin = ( |
|
|
cos[:, :, ~self.image_tokens_mask[:, 0]], |
|
|
sin[:, :, ~self.image_tokens_mask[:, 0]], |
|
|
) |
|
|
elif key != "kv": |
|
|
raise ValueError(f"Unknown key for position embedding {key}") |
|
|
|
|
|
|
|
|
if num_queries == 0: |
|
|
return cos, sin |
|
|
|
|
|
bls = self.full_batch_lengths if key == "kv" else self.batch_lengths |
|
|
cos = [x[:, :, -num_queries:] for x in torch.split(cos, bls, dim=2)] |
|
|
sin = [x[:, :, -num_queries:] for x in torch.split(sin, bls, dim=2)] |
|
|
return torch.cat(cos, dim=2), torch.cat(sin, dim=2) |
|
|
|
|
|
|
|
|
class QwenCASAAttention(CASAAttention): |
|
|
"""A CASA Attention layer compatible with Qwen""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Qwen2_5_VLCASAConfig, |
|
|
layer_idx: int | None, |
|
|
self_attn: torch.nn.Module | None = None, |
|
|
input_layernorm_fn: Callable | None = None, |
|
|
): |
|
|
|
|
|
super().__init__(config, layer_idx, self_attn, input_layernorm_fn) |
|
|
assert config.rope_scaling is not None |
|
|
self.mrope_section = config.rope_scaling["mrope_section"] * 2 |
|
|
|
|
|
def apply_position_embeddings( |
|
|
self, |
|
|
key: Literal["q", "kv"], |
|
|
x: torch.Tensor, |
|
|
casa_handler: CASAAttentionHandler | None, |
|
|
num_queries: int = 0, |
|
|
unsqueeze_dim: int = 1, |
|
|
) -> torch.Tensor: |
|
|
"""Apply position embeddings to query and key states""" |
|
|
if casa_handler is not None: |
|
|
posemb = casa_handler.get_position_embedding(key, num_queries=num_queries) |
|
|
|
|
|
if posemb is not None: |
|
|
x = x.transpose(1, 2).to(torch.float32) |
|
|
cos, sin = posemb |
|
|
cos = torch.cat( |
|
|
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1 |
|
|
).unsqueeze(unsqueeze_dim) |
|
|
|
|
|
sin = torch.cat( |
|
|
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1 |
|
|
).unsqueeze(unsqueeze_dim) |
|
|
|
|
|
x = (x * cos) + (rotate_half(x) * sin) |
|
|
return x.transpose(1, 2) |
|
|
return x |
|
|
|
|
|
def init_from_config_proj( |
|
|
self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig |
|
|
) -> torch.nn.Linear: |
|
|
"""Follows modeling_qwen2_5_vl.py initialization""" |
|
|
head_dim = config.hidden_size // config.num_attention_heads |
|
|
if key == "q": |
|
|
return torch.nn.Linear( |
|
|
config.hidden_size, config.num_attention_heads * head_dim, bias=True |
|
|
) |
|
|
if key in {"k", "v"}: |
|
|
return torch.nn.Linear( |
|
|
config.hidden_size, config.num_key_value_heads * head_dim, bias=True |
|
|
) |
|
|
if key == "o": |
|
|
return torch.nn.Linear( |
|
|
config.num_attention_heads * config.head_dim, config.hidden_size, bias=False |
|
|
) |
|
|
raise NotImplementedError(f"Unknown key {key}") |
|
|
|
|
|
|
|
|
class Qwen2_5_VLAttention_CASA(Qwen2_5_VLFlashAttention2): |
|
|
""" |
|
|
Qwen Attention with extra CASA Attention layer |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Qwen2_5_VLCASAConfig, |
|
|
layer_idx: Optional[int] = None, |
|
|
input_layernorm: torch.nn.Module | None = None, |
|
|
): |
|
|
super().__init__(config, layer_idx) |
|
|
self.casa_attn = QwenCASAAttention( |
|
|
config, |
|
|
layer_idx=layer_idx, |
|
|
self_attn=self, |
|
|
input_layernorm_fn=input_layernorm.forward if input_layernorm is not None else None, |
|
|
) |
|
|
self.casa_attention_handler: CASAAttentionHandler | None = None |
|
|
|
|
|
@classmethod |
|
|
def from_qwen2_5_vl_attention( |
|
|
cls, attention: Qwen2_5_VLFlashAttention2, input_layernorm: torch.nn.Module | None |
|
|
): |
|
|
"""Init this layer from Qwen Attention layer""" |
|
|
layer_idx = attention.layer_idx |
|
|
assert layer_idx is not None |
|
|
new_attention = cls(attention.config, layer_idx=layer_idx, input_layernorm=input_layernorm) |
|
|
new_attention.load_state_dict(attention.state_dict(), strict=False) |
|
|
return new_attention |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Cache] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
): |
|
|
casa_out: None | torch.Tensor = None |
|
|
if self.casa_attn is not None and self.config.xa_order in { |
|
|
"parallel", |
|
|
"ca_first", |
|
|
"instead", |
|
|
}: |
|
|
casa_out = self.casa_attn( |
|
|
hidden_states=hidden_states, |
|
|
casa_handler=self.casa_attention_handler, |
|
|
) |
|
|
|
|
|
if self.config.xa_order == "instead": |
|
|
return casa_out, None, None |
|
|
|
|
|
if self.config.xa_order == "ca_first" and casa_out is not None: |
|
|
hidden_states, casa_out = casa_out, None |
|
|
|
|
|
attn_output, attn_weights, past_key_values = super().forward( |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
past_key_value, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
cache_position, |
|
|
position_embeddings, |
|
|
) |
|
|
if self.config.xa_order == "parallel" and casa_out is not None: |
|
|
attn_output = casa_out + attn_output |
|
|
return attn_output, attn_weights, past_key_values |
|
|
|
|
|
|
|
|
def add_casa_layers(m: torch.nn.Module, xa_layers: tuple[int, ...] | None): |
|
|
"""Replace Attention layer by CASA Attention layer as needed""" |
|
|
if isinstance(m, Qwen2_5_VLDecoderLayer): |
|
|
layer_idx = m.self_attn.layer_idx |
|
|
assert layer_idx is not None |
|
|
if xa_layers is None or len(xa_layers) == 0 or layer_idx in xa_layers: |
|
|
m.self_attn = Qwen2_5_VLAttention_CASA.from_qwen2_5_vl_attention( |
|
|
m.self_attn, input_layernorm=m.input_layernorm |
|
|
) |
|
|
|