Image-Text-to-Text
Transformers
Safetensors
English
CASA_Helium1_VL_2B
custom_code
File size: 12,831 Bytes
2ec00ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
from typing import Any, Callable
from typing import cast as type_cast

import torch
from transformers.cache_utils import DynamicCache
from transformers.configuration_utils import PretrainedConfig
from transformers.generation.utils import GenerateOutput
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VisionTransformerPretrainedModel,
)

from .image_encoder import Qwen25VLEncoder
from .configuration_helium1_casa import Helium1CASAConfig
from .language_helium1_casa import (
    CausalHeliumOutput,
    Helium1CASAAttention,
    Helium1ForCausalLM,
    Helium1RMSNorm,
)


def meta_project(
    logits: torch.Tensor | list[torch.Tensor],
    projector: torch.nn.Module,
    norm: torch.nn.Module | None = None,
) -> torch.Tensor | list[torch.Tensor]:
    """Projection operation that handles both tensors and list of tensors

    Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where
    S can be a different sequence length per image)
    """
    split_sizes: list[int] | None = None
    if not isinstance(logits, torch.Tensor):
        split_sizes = [_x.shape[0] for _x in logits]
        logits = torch.cat(logits, dim=0)[None, :, :]
    logits = type_cast(torch.Tensor, logits)
    logits = projector(logits)

    assert isinstance(logits, torch.Tensor)
    if norm is not None:
        logits = norm(logits)
    if split_sizes is not None:
        return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0))
    return logits


class ImageProjection(torch.nn.Module):
    """Takes in a batch or sequence of images and returns embeddings
      which are then fed to the LM.

    :param config: KyuteyeConfig object
    :param lm_model_dim: Output dimension (number of channels) for this module
    """

    def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None:
        super().__init__()
        self.config = config
        self.out_dim = lm_model_dim
        visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)

        self.enc = Qwen25VLEncoder(visual=visual)
        # Projection layer
        self.proj_extra = self.init_proj_module()
        # Output normalizations
        self.norm_extra = Helium1RMSNorm(self.out_dim)

    def init_proj_module(self) -> torch.nn.Module:
        """Init the project module for the inserted and/or cross-attended image tokens"""
        if self.config.vision_config.out_dim == self.out_dim:
            return torch.nn.Identity()
        return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim)

    def forward(
        self, x: torch.Tensor | list[torch.Tensor]
    ) -> dict[
        str,
        torch.Tensor | list[torch.Tensor],
    ]:
        """Image embedding mapping

        :param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors
        with shape (C, H, W) (or (H, W, C) in the case of Qwen)

        :return: Either a tensor with shape (num_total_image, S, D) or, if images
        can have different seq length, a list of `num_total_images` Tensors with shape
        (S, D)
        """

        # Apply image encoder
        og_dtype = x[0].dtype
        encoded = self.enc(x)["image_embeds"]
        encoded = [_x.to(og_dtype) for _x in encoded]
        if all(x.shape[0] == encoded[0].shape[0] for x in encoded):
            encoded = torch.stack(encoded, dim=0)

        # Extra projection
        image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra)

        # Apply different projection for extra vs cross attended tokens
        return {"image_embeds": image_embeds}


