from functools import partial from typing import Any from typing import cast as type_cast import torch from transformers.cache_utils import DynamicCache from transformers.generation.utils import GenerateOutput from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLCausalLMOutputWithPast, Qwen2_5_VLForConditionalGeneration, ) from .image_encoder import Qwen25VLEncoder from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig from .language_qwen2_5vl_casa import ( Qwen2_5_VLAttention_CASA, QwenCASAAttention, QwenCASAAttentionHandler, add_casa_layers, ) class V2Qwen2_5VL(Qwen2_5_VLForConditionalGeneration): # pyright: ignore[reportIncompatibleMethodOverride] config_class = Qwen2_5_VLCASAConfig def __init__(self, config: Qwen2_5_VLCASAConfig, **kwargs: Any) -> None: del kwargs super().__init__(config) self.image_prefix = Qwen25VLEncoder(self.visual) # type: ignore[assignment] self.visual = None self.model.apply(partial(add_casa_layers, xa_layers=self.config.xa_layers)) def get_device(self) -> str: """Return the device type of the model""" return next(self.parameters()).device.type @property def token_dim(self) -> int: """Returns the number of dimensions for the token representation""" return self.config.hidden_size def _update_model_kwargs_for_generation( self, outputs: Any, model_kwargs: dict[str, Any], is_encoder_decoder: bool = False, num_new_tokens: int = 1, ): """This is required to handle multiple gen calls for subtitles""" # Call parent to get default updates model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder, num_new_tokens ) # Used by prepare_inputs_for_generation model_kwargs["__is_first_gen_call__"] = False return model_kwargs def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride] self, input_ids: torch.Tensor, past_key_values: DynamicCache | None = None, **kwargs: Any, ): """Required to handle cache_position = None with QwenVL""" __is_first_gen_call__ = kwargs.pop("__is_first_gen_call__", True) if past_key_values is not None and ( kwargs.get("cache_position") is None or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0 ): # We're continuing from a cached state past_length = past_key_values._seen_tokens kwargs["cache_position"] = torch.arange( past_length, past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1), dtype=torch.long, device=input_ids.device, ) return super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, **kwargs, ) def prepare_multimodal_inputs( self, # text only training input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds_insertion_points: list[torch.Tensor] | None = None, labels: torch.Tensor | None = None, # image values pixel_values: torch.Tensor | list[torch.Tensor] | None = None, pre_image_tokens: list[int] | None = None, post_image_tokens: list[int] | None = None, **_kwargs: Any, ) -> dict: """Get a batch data mixing text and image data""" del _kwargs processed_inputs: dict = { "input_ids": input_ids, "inputs_embeds": inputs_embeds, "labels": labels, "attention_mask": attention_mask, "image_embeds_insertion_points": image_embeds_insertion_points, } if pixel_values is not None: processed_inputs.update(self.image_prefix(pixel_values)) assert "image_embeds" in processed_inputs assert ( isinstance(processed_inputs["image_embeds"], torch.Tensor) and processed_inputs["image_embeds"].ndim == 3 ) or ( isinstance(processed_inputs["image_embeds"], list) and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"]) ) # Add kwargs necessary to compute cu_seqlens windows for CASA processed_inputs["casa_windows_info"] = { "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens), "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens), } return processed_inputs def forward( # type: ignore[override] # pylint: disable=W0221 self, input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, pixel_values: torch.Tensor | list[torch.Tensor] | None = None, labels: torch.Tensor | None = None, image_embeds_insertion_points: list[torch.Tensor] | None = None, reinit_casa_handler: bool = True, pre_image_tokens: list[int] | None = None, post_image_tokens: list[int] | None = None, **kwargs: Any, ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: """Multi-modal forward pass""" if reinit_casa_handler: processed_inputs = self.prepare_multimodal_inputs( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, image_embeds_insertion_points=image_embeds_insertion_points, pixel_values=pixel_values, labels=labels, post_image_tokens=post_image_tokens, pre_image_tokens=pre_image_tokens, ) inputs_embeds = type_cast( torch.Tensor, self.model.embed_tokens(processed_inputs["input_ids"]) ) casa_attention_handler: QwenCASAAttentionHandler | None = None image_embeds = processed_inputs.get("image_embeds", None) attention_mask = processed_inputs["attention_mask"] inst_points = processed_inputs.get("image_embeds_insertion_points", None) if image_embeds is None: inst_points = None casa_attention_handler = QwenCASAAttentionHandler( # for text tokens, we don't need the actual values inputs_embeds=torch.zeros_like(inputs_embeds), # for image embeddings, we put real inputs as this will be fixed image_embeds=[] if image_embeds is None else image_embeds, image_embeds_insertion_points=inst_points, # attention mask is only needed at inference / left padding attention_mask=None if self.training else processed_inputs["attention_mask"], rope_fn=self.model.rotary_emb, windows=self.config.casa_windows, casa_windows_info=processed_inputs.pop("casa_windows_info", None), use_asymetric_q_kv=self.config.casa_use_asymetric_qkv, # extra for Qwen get_rope_index=self.get_rope_index, grid_thw=processed_inputs.get("grid_thw", None), ) self.update_casa_states(casa_attention_handler) else: inputs_embeds = self.model.embed_tokens(input_ids) # Run Qwen with the attention layers replaced to use CASA assert inputs_embeds is not None, "Could not compute input embeddings!" out = super().forward( inputs_embeds=inputs_embeds, # type: ignore[arg-type] attention_mask=attention_mask, pixel_values=None, **kwargs, ) return out @torch.no_grad() def generate_from_image( # pyright: ignore[reportInconsistentOverload] self, input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds_insertion_points: list[torch.Tensor] | None = None, pixel_values: torch.Tensor | list[torch.Tensor] | None = None, pre_image_tokens: list[int] | None = None, post_image_tokens: list[int] | None = None, position_ids_offset: int | None = None, reset_streaming: bool = True, **kwargs: Any, ) -> GenerateOutput | torch.LongTensor: """Custom generate function""" assert input_ids is not None and inputs_embeds is None, ( "Input IDs must be provided for generation" ) # init self-attention KVCache if kwargs.get("past_key_values", None) is None: kwargs["past_key_values"] = DynamicCache() # To avoid generate warning if kwargs.get("pad_token_id", None) is None: kwargs["pad_token_id"] = kwargs.get("eos_token_id", None) if isinstance(kwargs["pad_token_id"], (list, tuple)): kwargs["pad_token_id"] = kwargs["pad_token_id"][0] # Init CASA states processed_inputs = self.prepare_multimodal_inputs( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, image_embeds_insertion_points=image_embeds_insertion_points, pixel_values=pixel_values, labels=None, pre_image_tokens=pre_image_tokens, post_image_tokens=post_image_tokens, ) if pixel_values is not None: assert (image_embeds := processed_inputs.get("image_embeds", None)) is not None assert ( insrt_pts := processed_inputs.get("image_embeds_insertion_points", None) ) is not None casa_attention_handler = QwenCASAAttentionHandler( inputs_embeds=torch.empty( (input_ids.shape[0], input_ids.shape[1], image_embeds[0].shape[-1]), dtype=image_embeds[0].dtype, device=image_embeds[0].device, ), image_embeds=image_embeds, image_embeds_insertion_points=insrt_pts, attention_mask=attention_mask, rope_fn=self.model.rotary_emb, windows=self.config.casa_windows, casa_windows_info=processed_inputs.pop("casa_windows_info", None), use_asymetric_q_kv=self.config.casa_use_asymetric_qkv, get_rope_index=self.get_rope_index, grid_thw=processed_inputs.get("grid_thw", None), position_ids_offset=position_ids_offset or kwargs["past_key_values"]._seen_tokens, ) self.update_casa_states(casa_attention_handler) self.start_casa_streaming_states() pixel_values = None # Generate outputs = self.generate( input_ids, attention_mask=attention_mask, pixel_values=pixel_values, use_cache=True, reinit_casa_handler=False, **kwargs, ) if reset_streaming: self.reset_casa_streaming_states() return outputs def update_casa_states(self, handler: QwenCASAAttentionHandler | None): """Update handler in all layers""" def __update__(m: torch.nn.Module): nonlocal handler if isinstance(m, Qwen2_5_VLAttention_CASA): m.casa_attention_handler = handler self.apply(__update__) def reset_casa_streaming_states(self, clean_cache: bool = True) -> None: def __reset__(m: torch.nn.Module): if isinstance(m, QwenCASAAttention): m._set_streaming(False, ()) m.reset_streaming() if clean_cache: del m.streaming_state.k del m.streaming_state.v m.streaming_state.k = None # pyright: ignore[reportAttributeAccessIssue] m.streaming_state.v = None # pyright: ignore[reportAttributeAccessIssue] elif isinstance(m, Qwen2_5_VLAttention_CASA): del m.casa_attention_handler m.casa_attention_handler = None self.apply(__reset__) def start_casa_streaming_states(self) -> None: def __start__(m: torch.nn.Module): if isinstance(m, QwenCASAAttention): m._set_streaming(True, ()) self.apply(__start__)