Image-Text-to-Text
Transformers
Safetensors
English
CASA_Qwen_2_5_VL_3B
conversational
custom_code
CASA-Qwen2_5-VL-3B / language_qwen2_5vl_casa.py
ameroyer's picture
Super-squash branch 'main' using huggingface_hub
eb26251 verified
from functools import partial
from typing import Any, Callable, Literal, Optional
import torch
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLDecoderLayer,
Qwen2_5_VLFlashAttention2,
rotate_half,
)
from .casa_attention import CASAAttention, CASAAttentionHandler
from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig
class QwenCASAAttentionHandler(CASAAttentionHandler):
"""Overrides CASAAttention with the right pos embedding computation for Qwen"""
def __init__(
self,
*args: Any,
get_rope_index: Callable | None = None,
grid_thw: torch.Tensor | None = None,
position_ids_offset: int = 0,
**kwargs: Any,
):
assert get_rope_index is not None, "get_rope_index should be given for QwenCASA"
self.get_rope_index = partial(get_rope_index, image_grid_thw=grid_thw)
self.position_ids_offset = position_ids_offset
super().__init__(*args, **kwargs)
def compute_position_embeddings(
self,
rope_fn: Callable,
sample_lengths: list[int],
dummy_for_dtype_and_device: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute info required for position embeddings. Can be overriden e.g. for Qwen"""
# Here rope_fn is the "get_rope_index" function from the original mode
dummy_input_ids = torch.zeros(
(int(sum(sample_lengths)),), device=dummy_for_dtype_and_device.device, dtype=torch.long
)
# Set image token ids
dummy_input_ids[self.image_tokens_mask[:, 0]] = 151655
# required for the weird start of image tokens
# Highly recommended to use pre and post image tokens with Qwen
# Add vision start token ids (wherever a 151655 follows a 0)
start_of_images = torch.logical_and(
dummy_input_ids == 0,
torch.nn.functional.pad(dummy_input_ids[1:] == 151655, (0, 1), value=0),
)
dummy_input_ids[start_of_images] = 151652
# rebatch dummy input ids
padding_side = "left" if self.attention_mask is not None else "right"
s = list(torch.split(dummy_input_ids, self.full_batch_lengths))
mlen = max(_s.shape[0] for _s in s)
trims = [mlen - _s.shape[0] for _s in s]
dummy_input_ids = torch.stack(
[
torch.nn.functional.pad(
_s,
(
trims[i] if padding_side == "left" else 0,
trims[i] if padding_side == "right" else 0,
),
value=-1,
)
for i, _s in enumerate(s)
],
dim=0,
)
# We need to give attention map to rope_index in left padding
attention_mask = torch.ones_like(dummy_input_ids)
for i, t in enumerate(trims):
if padding_side == "right":
attention_mask[i, attention_mask.shape[-1] - t :] = 0
else:
attention_mask[i, :t] = 0
# compute pos embeds shape (3, bs, seq)
position_ids = (
self.get_rope_index(dummy_input_ids, attention_mask=attention_mask)[0]
+ self.position_ids_offset
)
# Compute pos-ebemds and recover flattened unpadded shape
cos, sin = rope_fn(dummy_for_dtype_and_device, position_ids)
# reflatten seq
if padding_side == "right":
cos = torch.cat(
[cos[:, i : i + 1, : cos.shape[2] - t, :] for i, t in enumerate(trims)], dim=2
)
sin = torch.cat(
[sin[:, i : i + 1, : sin.shape[2] - t, :] for i, t in enumerate(trims)], dim=2
)
else:
cos = torch.cat([cos[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2)
sin = torch.cat([sin[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2)
return cos, sin
def get_position_embedding(
self,
key: Literal["q", "kv"],
num_queries: int = 0,
) -> tuple[torch.Tensor, torch.Tensor] | None:
if self.position_embeds is None:
return None
cos, sin = self.position_embeds
# For Q, we only want the text-only posembeds
if key == "q":
cos, sin = (
cos[:, :, ~self.image_tokens_mask[:, 0]],
sin[:, :, ~self.image_tokens_mask[:, 0]],
)
elif key != "kv":
raise ValueError(f"Unknown key for position embedding {key}")
# Easy case: training or first step at inference: we use all the posembeds
if num_queries == 0:
return cos, sin
# If num queries is given, we need to trim for *every sample in the batch*
bls = self.full_batch_lengths if key == "kv" else self.batch_lengths
cos = [x[:, :, -num_queries:] for x in torch.split(cos, bls, dim=2)]
sin = [x[:, :, -num_queries:] for x in torch.split(sin, bls, dim=2)]
return torch.cat(cos, dim=2), torch.cat(sin, dim=2)
class QwenCASAAttention(CASAAttention):
"""A CASA Attention layer compatible with Qwen"""
def __init__(
self,
config: Qwen2_5_VLCASAConfig,
layer_idx: int | None,
self_attn: torch.nn.Module | None = None,
input_layernorm_fn: Callable | None = None,
):
# Only adding this init for typing purposes for the config
super().__init__(config, layer_idx, self_attn, input_layernorm_fn) # pyright: ignore[reportArgumentType]
assert config.rope_scaling is not None
self.mrope_section = config.rope_scaling["mrope_section"] * 2
def apply_position_embeddings(
self,
key: Literal["q", "kv"],
x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
casa_handler: CASAAttentionHandler | None,
num_queries: int = 0,
unsqueeze_dim: int = 1,
) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
"""Apply position embeddings to query and key states"""
if casa_handler is not None:
posemb = casa_handler.get_position_embedding(key, num_queries=num_queries)
if posemb is not None:
x = x.transpose(1, 2).to(torch.float32)
cos, sin = posemb
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1
).unsqueeze(unsqueeze_dim)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1
).unsqueeze(unsqueeze_dim)
x = (x * cos) + (rotate_half(x) * sin)
return x.transpose(1, 2)
return x
def init_from_config_proj(
self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
) -> torch.nn.Linear:
"""Follows modeling_qwen2_5_vl.py initialization"""
head_dim = config.hidden_size // config.num_attention_heads
if key == "q":
return torch.nn.Linear(
config.hidden_size, config.num_attention_heads * head_dim, bias=True
)
if key in {"k", "v"}:
return torch.nn.Linear(
config.hidden_size, config.num_key_value_heads * head_dim, bias=True
)
if key == "o":
return torch.nn.Linear(
config.num_attention_heads * config.head_dim, config.hidden_size, bias=False
)
raise NotImplementedError(f"Unknown key {key}")
class Qwen2_5_VLAttention_CASA(Qwen2_5_VLFlashAttention2):
"""
Qwen Attention with extra CASA Attention layer
"""
def __init__(
self,
config: Qwen2_5_VLCASAConfig,
layer_idx: Optional[int] = None,
input_layernorm: torch.nn.Module | None = None,
):
super().__init__(config, layer_idx) # pyright: ignore[reportArgumentType]
self.casa_attn = QwenCASAAttention(
config,
layer_idx=layer_idx,
self_attn=self,
input_layernorm_fn=input_layernorm.forward if input_layernorm is not None else None,
)
self.casa_attention_handler: CASAAttentionHandler | None = None
@classmethod
def from_qwen2_5_vl_attention(
cls, attention: Qwen2_5_VLFlashAttention2, input_layernorm: torch.nn.Module | None
):
"""Init this layer from Qwen Attention layer"""
layer_idx = attention.layer_idx
assert layer_idx is not None
new_attention = cls(attention.config, layer_idx=layer_idx, input_layernorm=input_layernorm) # pyright: ignore
new_attention.load_state_dict(attention.state_dict(), strict=False)
return new_attention
def forward( # pyright: ignore[reportIncompatibleMethodOverride]
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
):
casa_out: None | torch.Tensor = None
if self.casa_attn is not None and self.config.xa_order in {
"parallel",
"ca_first",
"instead",
}:
casa_out = self.casa_attn(
hidden_states=hidden_states,
casa_handler=self.casa_attention_handler,
)
if self.config.xa_order == "instead":
return casa_out, None, None
if self.config.xa_order == "ca_first" and casa_out is not None:
hidden_states, casa_out = casa_out, None
attn_output, attn_weights, past_key_values = super().forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
if self.config.xa_order == "parallel" and casa_out is not None:
attn_output = casa_out + attn_output
return attn_output, attn_weights, past_key_values
def add_casa_layers(m: torch.nn.Module, xa_layers: tuple[int, ...] | None):
"""Replace Attention layer by CASA Attention layer as needed"""
if isinstance(m, Qwen2_5_VLDecoderLayer):
layer_idx = m.self_attn.layer_idx
assert layer_idx is not None
if xa_layers is None or len(xa_layers) == 0 or layer_idx in xa_layers:
m.self_attn = Qwen2_5_VLAttention_CASA.from_qwen2_5_vl_attention(
m.self_attn, input_layernorm=m.input_layernorm
)