from typing import Any, Callable from typing import cast as type_cast import torch from transformers.cache_utils import DynamicCache from transformers.configuration_utils import PretrainedConfig from transformers.generation.utils import GenerateOutput from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, ) from .image_encoder import Qwen25VLEncoder from .configuration_helium1_casa import Helium1CASAConfig from .language_helium1_casa import ( CausalHeliumOutput, Helium1CASAAttention, Helium1ForCausalLM, Helium1RMSNorm, ) def meta_project( logits: torch.Tensor | list[torch.Tensor], projector: torch.nn.Module, norm: torch.nn.Module | None = None, ) -> torch.Tensor | list[torch.Tensor]: """Projection operation that handles both tensors and list of tensors Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where S can be a different sequence length per image) """ split_sizes: list[int] | None = None if not isinstance(logits, torch.Tensor): split_sizes = [_x.shape[0] for _x in logits] logits = torch.cat(logits, dim=0)[None, :, :] logits = type_cast(torch.Tensor, logits) logits = projector(logits) assert isinstance(logits, torch.Tensor) if norm is not None: logits = norm(logits) if split_sizes is not None: return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0)) return logits class ImageProjection(torch.nn.Module): """Takes in a batch or sequence of images and returns embeddings which are then fed to the LM. :param config: KyuteyeConfig object :param lm_model_dim: Output dimension (number of channels) for this module """ def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None: super().__init__() self.config = config self.out_dim = lm_model_dim visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.enc = Qwen25VLEncoder(visual=visual) # Projection layer self.proj_extra = self.init_proj_module() # Output normalizations self.norm_extra = Helium1RMSNorm(self.out_dim) def init_proj_module(self) -> torch.nn.Module: """Init the project module for the inserted and/or cross-attended image tokens""" if self.config.vision_config.out_dim == self.out_dim: return torch.nn.Identity() return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim) def forward( self, x: torch.Tensor | list[torch.Tensor] ) -> dict[ str, torch.Tensor | list[torch.Tensor], ]: """Image embedding mapping :param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors with shape (C, H, W) (or (H, W, C) in the case of Qwen) :return: Either a tensor with shape (num_total_image, S, D) or, if images can have different seq length, a list of `num_total_images` Tensors with shape (S, D) """ # Apply image encoder og_dtype = x[0].dtype encoded = self.enc(x)["image_embeds"] encoded = [_x.to(og_dtype) for _x in encoded] if all(x.shape[0] == encoded[0].shape[0] for x in encoded): encoded = torch.stack(encoded, dim=0) # Extra projection image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra) # Apply different projection for extra vs cross attended tokens return {"image_embeds": image_embeds} class V2Helium1(Helium1ForCausalLM): # pyright: ignore[reportIncompatibleMethodOverride] config_class = Helium1CASAConfig def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None: del kwargs super().__init__(config) self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim) def get_device(self) -> str: """Return the device type of the model""" return next(self.parameters()).device.type @property def token_dim(self) -> int: """Returns the number of dimensions for the token representation""" return self.config.hidden_size @property def rotary_embed(self) -> Callable: """Returns the rotary embedding function of the underlying model""" return self.model.rotary_emb def _update_model_kwargs_for_generation( self, outputs: Any, model_kwargs: dict[str, Any], is_encoder_decoder: bool = False, num_new_tokens: int = 1, ): """This is required to handle multiple gen calls for subtitles""" # Call parent to get default updates model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder, num_new_tokens ) # Used by prepare_inputs_for_generation model_kwargs["__is_first_gen_call__"] = False return model_kwargs def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride] self, input_ids: torch.Tensor, past_key_values: DynamicCache | None = None, **kwargs: Any, ): __is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True) if past_key_values is not None and ( kwargs.get("cache_position") is None or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0 ): # We're continuing from a cached state past_length = past_key_values._seen_tokens kwargs["cache_position"] = torch.arange( past_length, past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1), dtype=torch.long, device=input_ids.device, ) return super().prepare_inputs_for_generation( type_cast(torch.LongTensor, input_ids), past_key_values=past_key_values, **kwargs, ) def prepare_multimodal_inputs( self, # text only training input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds_insertion_points: list[torch.Tensor] | None = None, labels: torch.Tensor | None = None, # image values pixel_values: torch.Tensor | list[torch.Tensor] | None = None, pre_image_tokens: list[int] | None = None, post_image_tokens: list[int] | None = None, **_kwargs: Any, ) -> dict: """Get a batch data mixing text and image data""" del _kwargs processed_inputs = { "input_ids": input_ids, "inputs_embeds": inputs_embeds, "labels": labels, "attention_mask": attention_mask, "image_embeds_insertion_points": image_embeds_insertion_points, } if pixel_values is not None: processed_inputs.update(self.image_prefix(pixel_values)) assert "image_embeds" in processed_inputs assert ( isinstance(processed_inputs["image_embeds"], torch.Tensor) and processed_inputs["image_embeds"].ndim == 3 ) or ( isinstance(processed_inputs["image_embeds"], list) and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"]) ) # Add kwargs necessary to compute cu_seqlens windows for CASA processed_inputs["casa_windows_info"] = { "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens), "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens), } return processed_inputs def forward( # pyright: ignore[reportIncompatibleMethodOverride] self, input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, pixel_values: torch.Tensor | list[torch.Tensor] | None = None, return_loss: bool = True, labels: torch.Tensor | None = None, image_embeds_insertion_points: list[torch.Tensor] | None = None, pre_image_tokens: list[int] | None = None, post_image_tokens: list[int] | None = None, **kwargs: Any, ) -> CausalHeliumOutput: """Multi modal forward pass""" assert input_ids is not None or inputs_embeds is not None if self.training: assert return_loss is True, ( "Helium models always compute its own labels/losses in train mode" ) # Case 1: For first generation call we need to compute pixel values and CASA states if kwargs.get("__is_first_gen_call__", True): processed_inputs = self.prepare_multimodal_inputs( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, image_embeds_insertion_points=image_embeds_insertion_points, pixel_values=pixel_values, labels=labels, pre_image_tokens=pre_image_tokens, post_image_tokens=post_image_tokens, ) processed_inputs.pop("inputs_embeds", None) else: processed_inputs = { "inputs_embeds": self.model.embed_tokens(input_ids), "attention_mask": attention_mask, } # For Helium prefix, we need to update the positions by the number # of image tokens inserted in the first call if ( not self.config.casa_attention and (cp := kwargs.get("cache_position", None)) is not None and pixel_values is not None ): start = kwargs["cache_position"][0].item() num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4 num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # type: ignore kwargs["cache_position"] = torch.arange( start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens), start + num_tokens + num_image_tokens, dtype=cp.dtype, device=cp.device, ) kwargs.pop("__is_first_gen_call__", True) out = super().forward( **processed_inputs, # type: ignore **kwargs, ) return out @torch.no_grad() def generate_from_image( # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride] self, input_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, image_embeds_insertion_points: list[torch.Tensor] | None = None, pixel_values: torch.Tensor | list[torch.Tensor] | None = None, reset_streaming: bool = True, **kwargs: Any, ) -> "GenerateOutput | torch.LongTensor": assert input_ids is not None and inputs_embeds is None, ( "Input IDs must be provided for generation" ) # init self-attention KVCache if kwargs.get("past_key_values", None) is None: kwargs["past_key_values"] = DynamicCache() # To avoid generate warning if kwargs.get("pad_token_id", None) is None: kwargs["pad_token_id"] = kwargs.get("eos_token_id", None) if isinstance(kwargs["pad_token_id"], (list, tuple)): kwargs["pad_token_id"] = kwargs["pad_token_id"][0] self.start_casa_streaming_states() outputs = self.generate( input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_embeds_insertion_points=image_embeds_insertion_points, use_cache=True, **kwargs, ) if reset_streaming: self.reset_casa_streaming_states() return outputs def reset_casa_streaming_states(self, clean_cache: bool = True) -> None: def __reset__(m: torch.nn.Module): if isinstance(m, Helium1CASAAttention): m._set_streaming(False, ()) m.reset_streaming() if clean_cache: del m.streaming_state.k del m.streaming_state.v del m.streaming_state.casa_handler self.apply(__reset__) def start_casa_streaming_states(self) -> None: def __start__(m: torch.nn.Module): if isinstance(m, Helium1CASAAttention): m._set_streaming(True, ()) self.apply(__start__)