from typing import Any, Literal from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig class Qwen2_5_VLCASAConfig(Qwen2_5_VLConfig): """Qwen config augmented with CASA options""" model_type = "qwen2_5vl_casa" def __init__( self, *args: Any, # Common to all fusion mechanisms xa_layers: None | tuple = None, xa_order: Literal["ca_first", "parallel", "instead"] = "ca_first", xa_norm_on_images: bool = False, xa_update_image_embeds: bool = False, # CASA casa_attention: bool = False, casa_delta_w: bool = False, casa_windows: Literal["batch", "squashed", "images", "turn_based"] = "batch", casa_use_asymetric_qkv: bool = True, **kwargs: Any, ): super().__init__(*args, **kwargs) self.head_dim = self.hidden_size // self.num_attention_heads self.xa_layers = xa_layers self.xa_order: Literal["ca_first", "parallel", "instead"] = xa_order self.xa_norm_on_images = xa_norm_on_images self.xa_update_image_embeds = xa_update_image_embeds # CASA config self.casa_attention = casa_attention self.casa_delta_w = casa_delta_w self.casa_windows: Literal["batch", "squashed", "images", "turn_based"] = casa_windows self.casa_use_asymetric_qkv = casa_use_asymetric_qkv