class V2Helium1(Helium1ForCausalLM):  # pyright: ignore[reportIncompatibleMethodOverride]
    config_class = Helium1CASAConfig

    def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
        del kwargs
        super().__init__(config)
        self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim)

    def get_device(self) -> str:
        """Return the device type of the model"""
        return next(self.parameters()).device.type

    @property
    def token_dim(self) -> int:
        """Returns the number of dimensions for the token representation"""
        return self.config.hidden_size

    @property
    def rotary_embed(self) -> Callable:
        """Returns the rotary embedding function of the underlying model"""
        return self.model.rotary_emb

    def _update_model_kwargs_for_generation(
        self,
        outputs: Any,
        model_kwargs: dict[str, Any],
        is_encoder_decoder: bool = False,
        num_new_tokens: int = 1,
    ):
        """This is required to handle multiple gen calls for subtitles"""
        # Call parent to get default updates
        model_kwargs = super()._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder, num_new_tokens
        )
        # Used by prepare_inputs_for_generation
        model_kwargs["__is_first_gen_call__"] = False
        return model_kwargs

    def prepare_inputs_for_generation(  # pyright: ignore[reportIncompatibleMethodOverride]
        self,
        input_ids: torch.Tensor,
        past_key_values: DynamicCache | None = None,
        **kwargs: Any,
    ):
        __is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True)
        if past_key_values is not None and (
            kwargs.get("cache_position") is None
            or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
        ):
            # We're continuing from a cached state
            past_length = past_key_values._seen_tokens
            kwargs["cache_position"] = torch.arange(
                past_length,
                past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
                dtype=torch.long,
                device=input_ids.device,
            )

        return super().prepare_inputs_for_generation(
            type_cast(torch.LongTensor, input_ids),
            past_key_values=past_key_values,
            **kwargs,
        )

    def prepare_multimodal_inputs(
        self,
        # text only training
        input_ids: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        image_embeds_insertion_points: list[torch.Tensor] | None = None,
        labels: torch.Tensor | None = None,
        # image values
        pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
        pre_image_tokens: list[int] | None = None,
        post_image_tokens: list[int] | None = None,
        **_kwargs: Any,
    ) -> dict:
        """Get a batch data mixing text and image data"""
        del _kwargs

        processed_inputs = {
            "input_ids": input_ids,
            "inputs_embeds": inputs_embeds,
            "labels": labels,
            "attention_mask": attention_mask,
            "image_embeds_insertion_points": image_embeds_insertion_points,
        }
        if pixel_values is not None:
            processed_inputs.update(self.image_prefix(pixel_values))
            assert "image_embeds" in processed_inputs
            assert (
                isinstance(processed_inputs["image_embeds"], torch.Tensor)
                and processed_inputs["image_embeds"].ndim == 3
            ) or (
                isinstance(processed_inputs["image_embeds"], list)
                and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
            )

        # Add kwargs necessary to compute cu_seqlens windows for CASA
        processed_inputs["casa_windows_info"] = {
            "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
            "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
        }

        return processed_inputs

    def forward(  # pyright: ignore[reportIncompatibleMethodOverride]
        self,
        input_ids: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
        return_loss: bool = True,
        labels: torch.Tensor | None = None,
        image_embeds_insertion_points: list[torch.Tensor] | None = None,
        pre_image_tokens: list[int] | None = None,
        post_image_tokens: list[int] | None = None,
        **kwargs: Any,
    ) -> CausalHeliumOutput:
        """Multi modal forward pass"""
        assert input_ids is not None or inputs_embeds is not None

        if self.training:
            assert return_loss is True, (
                "Helium models always compute its own labels/losses in train mode"
            )

        # Case 1: For first generation call we need to compute pixel values and CASA states
        if kwargs.get("__is_first_gen_call__", True):
            processed_inputs = self.prepare_multimodal_inputs(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                image_embeds_insertion_points=image_embeds_insertion_points,
                pixel_values=pixel_values,
                labels=labels,
                pre_image_tokens=pre_image_tokens,
                post_image_tokens=post_image_tokens,
            )
            processed_inputs.pop("inputs_embeds", None)
        else:
            processed_inputs = {
                "inputs_embeds": self.model.embed_tokens(input_ids),
                "attention_mask": attention_mask,
            }

        # For Helium prefix, we need to update the positions by the number
        # of image tokens inserted in the first call
        if (
            not self.config.casa_attention
            and (cp := kwargs.get("cache_position", None)) is not None
            and pixel_values is not None
        ):
            start = kwargs["cache_position"][0].item()
            num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4
            num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]  # type: ignore
            kwargs["cache_position"] = torch.arange(
                start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens),
                start + num_tokens + num_image_tokens,
                dtype=cp.dtype,
                device=cp.device,
            )

        kwargs.pop("__is_first_gen_call__", True)
        out = super().forward(
            **processed_inputs,  # type: ignore
            **kwargs,
        )

        return out

    @torch.no_grad()
    def generate_from_image(  # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride]
        self,
        input_ids: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        image_embeds_insertion_points: list[torch.Tensor] | None = None,
        pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
        reset_streaming: bool = True,
        **kwargs: Any,
    ) -> "GenerateOutput | torch.LongTensor":
        assert input_ids is not None and inputs_embeds is None, (
            "Input IDs must be provided for generation"
        )

        # init self-attention KVCache
        if kwargs.get("past_key_values", None) is None:
            kwargs["past_key_values"] = DynamicCache()

        # To avoid generate warning
        if kwargs.get("pad_token_id", None) is None:
            kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
            if isinstance(kwargs["pad_token_id"], (list, tuple)):
                kwargs["pad_token_id"] = kwargs["pad_token_id"][0]

        self.start_casa_streaming_states()
        outputs = self.generate(
            input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_embeds_insertion_points=image_embeds_insertion_points,
            use_cache=True,
            **kwargs,
        )
        if reset_streaming:
            self.reset_casa_streaming_states()
        return outputs

    def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
        def __reset__(m: torch.nn.Module):
            if isinstance(m, Helium1CASAAttention):
                m._set_streaming(False, ())
                m.reset_streaming()
                if clean_cache:
                    del m.streaming_state.k
                    del m.streaming_state.v
                    del m.streaming_state.casa_handler

        self.apply(__reset__)

    def start_casa_streaming_states(self) -> None:
        def __start__(m: torch.nn.Module):
            if isinstance(m, Helium1CASAAttention):
                m._set_streaming(True, ())

        self.apply(__start__)