|
|
from typing import Any, Literal |
|
|
|
|
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig |
|
|
|
|
|
|
|
|
class Qwen2_5_VLCASAConfig(Qwen2_5_VLConfig): |
|
|
"""Qwen config augmented with CASA options""" |
|
|
|
|
|
model_type = "qwen2_5vl_casa" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args: Any, |
|
|
|
|
|
xa_layers: None | tuple = None, |
|
|
xa_order: Literal["ca_first", "parallel", "instead"] = "ca_first", |
|
|
xa_norm_on_images: bool = False, |
|
|
xa_update_image_embeds: bool = False, |
|
|
|
|
|
casa_attention: bool = False, |
|
|
casa_delta_w: bool = False, |
|
|
casa_windows: Literal["batch", "squashed", "images", "turn_based"] = "batch", |
|
|
casa_use_asymetric_qkv: bool = True, |
|
|
**kwargs: Any, |
|
|
): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.head_dim = self.hidden_size // self.num_attention_heads |
|
|
self.xa_layers = xa_layers |
|
|
self.xa_order: Literal["ca_first", "parallel", "instead"] = xa_order |
|
|
self.xa_norm_on_images = xa_norm_on_images |
|
|
self.xa_update_image_embeds = xa_update_image_embeds |
|
|
|
|
|
self.casa_attention = casa_attention |
|
|
self.casa_delta_w = casa_delta_w |
|
|
self.casa_windows: Literal["batch", "squashed", "images", "turn_based"] = casa_windows |
|
|
self.casa_use_asymetric_qkv = casa_use_asymetric_qkv |
|
|
|