Image-Text-to-Text
Transformers
Safetensors
English
CASA_Qwen_2_5_VL_3B
conversational
custom_code
CASA-Qwen2_5-VL-3B / modeling_qwen2_5vl_casa.py
ameroyer's picture
Super-squash branch 'main' using huggingface_hub
eb26251 verified
from functools import partial
from typing import Any
from typing import cast as type_cast
import torch
from transformers.cache_utils import DynamicCache
from transformers.generation.utils import GenerateOutput
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLCausalLMOutputWithPast,
Qwen2_5_VLForConditionalGeneration,
)
from .image_encoder import Qwen25VLEncoder
from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig
from .language_qwen2_5vl_casa import (
Qwen2_5_VLAttention_CASA,
QwenCASAAttention,
QwenCASAAttentionHandler,
add_casa_layers,
)
class V2Qwen2_5VL(Qwen2_5_VLForConditionalGeneration): # pyright: ignore[reportIncompatibleMethodOverride]
config_class = Qwen2_5_VLCASAConfig
def __init__(self, config: Qwen2_5_VLCASAConfig, **kwargs: Any) -> None:
del kwargs
super().__init__(config)
self.image_prefix = Qwen25VLEncoder(self.visual) # type: ignore[assignment]
self.visual = None
self.model.apply(partial(add_casa_layers, xa_layers=self.config.xa_layers))
def get_device(self) -> str:
"""Return the device type of the model"""
return next(self.parameters()).device.type
@property
def token_dim(self) -> int:
"""Returns the number of dimensions for the token representation"""
return self.config.hidden_size
def _update_model_kwargs_for_generation(
self,
outputs: Any,
model_kwargs: dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
):
"""This is required to handle multiple gen calls for subtitles"""
# Call parent to get default updates
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
)
# Used by prepare_inputs_for_generation
model_kwargs["__is_first_gen_call__"] = False
return model_kwargs
def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor,
past_key_values: DynamicCache | None = None,
**kwargs: Any,
):
"""Required to handle cache_position = None with QwenVL"""
__is_first_gen_call__ = kwargs.pop("__is_first_gen_call__", True)
if past_key_values is not None and (
kwargs.get("cache_position") is None
or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
):
# We're continuing from a cached state
past_length = past_key_values._seen_tokens
kwargs["cache_position"] = torch.arange(
past_length,
past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
dtype=torch.long,
device=input_ids.device,
)
return super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
**kwargs,
)
def prepare_multimodal_inputs(
self,
# text only training
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
labels: torch.Tensor | None = None,
# image values
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
**_kwargs: Any,
) -> dict:
"""Get a batch data mixing text and image data"""
del _kwargs
processed_inputs: dict = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"labels": labels,
"attention_mask": attention_mask,
"image_embeds_insertion_points": image_embeds_insertion_points,
}
if pixel_values is not None:
processed_inputs.update(self.image_prefix(pixel_values))
assert "image_embeds" in processed_inputs
assert (
isinstance(processed_inputs["image_embeds"], torch.Tensor)
and processed_inputs["image_embeds"].ndim == 3
) or (
isinstance(processed_inputs["image_embeds"], list)
and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
)
# Add kwargs necessary to compute cu_seqlens windows for CASA
processed_inputs["casa_windows_info"] = {
"num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
"num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
}
return processed_inputs
def forward( # type: ignore[override] # pylint: disable=W0221
self,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
labels: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
reinit_casa_handler: bool = True,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
**kwargs: Any,
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
"""Multi-modal forward pass"""
if reinit_casa_handler:
processed_inputs = self.prepare_multimodal_inputs(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
image_embeds_insertion_points=image_embeds_insertion_points,
pixel_values=pixel_values,
labels=labels,
post_image_tokens=post_image_tokens,
pre_image_tokens=pre_image_tokens,
)
inputs_embeds = type_cast(
torch.Tensor, self.model.embed_tokens(processed_inputs["input_ids"])
)
casa_attention_handler: QwenCASAAttentionHandler | None = None
image_embeds = processed_inputs.get("image_embeds", None)
attention_mask = processed_inputs["attention_mask"]
inst_points = processed_inputs.get("image_embeds_insertion_points", None)
if image_embeds is None:
inst_points = None
casa_attention_handler = QwenCASAAttentionHandler(
# for text tokens, we don't need the actual values
inputs_embeds=torch.zeros_like(inputs_embeds),
# for image embeddings, we put real inputs as this will be fixed
image_embeds=[] if image_embeds is None else image_embeds,
image_embeds_insertion_points=inst_points,
# attention mask is only needed at inference / left padding
attention_mask=None if self.training else processed_inputs["attention_mask"],
rope_fn=self.model.rotary_emb,
windows=self.config.casa_windows,
casa_windows_info=processed_inputs.pop("casa_windows_info", None),
use_asymetric_q_kv=self.config.casa_use_asymetric_qkv,
# extra for Qwen
get_rope_index=self.get_rope_index,
grid_thw=processed_inputs.get("grid_thw", None),
)
self.update_casa_states(casa_attention_handler)
else:
inputs_embeds = self.model.embed_tokens(input_ids)
# Run Qwen with the attention layers replaced to use CASA
assert inputs_embeds is not None, "Could not compute input embeddings!"
out = super().forward(
inputs_embeds=inputs_embeds, # type: ignore[arg-type]
attention_mask=attention_mask,
pixel_values=None,
**kwargs,
)
return out
@torch.no_grad()
def generate_from_image( # pyright: ignore[reportInconsistentOverload]
self,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
position_ids_offset: int | None = None,
reset_streaming: bool = True,
**kwargs: Any,
) -> GenerateOutput | torch.LongTensor:
"""Custom generate function"""
assert input_ids is not None and inputs_embeds is None, (
"Input IDs must be provided for generation"
)
# init self-attention KVCache
if kwargs.get("past_key_values", None) is None:
kwargs["past_key_values"] = DynamicCache()
# To avoid generate warning
if kwargs.get("pad_token_id", None) is None:
kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
if isinstance(kwargs["pad_token_id"], (list, tuple)):
kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
# Init CASA states
processed_inputs = self.prepare_multimodal_inputs(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
image_embeds_insertion_points=image_embeds_insertion_points,
pixel_values=pixel_values,
labels=None,
pre_image_tokens=pre_image_tokens,
post_image_tokens=post_image_tokens,
)
if pixel_values is not None:
assert (image_embeds := processed_inputs.get("image_embeds", None)) is not None
assert (
insrt_pts := processed_inputs.get("image_embeds_insertion_points", None)
) is not None
casa_attention_handler = QwenCASAAttentionHandler(
inputs_embeds=torch.empty(
(input_ids.shape[0], input_ids.shape[1], image_embeds[0].shape[-1]),
dtype=image_embeds[0].dtype,
device=image_embeds[0].device,
),
image_embeds=image_embeds,
image_embeds_insertion_points=insrt_pts,
attention_mask=attention_mask,
rope_fn=self.model.rotary_emb,
windows=self.config.casa_windows,
casa_windows_info=processed_inputs.pop("casa_windows_info", None),
use_asymetric_q_kv=self.config.casa_use_asymetric_qkv,
get_rope_index=self.get_rope_index,
grid_thw=processed_inputs.get("grid_thw", None),
position_ids_offset=position_ids_offset or kwargs["past_key_values"]._seen_tokens,
)
self.update_casa_states(casa_attention_handler)
self.start_casa_streaming_states()
pixel_values = None
# Generate
outputs = self.generate(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
use_cache=True,
reinit_casa_handler=False,
**kwargs,
)
if reset_streaming:
self.reset_casa_streaming_states()
return outputs
def update_casa_states(self, handler: QwenCASAAttentionHandler | None):
"""Update handler in all layers"""
def __update__(m: torch.nn.Module):
nonlocal handler
if isinstance(m, Qwen2_5_VLAttention_CASA):
m.casa_attention_handler = handler
self.apply(__update__)
def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
def __reset__(m: torch.nn.Module):
if isinstance(m, QwenCASAAttention):
m._set_streaming(False, ())
m.reset_streaming()
if clean_cache:
del m.streaming_state.k
del m.streaming_state.v
m.streaming_state.k = None # pyright: ignore[reportAttributeAccessIssue]
m.streaming_state.v = None # pyright: ignore[reportAttributeAccessIssue]
elif isinstance(m, Qwen2_5_VLAttention_CASA):
del m.casa_attention_handler
m.casa_attention_handler = None
self.apply(__reset__)
def start_casa_streaming_states(self) -> None:
def __start__(m: torch.nn.Module):
if isinstance(m, QwenCASAAttention):
m._set_streaming(True, ())
self.apply(__start__)