File size: 12,831 Bytes
2ec00ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 | from typing import Any, Callable
from typing import cast as type_cast
import torch
from transformers.cache_utils import DynamicCache
from transformers.configuration_utils import PretrainedConfig
from transformers.generation.utils import GenerateOutput
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)
from .image_encoder import Qwen25VLEncoder
from .configuration_helium1_casa import Helium1CASAConfig
from .language_helium1_casa import (
CausalHeliumOutput,
Helium1CASAAttention,
Helium1ForCausalLM,
Helium1RMSNorm,
)
def meta_project(
logits: torch.Tensor | list[torch.Tensor],
projector: torch.nn.Module,
norm: torch.nn.Module | None = None,
) -> torch.Tensor | list[torch.Tensor]:
"""Projection operation that handles both tensors and list of tensors
Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where
S can be a different sequence length per image)
"""
split_sizes: list[int] | None = None
if not isinstance(logits, torch.Tensor):
split_sizes = [_x.shape[0] for _x in logits]
logits = torch.cat(logits, dim=0)[None, :, :]
logits = type_cast(torch.Tensor, logits)
logits = projector(logits)
assert isinstance(logits, torch.Tensor)
if norm is not None:
logits = norm(logits)
if split_sizes is not None:
return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0))
return logits
class ImageProjection(torch.nn.Module):
"""Takes in a batch or sequence of images and returns embeddings
which are then fed to the LM.
:param config: KyuteyeConfig object
:param lm_model_dim: Output dimension (number of channels) for this module
"""
def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None:
super().__init__()
self.config = config
self.out_dim = lm_model_dim
visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.enc = Qwen25VLEncoder(visual=visual)
# Projection layer
self.proj_extra = self.init_proj_module()
# Output normalizations
self.norm_extra = Helium1RMSNorm(self.out_dim)
def init_proj_module(self) -> torch.nn.Module:
"""Init the project module for the inserted and/or cross-attended image tokens"""
if self.config.vision_config.out_dim == self.out_dim:
return torch.nn.Identity()
return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim)
def forward(
self, x: torch.Tensor | list[torch.Tensor]
) -> dict[
str,
torch.Tensor | list[torch.Tensor],
]:
"""Image embedding mapping
:param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors
with shape (C, H, W) (or (H, W, C) in the case of Qwen)
:return: Either a tensor with shape (num_total_image, S, D) or, if images
can have different seq length, a list of `num_total_images` Tensors with shape
(S, D)
"""
# Apply image encoder
og_dtype = x[0].dtype
encoded = self.enc(x)["image_embeds"]
encoded = [_x.to(og_dtype) for _x in encoded]
if all(x.shape[0] == encoded[0].shape[0] for x in encoded):
encoded = torch.stack(encoded, dim=0)
# Extra projection
image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra)
# Apply different projection for extra vs cross attended tokens
return {"image_embeds": image_embeds}
class V2Helium1(Helium1ForCausalLM): # pyright: ignore[reportIncompatibleMethodOverride]
config_class = Helium1CASAConfig
def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
del kwargs
super().__init__(config)
self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim)
def get_device(self) -> str:
"""Return the device type of the model"""
return next(self.parameters()).device.type
@property
def token_dim(self) -> int:
"""Returns the number of dimensions for the token representation"""
return self.config.hidden_size
@property
def rotary_embed(self) -> Callable:
"""Returns the rotary embedding function of the underlying model"""
return self.model.rotary_emb
def _update_model_kwargs_for_generation(
self,
outputs: Any,
model_kwargs: dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
):
"""This is required to handle multiple gen calls for subtitles"""
# Call parent to get default updates
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
)
# Used by prepare_inputs_for_generation
model_kwargs["__is_first_gen_call__"] = False
return model_kwargs
def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor,
past_key_values: DynamicCache | None = None,
**kwargs: Any,
):
__is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True)
if past_key_values is not None and (
kwargs.get("cache_position") is None
or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
):
# We're continuing from a cached state
past_length = past_key_values._seen_tokens
kwargs["cache_position"] = torch.arange(
past_length,
past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
dtype=torch.long,
device=input_ids.device,
)
return super().prepare_inputs_for_generation(
type_cast(torch.LongTensor, input_ids),
past_key_values=past_key_values,
**kwargs,
)
def prepare_multimodal_inputs(
self,
# text only training
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
labels: torch.Tensor | None = None,
# image values
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
**_kwargs: Any,
) -> dict:
"""Get a batch data mixing text and image data"""
del _kwargs
processed_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"labels": labels,
"attention_mask": attention_mask,
"image_embeds_insertion_points": image_embeds_insertion_points,
}
if pixel_values is not None:
processed_inputs.update(self.image_prefix(pixel_values))
assert "image_embeds" in processed_inputs
assert (
isinstance(processed_inputs["image_embeds"], torch.Tensor)
and processed_inputs["image_embeds"].ndim == 3
) or (
isinstance(processed_inputs["image_embeds"], list)
and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
)
# Add kwargs necessary to compute cu_seqlens windows for CASA
processed_inputs["casa_windows_info"] = {
"num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
"num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
}
return processed_inputs
def forward( # pyright: ignore[reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
return_loss: bool = True,
labels: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
**kwargs: Any,
) -> CausalHeliumOutput:
"""Multi modal forward pass"""
assert input_ids is not None or inputs_embeds is not None
if self.training:
assert return_loss is True, (
"Helium models always compute its own labels/losses in train mode"
)
# Case 1: For first generation call we need to compute pixel values and CASA states
if kwargs.get("__is_first_gen_call__", True):
processed_inputs = self.prepare_multimodal_inputs(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
image_embeds_insertion_points=image_embeds_insertion_points,
pixel_values=pixel_values,
labels=labels,
pre_image_tokens=pre_image_tokens,
post_image_tokens=post_image_tokens,
)
processed_inputs.pop("inputs_embeds", None)
else:
processed_inputs = {
"inputs_embeds": self.model.embed_tokens(input_ids),
"attention_mask": attention_mask,
}
# For Helium prefix, we need to update the positions by the number
# of image tokens inserted in the first call
if (
not self.config.casa_attention
and (cp := kwargs.get("cache_position", None)) is not None
and pixel_values is not None
):
start = kwargs["cache_position"][0].item()
num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4
num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # type: ignore
kwargs["cache_position"] = torch.arange(
start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens),
start + num_tokens + num_image_tokens,
dtype=cp.dtype,
device=cp.device,
)
kwargs.pop("__is_first_gen_call__", True)
out = super().forward(
**processed_inputs, # type: ignore
**kwargs,
)
return out
@torch.no_grad()
def generate_from_image( # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
reset_streaming: bool = True,
**kwargs: Any,
) -> "GenerateOutput | torch.LongTensor":
assert input_ids is not None and inputs_embeds is None, (
"Input IDs must be provided for generation"
)
# init self-attention KVCache
if kwargs.get("past_key_values", None) is None:
kwargs["past_key_values"] = DynamicCache()
# To avoid generate warning
if kwargs.get("pad_token_id", None) is None:
kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
if isinstance(kwargs["pad_token_id"], (list, tuple)):
kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
self.start_casa_streaming_states()
outputs = self.generate(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_embeds_insertion_points=image_embeds_insertion_points,
use_cache=True,
**kwargs,
)
if reset_streaming:
self.reset_casa_streaming_states()
return outputs
def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
def __reset__(m: torch.nn.Module):
if isinstance(m, Helium1CASAAttention):
m._set_streaming(False, ())
m.reset_streaming()
if clean_cache:
del m.streaming_state.k
del m.streaming_state.v
del m.streaming_state.casa_handler
self.apply(__reset__)
def start_casa_streaming_states(self) -> None:
def __start__(m: torch.nn.Module):
if isinstance(m, Helium1CASAAttention):
m._set_streaming(True, ())
self.apply(__start__)
|