|
|
from functools import partial |
|
|
from typing import Any |
|
|
from typing import cast as type_cast |
|
|
|
|
|
import torch |
|
|
from transformers.cache_utils import DynamicCache |
|
|
from transformers.generation.utils import GenerateOutput |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VLCausalLMOutputWithPast, |
|
|
Qwen2_5_VLForConditionalGeneration, |
|
|
) |
|
|
|
|
|
from .image_encoder import Qwen25VLEncoder |
|
|
from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig |
|
|
from .language_qwen2_5vl_casa import ( |
|
|
Qwen2_5_VLAttention_CASA, |
|
|
QwenCASAAttention, |
|
|
QwenCASAAttentionHandler, |
|
|
add_casa_layers, |
|
|
) |
|
|
|
|
|
|
|
|
class V2Qwen2_5VL(Qwen2_5_VLForConditionalGeneration): |
|
|
config_class = Qwen2_5_VLCASAConfig |
|
|
|
|
|
def __init__(self, config: Qwen2_5_VLCASAConfig, **kwargs: Any) -> None: |
|
|
del kwargs |
|
|
super().__init__(config) |
|
|
self.image_prefix = Qwen25VLEncoder(self.visual) |
|
|
self.visual = None |
|
|
self.model.apply(partial(add_casa_layers, xa_layers=self.config.xa_layers)) |
|
|
|
|
|
def get_device(self) -> str: |
|
|
"""Return the device type of the model""" |
|
|
return next(self.parameters()).device.type |
|
|
|
|
|
@property |
|
|
def token_dim(self) -> int: |
|
|
"""Returns the number of dimensions for the token representation""" |
|
|
return self.config.hidden_size |
|
|
|
|
|
def _update_model_kwargs_for_generation( |
|
|
self, |
|
|
outputs: Any, |
|
|
model_kwargs: dict[str, Any], |
|
|
is_encoder_decoder: bool = False, |
|
|
num_new_tokens: int = 1, |
|
|
): |
|
|
"""This is required to handle multiple gen calls for subtitles""" |
|
|
|
|
|
model_kwargs = super()._update_model_kwargs_for_generation( |
|
|
outputs, model_kwargs, is_encoder_decoder, num_new_tokens |
|
|
) |
|
|
|
|
|
model_kwargs["__is_first_gen_call__"] = False |
|
|
return model_kwargs |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
past_key_values: DynamicCache | None = None, |
|
|
**kwargs: Any, |
|
|
): |
|
|
"""Required to handle cache_position = None with QwenVL""" |
|
|
__is_first_gen_call__ = kwargs.pop("__is_first_gen_call__", True) |
|
|
if past_key_values is not None and ( |
|
|
kwargs.get("cache_position") is None |
|
|
or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0 |
|
|
): |
|
|
|
|
|
past_length = past_key_values._seen_tokens |
|
|
kwargs["cache_position"] = torch.arange( |
|
|
past_length, |
|
|
past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1), |
|
|
dtype=torch.long, |
|
|
device=input_ids.device, |
|
|
) |
|
|
|
|
|
return super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def prepare_multimodal_inputs( |
|
|
self, |
|
|
|
|
|
input_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None = None, |
|
|
labels: torch.Tensor | None = None, |
|
|
|
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
|
|
pre_image_tokens: list[int] | None = None, |
|
|
post_image_tokens: list[int] | None = None, |
|
|
**_kwargs: Any, |
|
|
) -> dict: |
|
|
"""Get a batch data mixing text and image data""" |
|
|
del _kwargs |
|
|
|
|
|
processed_inputs: dict = { |
|
|
"input_ids": input_ids, |
|
|
"inputs_embeds": inputs_embeds, |
|
|
"labels": labels, |
|
|
"attention_mask": attention_mask, |
|
|
"image_embeds_insertion_points": image_embeds_insertion_points, |
|
|
} |
|
|
if pixel_values is not None: |
|
|
processed_inputs.update(self.image_prefix(pixel_values)) |
|
|
assert "image_embeds" in processed_inputs |
|
|
assert ( |
|
|
isinstance(processed_inputs["image_embeds"], torch.Tensor) |
|
|
and processed_inputs["image_embeds"].ndim == 3 |
|
|
) or ( |
|
|
isinstance(processed_inputs["image_embeds"], list) |
|
|
and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"]) |
|
|
) |
|
|
|
|
|
|
|
|
processed_inputs["casa_windows_info"] = { |
|
|
"num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens), |
|
|
"num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens), |
|
|
} |
|
|
|
|
|
return processed_inputs |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
|
|
labels: torch.Tensor | None = None, |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None = None, |
|
|
reinit_casa_handler: bool = True, |
|
|
pre_image_tokens: list[int] | None = None, |
|
|
post_image_tokens: list[int] | None = None, |
|
|
**kwargs: Any, |
|
|
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: |
|
|
"""Multi-modal forward pass""" |
|
|
|
|
|
if reinit_casa_handler: |
|
|
processed_inputs = self.prepare_multimodal_inputs( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
pixel_values=pixel_values, |
|
|
labels=labels, |
|
|
post_image_tokens=post_image_tokens, |
|
|
pre_image_tokens=pre_image_tokens, |
|
|
) |
|
|
inputs_embeds = type_cast( |
|
|
torch.Tensor, self.model.embed_tokens(processed_inputs["input_ids"]) |
|
|
) |
|
|
casa_attention_handler: QwenCASAAttentionHandler | None = None |
|
|
image_embeds = processed_inputs.get("image_embeds", None) |
|
|
attention_mask = processed_inputs["attention_mask"] |
|
|
inst_points = processed_inputs.get("image_embeds_insertion_points", None) |
|
|
if image_embeds is None: |
|
|
inst_points = None |
|
|
casa_attention_handler = QwenCASAAttentionHandler( |
|
|
|
|
|
inputs_embeds=torch.zeros_like(inputs_embeds), |
|
|
|
|
|
image_embeds=[] if image_embeds is None else image_embeds, |
|
|
image_embeds_insertion_points=inst_points, |
|
|
|
|
|
attention_mask=None if self.training else processed_inputs["attention_mask"], |
|
|
rope_fn=self.model.rotary_emb, |
|
|
windows=self.config.casa_windows, |
|
|
casa_windows_info=processed_inputs.pop("casa_windows_info", None), |
|
|
use_asymetric_q_kv=self.config.casa_use_asymetric_qkv, |
|
|
|
|
|
get_rope_index=self.get_rope_index, |
|
|
grid_thw=processed_inputs.get("grid_thw", None), |
|
|
) |
|
|
self.update_casa_states(casa_attention_handler) |
|
|
else: |
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
assert inputs_embeds is not None, "Could not compute input embeddings!" |
|
|
out = super().forward( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
pixel_values=None, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_from_image( |
|
|
self, |
|
|
input_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None = None, |
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
|
|
pre_image_tokens: list[int] | None = None, |
|
|
post_image_tokens: list[int] | None = None, |
|
|
position_ids_offset: int | None = None, |
|
|
reset_streaming: bool = True, |
|
|
**kwargs: Any, |
|
|
) -> GenerateOutput | torch.LongTensor: |
|
|
"""Custom generate function""" |
|
|
assert input_ids is not None and inputs_embeds is None, ( |
|
|
"Input IDs must be provided for generation" |
|
|
) |
|
|
|
|
|
|
|
|
if kwargs.get("past_key_values", None) is None: |
|
|
kwargs["past_key_values"] = DynamicCache() |
|
|
|
|
|
|
|
|
if kwargs.get("pad_token_id", None) is None: |
|
|
kwargs["pad_token_id"] = kwargs.get("eos_token_id", None) |
|
|
if isinstance(kwargs["pad_token_id"], (list, tuple)): |
|
|
kwargs["pad_token_id"] = kwargs["pad_token_id"][0] |
|
|
|
|
|
|
|
|
processed_inputs = self.prepare_multimodal_inputs( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
pixel_values=pixel_values, |
|
|
labels=None, |
|
|
pre_image_tokens=pre_image_tokens, |
|
|
post_image_tokens=post_image_tokens, |
|
|
) |
|
|
|
|
|
if pixel_values is not None: |
|
|
assert (image_embeds := processed_inputs.get("image_embeds", None)) is not None |
|
|
assert ( |
|
|
insrt_pts := processed_inputs.get("image_embeds_insertion_points", None) |
|
|
) is not None |
|
|
casa_attention_handler = QwenCASAAttentionHandler( |
|
|
inputs_embeds=torch.empty( |
|
|
(input_ids.shape[0], input_ids.shape[1], image_embeds[0].shape[-1]), |
|
|
dtype=image_embeds[0].dtype, |
|
|
device=image_embeds[0].device, |
|
|
), |
|
|
image_embeds=image_embeds, |
|
|
image_embeds_insertion_points=insrt_pts, |
|
|
attention_mask=attention_mask, |
|
|
rope_fn=self.model.rotary_emb, |
|
|
windows=self.config.casa_windows, |
|
|
casa_windows_info=processed_inputs.pop("casa_windows_info", None), |
|
|
use_asymetric_q_kv=self.config.casa_use_asymetric_qkv, |
|
|
get_rope_index=self.get_rope_index, |
|
|
grid_thw=processed_inputs.get("grid_thw", None), |
|
|
position_ids_offset=position_ids_offset or kwargs["past_key_values"]._seen_tokens, |
|
|
) |
|
|
self.update_casa_states(casa_attention_handler) |
|
|
self.start_casa_streaming_states() |
|
|
pixel_values = None |
|
|
|
|
|
|
|
|
outputs = self.generate( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pixel_values=pixel_values, |
|
|
use_cache=True, |
|
|
reinit_casa_handler=False, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if reset_streaming: |
|
|
self.reset_casa_streaming_states() |
|
|
return outputs |
|
|
|
|
|
def update_casa_states(self, handler: QwenCASAAttentionHandler | None): |
|
|
"""Update handler in all layers""" |
|
|
|
|
|
def __update__(m: torch.nn.Module): |
|
|
nonlocal handler |
|
|
|
|
|
if isinstance(m, Qwen2_5_VLAttention_CASA): |
|
|
m.casa_attention_handler = handler |
|
|
|
|
|
self.apply(__update__) |
|
|
|
|
|
def reset_casa_streaming_states(self, clean_cache: bool = True) -> None: |
|
|
def __reset__(m: torch.nn.Module): |
|
|
if isinstance(m, QwenCASAAttention): |
|
|
m._set_streaming(False, ()) |
|
|
m.reset_streaming() |
|
|
if clean_cache: |
|
|
del m.streaming_state.k |
|
|
del m.streaming_state.v |
|
|
m.streaming_state.k = None |
|
|
m.streaming_state.v = None |
|
|
|
|
|
elif isinstance(m, Qwen2_5_VLAttention_CASA): |
|
|
del m.casa_attention_handler |
|
|
m.casa_attention_handler = None |
|
|
|
|
|
self.apply(__reset__) |
|
|
|
|
|
def start_casa_streaming_states(self) -> None: |
|
|
def __start__(m: torch.nn.Module): |
|
|
if isinstance(m, QwenCASAAttention): |
|
|
m._set_streaming(True, ()) |
|
|
|
|
|
self.apply(__start__) |
|
|
|