Image-Text-to-Text
Transformers
Safetensors
English
Helium1_VL_2B
custom_code
ameroyer commited on
Commit
1126ea7
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Notice ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Helium1-VL-2B's image encoder is finetuned from the image encoder of Qwen2.5-VL-3B.
2
+ Qwen is licensed under the Qwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ base_model:
5
+ - kyutai/helium-1-2b
6
+ pipeline_tag: image-text-to-text
7
+ license: cc-by-nc-sa-4.0
8
+ datasets:
9
+ - HuggingFaceM4/FineVision
10
+ - mvp-lab/LLaVA-OneVision-1.5-Instruct-Data
11
+ ---
12
+ Please refer to the [main model card](https://huggingface.co/kyutai/CASA-Helium1-VL-2B) for more information and instructions to run.
13
+
14
+ This model page contains model weights for `Helium1-VL-2B`, a Helium1-2B model which is instruct-tuned and further trained to handle visual inputs using a pretrained encoder from Qwen-2.5VL.
15
+ This model is released as part of our CASA model release. We provide model weights for CASA models in the associated model collection.
__init__.py ADDED
File without changes
casa_attention.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CASA layers"""
2
+
3
+ import bisect
4
+ from dataclasses import dataclass
5
+ from itertools import accumulate
6
+ from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload
7
+ from typing import cast as type_cast
8
+
9
+ import torch
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+ from .utils import StreamingModule, StreamingState, delta_w_factory
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers.configuration_utils import PretrainedConfig
16
+
17
+ try:
18
+ from flash_attn import flash_attn_varlen_func
19
+ except ImportError:
20
+ flash_attn_varlen_func = None # type: ignore
21
+
22
+
23
+ WindowsComputeKwargs = TypedDict(
24
+ "WindowsComputeKwargs",
25
+ {
26
+ "num_post_image_tokens": int,
27
+ "num_pre_image_tokens": int,
28
+ },
29
+ total=False,
30
+ )
31
+
32
+
33
+ def __split_n_merge__(
34
+ x: torch.Tensor,
35
+ sample_lengths: list[int],
36
+ padding_side: Literal["left", "right"] = "right",
37
+ pad_value: int | float | bool = 0,
38
+ ) -> torch.Tensor:
39
+ max_sample_length = max(sample_lengths)
40
+ pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2))
41
+ return torch.stack(
42
+ [
43
+ torch.nn.functional.pad(
44
+ _x,
45
+ pad_tuple + (0, max_sample_length - _x.shape[0])
46
+ if padding_side == "right"
47
+ else pad_tuple + (max_sample_length - _x.shape[0], 0),
48
+ value=pad_value,
49
+ )
50
+ for _x in torch.split(x, sample_lengths, dim=0)
51
+ ],
52
+ dim=0,
53
+ )
54
+
55
+
56
+ @overload
57
+ def insert_image_tokens(
58
+ inputs_embeds: torch.Tensor,
59
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
60
+ image_embeds_insertion_points: list[torch.Tensor],
61
+ recover_batch_dim: Literal[True],
62
+ attention_mask: torch.Tensor | None = None,
63
+ padding_side: Literal["left", "right"] = "right",
64
+ keep_only_attended: bool = False,
65
+ pad_output: int | float | bool = 0.0,
66
+ ) -> tuple[
67
+ torch.Tensor,
68
+ None,
69
+ torch.Tensor | None,
70
+ torch.Tensor,
71
+ ]: ...
72
+ @overload
73
+ def insert_image_tokens(
74
+ inputs_embeds: torch.Tensor,
75
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
76
+ image_embeds_insertion_points: list[torch.Tensor],
77
+ recover_batch_dim: Literal[False],
78
+ attention_mask: torch.Tensor | None = None,
79
+ padding_side: Literal["left", "right"] = "right",
80
+ keep_only_attended: bool = False,
81
+ pad_output: int | float | bool = 0.0,
82
+ ) -> tuple[
83
+ torch.Tensor,
84
+ list[int],
85
+ torch.Tensor | None,
86
+ torch.Tensor,
87
+ ]: ...
88
+ def insert_image_tokens(
89
+ inputs_embeds: torch.Tensor,
90
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
91
+ image_embeds_insertion_points: list[torch.Tensor],
92
+ recover_batch_dim: bool = True,
93
+ attention_mask: torch.Tensor | None = None,
94
+ padding_side: Literal["left", "right"] = "right",
95
+ keep_only_attended: bool = False,
96
+ pad_output: int | float | bool = 0.0,
97
+ ) -> tuple[
98
+ torch.Tensor | torch.Tensor,
99
+ list[int] | None,
100
+ torch.Tensor | torch.Tensor | None,
101
+ torch.Tensor | torch.Tensor,
102
+ ]:
103
+ """
104
+ Insert image embeddings into text embeddings
105
+
106
+ Args:
107
+ inputs_embeds (torch.Tensor): (B, S, D) input token embeddings.
108
+ image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings.
109
+ image_embeds_insertion_points (list[torch.Tensor]): Insertion indices.
110
+ attention_mask (torch.Tensor, optional): (B, S) attention mask.
111
+ padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images.
112
+ return_indices (bool): Whether to return gather indices or the fused sequence directly.
113
+ keep_only_attended: This is only applicable when recover_batch_dim is False; whether to
114
+ remove any non-attended tokens in the whole array. In this case, the attention
115
+ mask returned is **still the original one**, so we can remember which indices have been
116
+ removed
117
+ Returns:
118
+ output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence
119
+ image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list
120
+ attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding.
121
+ image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions.
122
+ """
123
+ if isinstance(image_embeds, list) and len(image_embeds) == 0:
124
+ batch_size, text_seq_length, token_dim = inputs_embeds.shape
125
+ if recover_batch_dim:
126
+ return (
127
+ inputs_embeds,
128
+ None,
129
+ attention_mask,
130
+ torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool),
131
+ )
132
+ else:
133
+ flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1]
134
+ return (
135
+ torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])),
136
+ [text_seq_length] * inputs_embeds.shape[0],
137
+ attention_mask.flatten() if attention_mask is not None else None,
138
+ torch.zeros((flattened_seq_length, 1), dtype=torch.bool),
139
+ )
140
+
141
+ # Sanity checks
142
+ if isinstance(image_embeds, torch.Tensor):
143
+ assert inputs_embeds.shape[-1] == image_embeds.shape[-1]
144
+ else:
145
+ assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds)
146
+
147
+ batch_size, text_seq_length, token_dim = inputs_embeds.shape
148
+ image_seq_length = [x.shape[0] for x in image_embeds]
149
+
150
+ # Flatten insertion points
151
+ insertion_offset = []
152
+ counter, offset_from_text, offset_from_image = 0, 0, 0
153
+ for sample in image_embeds_insertion_points:
154
+ for pt in sample:
155
+ insertion_offset.append(pt + offset_from_image + offset_from_text)
156
+ offset_from_image += image_seq_length[counter]
157
+ counter += 1
158
+ offset_from_text += text_seq_length
159
+ image_insert_positions = [
160
+ x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx])
161
+ ]
162
+
163
+ # Flatten image embeds
164
+ if isinstance(image_embeds, list):
165
+ image_embeds = torch.cat(image_embeds, dim=0)
166
+ else:
167
+ image_embeds = type_cast(torch.Tensor, image_embeds)
168
+ image_embeds = torch.reshape(image_embeds, (-1, token_dim))
169
+
170
+ # Flatten text embeds across batch dim (B x S, D)
171
+ inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim))
172
+ flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length)
173
+ text_insert_positions = sorted(
174
+ set(range(flattened_seq_length)).difference(set(image_insert_positions))
175
+ )
176
+
177
+ # Scatter image embeds in the flattened dict
178
+ # scatter text related stuff
179
+ output = torch.empty(
180
+ (flattened_seq_length, token_dim),
181
+ device=inputs_embeds.device,
182
+ dtype=inputs_embeds.dtype,
183
+ )
184
+ txt_positions_tensor = torch.Tensor(text_insert_positions).to(
185
+ dtype=torch.long, device=inputs_embeds.device
186
+ )
187
+ output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds)
188
+ attention_mask_new: torch.Tensor | None = None
189
+ if attention_mask is not None:
190
+ attention_mask_new = torch.ones(
191
+ (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
192
+ )
193
+ attention_mask_new.scatter_(
194
+ 0, txt_positions_tensor, attention_mask.flatten().to(torch.bool)
195
+ )
196
+
197
+ # scatter image related stuff
198
+ image_tokens_mask = torch.zeros(
199
+ (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
200
+ )
201
+ img_positions_tensor = torch.Tensor(image_insert_positions).to(
202
+ device=inputs_embeds.device, dtype=torch.long
203
+ )
204
+ output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds)
205
+ image_tokens_mask.scatter_(0, img_positions_tensor, True)
206
+
207
+ # Compute expected sample length, taking into account the real batch
208
+ # i.e. recover the batch dimension of image embeddings
209
+ sample_lengths = []
210
+ counter = 0
211
+ for sample_idx, pts in enumerate(image_embeds_insertion_points):
212
+ num_image_tokens = 0
213
+ for _ in pts:
214
+ num_image_tokens += image_seq_length[counter]
215
+ counter += 1
216
+ if keep_only_attended and attention_mask is not None:
217
+ attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item()
218
+ sample_lengths.append(attended_seq_length + num_image_tokens)
219
+ else:
220
+ sample_lengths.append(text_seq_length + num_image_tokens)
221
+
222
+ # For CASA attention, we can keep stuff flatten ad return
223
+ # the sample_lengths for the blockwise attention
224
+ if not recover_batch_dim:
225
+ if keep_only_attended and attention_mask_new is not None:
226
+ output = output[attention_mask_new]
227
+ image_tokens_mask = image_tokens_mask[attention_mask_new]
228
+ return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None]
229
+
230
+ # Otherwise, time to (pad) and reshape
231
+ # Easy case: everything has the same length
232
+ if all(x == sample_lengths[0] for x in sample_lengths):
233
+ output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim))
234
+ image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1))
235
+ if attention_mask_new is not None:
236
+ attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0]))
237
+ # if there is any size mismatch we break into a
238
+ # list and pad again
239
+ else:
240
+ # split and merge
241
+ output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output)
242
+ # note that the extra padding tokens are also marked as image tokens to be removed later
243
+ image_tokens_mask = __split_n_merge__(
244
+ image_tokens_mask, sample_lengths, padding_side, True
245
+ )[:, :, None]
246
+ if attention_mask_new is not None:
247
+ attention_mask_new = __split_n_merge__(
248
+ attention_mask_new, sample_lengths, padding_side, 0
249
+ )
250
+ # Return
251
+ return output, sample_lengths, attention_mask_new, image_tokens_mask
252
+
253
+
254
+ def get_sample_lengths_from_insertion_points(
255
+ image_embeds_insertion_points: list[torch.Tensor],
256
+ image_embeds: torch.Tensor | list[torch.Tensor] | None,
257
+ total_seq_len: int | None = None,
258
+ attention_mask: torch.Tensor | None = None,
259
+ **kwargs: WindowsComputeKwargs,
260
+ ) -> tuple[list[tuple[int, bool]], list[int]]:
261
+ """Compute sample lengths as if each image insertion point defines a
262
+ new document (ex document ID)
263
+ """
264
+ num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0))
265
+ num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0))
266
+ squashed_samples_lengths = type_cast(
267
+ list[list[int]] | None, kwargs.get("squashed_samples_lengths", None)
268
+ )
269
+ if squashed_samples_lengths is not None:
270
+ assert len(squashed_samples_lengths) == len(image_embeds_insertion_points)
271
+
272
+ def __insert_next_sample__(
273
+ batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False
274
+ ) -> None:
275
+ nonlocal attention_mask
276
+ nonlocal text_sample_lengths, full_sample_lengths
277
+ nonlocal cum_samples_lengths, current_image_offset
278
+ nonlocal last_image_idx, current_image_idx, current_length
279
+ # Add the sample between [last_insrt_pt, insrt_pt] with breaks in
280
+ # between any squashed samples we find on the way
281
+ start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt)
282
+ added_sample = False
283
+ for end_of_sample in cum_samples_lengths[start_pt:]:
284
+ # we will break the loop at the end when end_of_sample = insrt_pt
285
+ end_of_sample = min(end_of_sample, insrt_pt)
286
+
287
+ # Add between [last_insrt_pt, end_of_sample]
288
+ current_length = end_of_sample - last_insrt_pt
289
+ if attention_mask is not None:
290
+ current_length -= int(
291
+ torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item()
292
+ )
293
+ if current_length > 0:
294
+ added_sample = True
295
+ text_sample_lengths.append(
296
+ (current_length, end_of_batch_sample and insrt_pt == end_of_sample)
297
+ )
298
+ # add image tokens to current_length
299
+ if current_image_idx > 0 and image_embeds is not None:
300
+ images_in_sample = [
301
+ img_idx
302
+ for img_idx in range(last_image_idx, current_image_idx)
303
+ if img_idx < len(image_embeds_insertion_points[batch_idx])
304
+ and last_insrt_pt
305
+ <= image_embeds_insertion_points[batch_idx][img_idx]
306
+ < end_of_sample
307
+ ]
308
+ if len(images_in_sample) > 0:
309
+ num_image_tokens = sum(
310
+ _x.shape[0]
311
+ for _x in image_embeds[
312
+ current_image_offset + images_in_sample[0] : current_image_offset
313
+ + images_in_sample[-1]
314
+ + 1
315
+ ]
316
+ )
317
+ current_length += num_image_tokens
318
+ full_sample_lengths.append(current_length)
319
+
320
+ # prepare for next loop
321
+ last_insrt_pt = end_of_sample
322
+ if end_of_sample == insrt_pt:
323
+ break
324
+ # End of loop: Catching weird use case where we may end up on a span
325
+ # full of padding tokens which will not get added due to current_length > 0
326
+ if end_of_batch_sample:
327
+ assert added_sample, "Weird edge case. Don't do that, thank you"
328
+ text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
329
+
330
+ # End of loop: Catching weird use case where we may end up on a span
331
+ # full of padding tokens which will not get added due to current_length > 0
332
+ if end_of_batch_sample:
333
+ assert added_sample, "Weird edge case. Don't do that, thank you"
334
+ text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
335
+
336
+ current_image_offset = 0
337
+ text_sample_lengths, full_sample_lengths = [], []
338
+ cum_samples_lengths: list[int] = []
339
+ current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
340
+ for batch_idx, pts in enumerate(image_embeds_insertion_points):
341
+ if squashed_samples_lengths is not None:
342
+ cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx]))
343
+ else:
344
+ assert total_seq_len is not None
345
+ cum_samples_lengths = [total_seq_len]
346
+
347
+ for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()):
348
+ # check if the images are consecutive in which way we want
349
+ # them to belong to the same window
350
+ if current_image_idx >= 1 and insrt_pt == (
351
+ image_embeds_insertion_points[batch_idx][current_image_idx - 1]
352
+ + num_pre_image_tokens
353
+ + num_post_image_tokens
354
+ ):
355
+ continue
356
+ # Otherwise, we found a new sample
357
+ # not very important but for completeness: the insertion points come *after*
358
+ # the pre-image tokens per design but for the document-id mask it is more consistent to
359
+ # have them correspond to the same image
360
+ insrt_pt -= num_pre_image_tokens
361
+
362
+ # Update text and full sample lengths
363
+ if insrt_pt > last_insrt_pt:
364
+ __insert_next_sample__(
365
+ batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False
366
+ )
367
+ last_image_idx = current_image_idx
368
+ last_insrt_pt = insrt_pt
369
+
370
+ # End of batch: add sample in progress and reset
371
+ current_image_idx += 1
372
+ if cum_samples_lengths[-1] > last_insrt_pt:
373
+ __insert_next_sample__(
374
+ batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True
375
+ )
376
+ current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
377
+ current_image_offset += len(pts)
378
+
379
+ # Sanity checks that the is_eob are correctly place
380
+ assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), (
381
+ f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs"
382
+ f" from original batch size ({len(image_embeds_insertion_points)})"
383
+ )
384
+ return text_sample_lengths, full_sample_lengths
385
+
386
+
387
+ class CASAAttentionHandler:
388
+ def __init__(
389
+ self,
390
+ inputs_embeds: torch.Tensor,
391
+ image_embeds: torch.Tensor | list[torch.Tensor],
392
+ image_embeds_insertion_points: list[torch.Tensor],
393
+ attention_mask: torch.Tensor | None = None,
394
+ rope_fn: Callable | None = None,
395
+ windows: Literal["batch", "squashed", "images", "turn_based"] = "images",
396
+ use_asymetric_q_kv: bool = True,
397
+ casa_windows_info: None | dict = None,
398
+ ):
399
+ """Initialize the structure holding the query buffer for CASA attention layers
400
+ (ie the **flattened** text+image inserted tokens).
401
+ Note that this structure is shared across all casa layers, and it gets updated
402
+ with the current hidden states at every layer; this is merely a buffer to keep
403
+ scatter_ operations in-plae as much as possible
404
+
405
+ In this module, the embeddings related values (image_tokens_mask,
406
+ text_sample_lengths etc) are stored under the assumption of a tensor
407
+ which is *flatened* and *witout padding tokens*
408
+ Only the attention mask is kept as-is (text-only, batched, padded) to
409
+ be able to recover original shapes when needed
410
+ """
411
+ super().__init__()
412
+ assert windows == "images" # for inference code release
413
+ # Note 1: Unless overriden, text/full_sample_lengths are defined such that one
414
+ # document = one sample in the batch
415
+ if attention_mask is None:
416
+ text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds]
417
+ else:
418
+ text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask]
419
+ (
420
+ full_inputs_embeds,
421
+ full_sample_lengths,
422
+ # Full attention mask is only needed at inference to
423
+ # flatten the KV-Cache and remove padding tokens
424
+ _,
425
+ self.image_tokens_mask,
426
+ ) = insert_image_tokens(
427
+ inputs_embeds=inputs_embeds,
428
+ image_embeds=image_embeds,
429
+ image_embeds_insertion_points=image_embeds_insertion_points,
430
+ attention_mask=attention_mask,
431
+ recover_batch_dim=False,
432
+ keep_only_attended=attention_mask is not None,
433
+ )
434
+ assert self.image_tokens_mask.ndim == 2
435
+ self.image_embeds = image_embeds
436
+ self.image_embeds_insertion_points = image_embeds_insertion_points
437
+ self.attention_mask = None if attention_mask is None else attention_mask.bool()
438
+ self.use_asymetric_qkv = use_asymetric_q_kv
439
+ # At inference, we have to use asymetric QKV for efficiency
440
+ if self.attention_mask is not None:
441
+ self.use_asymetric_qkv = True
442
+
443
+ # Build CASA windows
444
+ assert casa_windows_info is not None
445
+ text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points(
446
+ image_embeds_insertion_points=image_embeds_insertion_points,
447
+ image_embeds=image_embeds,
448
+ total_seq_len=inputs_embeds.shape[1],
449
+ attention_mask=self.attention_mask,
450
+ **casa_windows_info, # pyright: ignore
451
+ )
452
+
453
+ # Sanity checks on the sample lengths
454
+ self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0]
455
+ self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0]
456
+
457
+ assert len(self.text_sample_lengths) == len(self.full_sample_lengths), (
458
+ f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}"
459
+ f" != full sample lengths {len(self.full_sample_lengths)}"
460
+ )
461
+ if self.attention_mask is None:
462
+ num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1]
463
+ else:
464
+ num_unpadded_text_tokens = int(
465
+ torch.sum(type_cast(torch.Tensor, attention_mask)).item()
466
+ )
467
+ assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, (
468
+ f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
469
+ )
470
+ assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], (
471
+ f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
472
+ )
473
+
474
+ # Finally we can compute cu_seqlen based on sample lengths
475
+ self.max_seqlen_q = max(self.text_sample_lengths)[0]
476
+ self.cu_seqlens_q = self.get_cu_seqlens(
477
+ [x[0] for x in self.text_sample_lengths], device=inputs_embeds.device
478
+ )
479
+
480
+ self.max_seqlen_kv = max(self.full_sample_lengths)
481
+ self.cu_seqlens_kv = self.get_cu_seqlens(
482
+ self.full_sample_lengths, device=inputs_embeds.device
483
+ )
484
+
485
+ # For inference: We save the length of the current document
486
+ # to trim the KV cache appropriately
487
+ self.current_doc_lengths = self.full_sample_lengths
488
+
489
+ # Precompute position embeddings
490
+ self.position_embeds = None
491
+ self.rope_fn = rope_fn
492
+ if self.rope_fn is not None:
493
+ self.position_embeds = self.compute_position_embeddings(
494
+ self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds
495
+ )
496
+
497
+ @property
498
+ def batch_lengths(self) -> list[int]:
499
+ """Return a (batch_size,) list of integers containing the
500
+ number of (non-padded) text tokens for each sample in the batch"""
501
+ bls = [0]
502
+ for ln, eob in self.text_sample_lengths:
503
+ bls[-1] += ln
504
+ if eob:
505
+ bls.append(0)
506
+ return bls[:-1]
507
+
508
+ @property
509
+ def full_batch_lengths(self) -> list[int]:
510
+ """Same as batch_lengths for text+image tokens"""
511
+ bls = [0]
512
+ for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths):
513
+ bls[-1] += ln
514
+ if eob:
515
+ bls.append(0)
516
+ return bls[:-1]
517
+
518
+ def get_cu_seqlens(
519
+ self, sample_lengths: list[int], device: torch.device | None
520
+ ) -> torch.Tensor:
521
+ """Update cu_seqlengths according to the given sample_lengths"""
522
+ return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to(
523
+ dtype=torch.int32, device=device
524
+ )
525
+
526
+ def compute_position_embeddings(
527
+ self,
528
+ rope_fn: Callable,
529
+ sample_lengths: list[int],
530
+ dummy_for_dtype_and_device: torch.Tensor,
531
+ ) -> tuple[torch.Tensor, torch.Tensor]:
532
+ """Compute info required for position embeddings. Can be override e.g. for Qwen"""
533
+ # option 1: Standard range
534
+ # position_ids = torch.arange(0, full_inputs_embeds.shape[0])
535
+ # option 2: Follows document boundary
536
+ position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0)
537
+ return rope_fn(
538
+ dummy_for_dtype_and_device,
539
+ position_ids.to(dummy_for_dtype_and_device.device)[None, ...],
540
+ )
541
+
542
+ def get_position_embedding(
543
+ self,
544
+ key: Literal["q", "kv"],
545
+ num_queries: int = 0,
546
+ ) -> tuple[torch.Tensor, torch.Tensor] | None:
547
+ if self.position_embeds is None:
548
+ return None
549
+ cos, sin = self.position_embeds
550
+ bls = self.full_batch_lengths
551
+ # For Q, we only want the text-only posembeds
552
+ if key == "q" and self.use_asymetric_qkv:
553
+ bls = self.batch_lengths
554
+ cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]]
555
+ elif key not in {"q", "kv"}:
556
+ raise ValueError(f"Unknow for position embedding {key}")
557
+
558
+ # Easy case: training or first step at inference: we use all the posembeds
559
+ if num_queries == 0:
560
+ return cos, sin
561
+ # If num queries is given, we need to trim for *every sample in the batch*
562
+ cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)]
563
+ sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)]
564
+ return torch.cat(cos, dim=1), torch.cat(sin, dim=1)
565
+
566
+ def get_full_embeds(
567
+ self, hidden_states: torch.Tensor, norm_fn: Callable | None
568
+ ) -> torch.Tensor:
569
+ """Update attended hidden states in the current query buffer
570
+
571
+ :param hidden_states: (b, s, d) Tensor input to the CASA attention layer"
572
+ """
573
+ assert self.image_embeds is not None
574
+ return insert_image_tokens(
575
+ inputs_embeds=hidden_states,
576
+ image_embeds=self.image_embeds
577
+ if norm_fn is None
578
+ else norm_fn(self.image_embeds)
579
+ if isinstance(self.image_embeds, torch.Tensor)
580
+ else [norm_fn(_x) for _x in self.image_embeds],
581
+ image_embeds_insertion_points=self.image_embeds_insertion_points,
582
+ attention_mask=self.attention_mask,
583
+ recover_batch_dim=False,
584
+ keep_only_attended=self.attention_mask is not None,
585
+ )[0][None, :, :]
586
+
587
+ def recover_text_embeds(
588
+ self,
589
+ hidden_states_out: torch.Tensor,
590
+ hidden_states_in: torch.Tensor,
591
+ update_image_embeddings: bool = False,
592
+ ) -> torch.Tensor:
593
+ """Returns text embeddings from the query buffer, including non-attended tokens at inference"""
594
+ if update_image_embeddings and not self.use_asymetric_qkv:
595
+ raise NotImplementedError("Implement image embeddings updates for asymetric QKV")
596
+ # Remove image tokens in the symetric case
597
+ if not self.use_asymetric_qkv:
598
+ hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]]
599
+
600
+ # if there's not attention mask, we are in the right padded case
601
+ # (keep_only_attended = False) we can directly return the query
602
+ # outputs (which don't contain the image)
603
+ if self.attention_mask is None:
604
+ return hidden_states_out
605
+
606
+ # Otherwise, we need to "scatter" back only the text-attended tokens to the original
607
+ # hidden states, which contain the paddings
608
+ num_queries = hidden_states_in.shape[1]
609
+
610
+ # Case 1: the padded hidden_states_in is larger than hidden_states_out
611
+ # we rebatch+pad hidden_state_out before doing the scattering
612
+ if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]:
613
+ s = torch.split(hidden_states_out, self.batch_lengths, dim=0)
614
+ assert max(_s.shape[0] for _s in s) <= num_queries # sanity check
615
+ s = [
616
+ torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0)
617
+ for _s in s
618
+ ]
619
+ return torch.where(
620
+ self.attention_mask[:, -num_queries:, None],
621
+ torch.stack(s),
622
+ hidden_states_in,
623
+ )
624
+ # If both have the smae shape, it means hidden_states_in contained no padding
625
+ # so we can directly return hidden states out
626
+ return hidden_states_out
627
+
628
+ def extend(self, num_tokens: int, offset: int = 0):
629
+ """Extend all necessary values of the Handler for infenrece
630
+ Note: this implementation curently assumes a single conversation at a time
631
+ (otherwise image tokens mask would have to change) and that tokens added are
632
+ attended to"""
633
+ # image embeds is inserted in the first step and stored in the KV cache
634
+ self.image_embeds = None
635
+
636
+ # Update attention mask (non-flattened) (assumes all new tokens are attended to)
637
+ if self.attention_mask is not None:
638
+ self.attention_mask = torch.nn.functional.pad(
639
+ self.attention_mask, (0, num_tokens), value=1
640
+ )
641
+
642
+ # Update image token mask (assumes only one image/conversation
643
+ # is started at once so that we always extend by zero)
644
+ # Note that the mask is stored flattened to avoid padding so we have to
645
+ # do something a bit ugly and inefficient here
646
+ imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0)
647
+ imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask]
648
+ self.image_tokens_mask = torch.cat(imtokmask, dim=0)
649
+
650
+ # Recompute cumulative document lengths after assigning the new
651
+ # number of tokens to each sample in the batch
652
+ for idx, (ln, is_eob) in enumerate(self.text_sample_lengths):
653
+ if is_eob:
654
+ self.text_sample_lengths[idx] = (num_tokens + ln, is_eob)
655
+ self.full_sample_lengths[idx] += num_tokens
656
+
657
+ # Recompute cu sequlen
658
+ # First step: Technically this never occurs, but we keep it for completeness
659
+ if offset == 0:
660
+ self.max_seqlen_q = max(self.text_sample_lengths)[0]
661
+ self.cu_seqlens_q = self.get_cu_seqlens(
662
+ [x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device
663
+ )
664
+
665
+ self.max_seqlen_kv = max(self.full_sample_lengths)
666
+ self.cu_seqlens_kv = self.get_cu_seqlens(
667
+ self.full_sample_lengths, device=self.cu_seqlens_kv.device
668
+ )
669
+ # Step > 0: the annoying part is since flashattn_varlen does not accept
670
+ # 0-len documents, we need to remove documents from the KV Cache when they're past
671
+ # their windows. In our current setting, this means we only want to keep the latest
672
+ # documents
673
+ else:
674
+ self.max_seqlen_q = num_tokens
675
+ self.cu_seqlens_q = self.get_cu_seqlens(
676
+ [num_tokens for (_, eob) in self.text_sample_lengths if eob],
677
+ device=self.cu_seqlens_q.device,
678
+ )
679
+
680
+ final_doc_lengths = [
681
+ ln
682
+ for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths)
683
+ if eob
684
+ ]
685
+ self.current_doc_lengths = final_doc_lengths
686
+ self.max_seqlen_kv = max(self.current_doc_lengths)
687
+ self.cu_seqlens_kv = self.get_cu_seqlens(
688
+ final_doc_lengths,
689
+ device=self.cu_seqlens_kv.device,
690
+ )
691
+ # Update position embeddings
692
+ if self.rope_fn is not None and self.position_embeds is not None:
693
+ self.position_embeds = self.compute_position_embeddings(
694
+ self.rope_fn,
695
+ self.full_sample_lengths,
696
+ dummy_for_dtype_and_device=self.position_embeds[0],
697
+ )
698
+
699
+
700
+ @dataclass
701
+ class CASAAttentionStreamingState(StreamingState):
702
+ """Streaming State for CASA Atention module. Keep the hidden"""
703
+
704
+ k: torch.Tensor = None # pyright: ignore[reportAssignmentType]
705
+ v: torch.Tensor = None # pyright: ignore[reportAssignmentType]
706
+ recover_batched_trims: list[int] = None # pyright: ignore[reportAssignmentType]
707
+ casa_handler: CASAAttentionHandler = None # pyright: ignore[reportAssignmentType]
708
+
709
+ def maybe_get_casa_handler(
710
+ self,
711
+ casa_handler: CASAAttentionHandler | None,
712
+ is_first_casa_layer: bool = False,
713
+ num_queries: int = -1,
714
+ ) -> CASAAttentionHandler | None:
715
+ # Set given Casa Handler the first time we reach this
716
+ if self.casa_handler is None:
717
+ self.casa_handler = casa_handler # pyright: ignore
718
+ # subsequent calls: we need to extend shape to accomodate new tokens
719
+ # however because CASA handler is shared across layers, we only need to do it once
720
+ if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer:
721
+ # since CasaHandler is shared, we only use its extend step once
722
+ self.casa_handler.extend(num_queries, offset=self.offset)
723
+ return self.casa_handler
724
+
725
+ def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor:
726
+ """Recover batched key/value states with left padding"""
727
+ s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1)
728
+ mlen = max(_s.shape[1] for _s in s)
729
+ # Remember the added padding so that we can re-flatten KV later
730
+ if self.recover_batched_trims is None:
731
+ self.recover_batched_trims = [mlen - _s.shape[1] for _s in s]
732
+ s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s]
733
+ return torch.cat(s, dim=0)
734
+
735
+ def __get_flattened_kv__(
736
+ self, k: torch.Tensor | None = None, v: torch.Tensor | None = None
737
+ ) -> tuple[torch.Tensor, torch.Tensor]:
738
+ """
739
+ Flattened and remove padding to act with flash_attn_func
740
+ """
741
+ k = self.k if k is None else k
742
+ v = self.v if v is None else v
743
+ assert k is not None and v is not None
744
+
745
+ # Since every batch at least contributes one document,
746
+ # we can use this to check whether we are in streaming mode with dropped docs.
747
+ # If so, we should trim the kv cache accordingly
748
+ if len(self.casa_handler.current_doc_lengths) == len(k):
749
+ k = torch.cat(
750
+ [
751
+ _k[self.recover_batched_trims[idx] :][-doc_len:]
752
+ for idx, _k, doc_len in zip(
753
+ range(len(k)), k, self.casa_handler.current_doc_lengths
754
+ )
755
+ ]
756
+ )
757
+ v = torch.cat(
758
+ [
759
+ _v[self.recover_batched_trims[idx] :][-doc_len:]
760
+ for idx, _v, doc_len in zip(
761
+ range(len(k)), v, self.casa_handler.current_doc_lengths
762
+ )
763
+ ]
764
+ )
765
+ return k[None, ...], v[None, ...]
766
+
767
+ k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)])
768
+ v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)])
769
+ return k[None, ...], v[None, ...]
770
+
771
+ def extend_kv(
772
+ self, key_states: torch.Tensor, value_states: torch.Tensor
773
+ ) -> tuple[torch.Tensor, torch.Tensor]:
774
+ """
775
+ Extend KV Cache while keep
776
+ """
777
+ assert self.casa_handler is not None
778
+ if self.k is None and self.v is None:
779
+ # Init with batch-padded key and value states
780
+ self.k = self.__recover_batched_kv__(key_states)
781
+ self.v = self.__recover_batched_kv__(value_states)
782
+ return self.__get_flattened_kv__()
783
+ if self.k is not None and self.v is not None:
784
+ # this is during generation; normally there is no padding at this stage
785
+ # so we can directly reshape the flattened key states
786
+ rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3])
787
+ self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1)
788
+ self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1)
789
+ return self.__get_flattened_kv__()
790
+
791
+ raise ValueError("Impossible configuration (k and v updates are desynchronized )")
792
+
793
+
794
+ class CASAAttention(StreamingModule[CASAAttentionStreamingState]):
795
+ def __init__(
796
+ self,
797
+ config: "PretrainedConfig",
798
+ layer_idx: int | None,
799
+ self_attn: torch.nn.Module | None = None,
800
+ input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
801
+ ):
802
+ super().__init__(CASAAttentionStreamingState)
803
+ self.head_dim = config.head_dim
804
+ self.config = config
805
+
806
+ self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0)
807
+ self.use_delta_w = config.casa_delta_w
808
+
809
+ self.q_proj_casa = self.init_from_config_proj("q", config)
810
+ self.k_proj_casa = self.init_from_config_proj("k", config)
811
+ self.v_proj_casa = self.init_from_config_proj("v", config)
812
+ self.o_proj_casa = self.init_from_config_proj("o", config)
813
+
814
+ # Delta_w
815
+ self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
816
+ self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
817
+ self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
818
+ self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
819
+
820
+ if config.casa_delta_w:
821
+ assert self_attn is not None
822
+ self.set_delta_w(self_attn)
823
+
824
+ # Layer norm
825
+ self.norm_fn: Callable | None = None
826
+ if config.xa_norm_on_images:
827
+ assert input_layernorm_fn is not None
828
+ self.norm_fn = input_layernorm_fn
829
+
830
+ def init_from_mha(self, self_attn: torch.nn.Module):
831
+ assert self_attn is not None
832
+ with torch.no_grad():
833
+ assert hasattr(self_attn, "q_proj")
834
+ for key in ["q", "k", "v", "o"]:
835
+ src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj"))
836
+ tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa"))
837
+ tgt.weight.copy_(src.weight)
838
+ if tgt.bias is not None and src.bias is not None:
839
+ tgt.bias.copy_(src.bias)
840
+
841
+ def set_delta_w(self, self_attn: torch.nn.Module):
842
+ """Delta w setup"""
843
+ self.override_q_proj = delta_w_factory(
844
+ self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj)
845
+ )
846
+ self.override_k_proj = delta_w_factory(
847
+ self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj)
848
+ )
849
+ self.override_v_proj = delta_w_factory(
850
+ self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj)
851
+ )
852
+ self.override_o_proj = delta_w_factory(
853
+ self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj)
854
+ )
855
+
856
+ with torch.no_grad():
857
+ torch.nn.init.zeros_(self.q_proj_casa.weight)
858
+ torch.nn.init.zeros_(self.k_proj_casa.weight)
859
+ torch.nn.init.zeros_(self.v_proj_casa.weight)
860
+ torch.nn.init.zeros_(self.o_proj_casa.weight)
861
+ if self.q_proj_casa.bias is not None:
862
+ torch.nn.init.zeros_(self.q_proj_casa.bias)
863
+ if self.k_proj_casa.bias is not None:
864
+ torch.nn.init.zeros_(self.k_proj_casa.bias)
865
+ if self.v_proj_casa.bias is not None:
866
+ torch.nn.init.zeros_(self.v_proj_casa.bias)
867
+ if self.o_proj_casa.bias is not None:
868
+ torch.nn.init.zeros_(self.o_proj_casa.bias)
869
+
870
+ def init_from_config_proj(
871
+ self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
872
+ ) -> torch.nn.Linear:
873
+ """Initialize the Linear proj in this module"""
874
+ raise NotImplementedError("Abastract class.")
875
+
876
+ def apply_position_embeddings(
877
+ self,
878
+ key: Literal["q", "kv"],
879
+ x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
880
+ casa_handler: CASAAttentionHandler | None,
881
+ num_queries: int = 0,
882
+ unsqueeze_dim: int = 1,
883
+ ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
884
+ """Apply position embeddings to query and key states"""
885
+ raise NotImplementedError("Abastract class.")
886
+
887
+ def forward(
888
+ self,
889
+ hidden_states: torch.Tensor,
890
+ casa_handler: CASAAttentionHandler | None,
891
+ ) -> torch.Tensor | None:
892
+ """Generic forward for CASA uses for instance in `helium1_attention`"""
893
+ og_dtype = hidden_states.dtype
894
+ if self.is_streaming:
895
+ casa_handler = self.streaming_state.maybe_get_casa_handler(
896
+ casa_handler,
897
+ is_first_casa_layer=self.is_first_casa_layer,
898
+ num_queries=hidden_states.shape[1],
899
+ )
900
+
901
+ # Case of text-only samples at training (or inference when no handler was cached)
902
+ # in this case we just skip CASA so we return None (no casa_update)
903
+ if casa_handler is None:
904
+ return None
905
+
906
+ if self.is_streaming:
907
+ assert casa_handler.use_asymetric_qkv, (
908
+ "You should set `use_asymetric_qkv` to True during inference"
909
+ )
910
+
911
+ og_shape = hidden_states.shape
912
+
913
+ # Build Q inputs
914
+ if casa_handler.use_asymetric_qkv:
915
+ q_inputs = hidden_states.flatten(0, 1)[None, ...]
916
+ if casa_handler.attention_mask is not None:
917
+ q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()]
918
+ else:
919
+ q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
920
+
921
+ # Case 1: Training or first inference step
922
+ if not self.is_streaming or self.streaming_state.offset == 0:
923
+ kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
924
+ else:
925
+ # during streaming, the KV cache including image embeddings
926
+ # will be inserted later so for now we only update the incoming queries
927
+ kv_inputs = q_inputs
928
+
929
+ # Compute QKV for the blockwise attention
930
+ bs, total_seq_len = kv_inputs.shape[:2]
931
+ hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim)
932
+ hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim)
933
+
934
+ if self.override_q_proj is None:
935
+ query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q)
936
+ else:
937
+ query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q)
938
+
939
+ if self.override_k_proj is None:
940
+ key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv)
941
+ else:
942
+ key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv)
943
+
944
+ if self.override_v_proj is None:
945
+ value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv)
946
+ else:
947
+ value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv)
948
+
949
+ # Apply position embedding at the right offset
950
+ num_queries = 0
951
+ if self.streaming and self.streaming_state.offset > 0:
952
+ num_queries = og_shape[1]
953
+
954
+ query_states = self.apply_position_embeddings(
955
+ "q", query_states, num_queries=num_queries, casa_handler=casa_handler
956
+ )
957
+ key_states = self.apply_position_embeddings(
958
+ "kv", key_states, num_queries=num_queries, casa_handler=casa_handler
959
+ )
960
+ assert flash_attn_varlen_func is not None, (
961
+ "flash_attention is not installed but required for block-wise attention"
962
+ )
963
+
964
+ # Flashattention has different efficient implem for streaming
965
+ # In that case, the KV cache has to be batched and has been extended
966
+ # to accomodate the shape of ne the new updates
967
+ if self.is_streaming:
968
+ key_states, value_states = self.streaming_state.extend_kv(
969
+ key_states=key_states, value_states=value_states
970
+ )
971
+ if casa_handler.use_asymetric_qkv:
972
+ cu_seqlens_q = casa_handler.cu_seqlens_q
973
+ max_seqlen_q = casa_handler.max_seqlen_q
974
+ else:
975
+ cu_seqlens_q = casa_handler.cu_seqlens_kv
976
+ max_seqlen_q = casa_handler.max_seqlen_kv
977
+ assert cu_seqlens_q[-1] == query_states.shape[1], (
978
+ f"{cu_seqlens_q[-1]} != {query_states.shape[1]}"
979
+ )
980
+ assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], (
981
+ f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}"
982
+ )
983
+ # for quer
984
+ attn_output: torch.Tensor = flash_attn_varlen_func(
985
+ query_states[0].to(torch.bfloat16),
986
+ key_states[0].to(torch.bfloat16),
987
+ value_states[0].to(torch.bfloat16),
988
+ cu_seqlens_q=cu_seqlens_q,
989
+ cu_seqlens_k=casa_handler.cu_seqlens_kv,
990
+ max_seqlen_q=max_seqlen_q,
991
+ max_seqlen_k=casa_handler.max_seqlen_kv,
992
+ dropout_p=0.0,
993
+ # softmax_scale=None, # defaults to 1/sqrt(d)
994
+ causal=True,
995
+ ).to(og_dtype)
996
+
997
+ attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous()
998
+ if self.override_o_proj is None:
999
+ attn_output = self.o_proj_casa(attn_output)
1000
+ else:
1001
+ attn_output = self.override_o_proj(attn_output)
1002
+
1003
+ attn_output = casa_handler.recover_text_embeds(
1004
+ attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds
1005
+ )
1006
+ attn_output = attn_output.reshape(og_shape)
1007
+
1008
+ if self.is_streaming:
1009
+ self.streaming_state.offset += attn_output.shape[1]
1010
+ return attn_output
config.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "attention_dropout": 0.0,
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_helium1_casa.Helium1CASAConfig",
6
+ "AutoModel": "modeling_helium1_casa.V2Helium1"
7
+ },
8
+ "bos_token_id": 1,
9
+ "casa_attention": false,
10
+ "casa_delta_w": true,
11
+ "casa_use_asymetric_qkv": true,
12
+ "casa_windows": "images",
13
+ "eos_token_id": null,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "mask_squash_blockwise": false,
20
+ "max_position_embeddings": 4096,
21
+ "mlp_bias": false,
22
+ "model_type": "Helium1_VL_2B",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 28,
25
+ "num_key_value_heads": 8,
26
+ "pad_token_id": 3,
27
+ "post_image_tokens": [],
28
+ "pre_image_tokens": [],
29
+ "pretraining_tp": 1,
30
+ "rms_norm_eps": 1e-08,
31
+ "rope_scaling": null,
32
+ "rope_theta": 20000.0,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "bfloat16",
35
+ "transformers_version": "4.51.3",
36
+ "use_cache": true,
37
+ "vision_config": {
38
+ "depth": 32,
39
+ "fullatt_block_indexes": [
40
+ 7,
41
+ 15,
42
+ 23,
43
+ 31
44
+ ],
45
+ "hidden_act": "silu",
46
+ "hidden_size": 1280,
47
+ "image_mean": [
48
+ 0.48145466,
49
+ 0.4578275,
50
+ 0.40821073
51
+ ],
52
+ "image_std": [
53
+ 0.26862954,
54
+ 0.26130258,
55
+ 0.27577711
56
+ ],
57
+ "in_channels": 3,
58
+ "in_chans": 3,
59
+ "intermediate_size": 3420,
60
+ "model_type": "qwen2_5_vl",
61
+ "num_heads": 16,
62
+ "out_dim": 2048,
63
+ "out_hidden_size": 2048,
64
+ "patch_size": 14,
65
+ "spatial_merge_size": 2,
66
+ "spatial_patch_size": 14,
67
+ "temporal_patch_size": 1,
68
+ "tokens_per_second": 2,
69
+ "window_size": 112
70
+ },
71
+ "vocab_size": 64000,
72
+ "xa_custom_norm": false,
73
+ "xa_layers": [],
74
+ "xa_norm_on_images": false,
75
+ "xa_order": "ca_first",
76
+ "xa_update_image_embeds": false
77
+ }
configuration_helium1_casa.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
5
+
6
+
7
+ class Helium1CASAConfig(PretrainedConfig):
8
+ r"""
9
+ Helium1 Config augmented with CASA options
10
+
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 32000):
14
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
15
+ `inputs_ids` passed when calling [`Helium1Model`]
16
+ hidden_size (`int`, *optional*, defaults to 4096):
17
+ Dimension of the hidden representations.
18
+ intermediate_size (`int`, *optional*, defaults to 11008):
19
+ Dimension of the MLP representations.
20
+ num_hidden_layers (`int`, *optional*, defaults to 32):
21
+ Number of hidden layers in the Transformer decoder.
22
+ num_attention_heads (`int`, *optional*, defaults to 32):
23
+ Number of attention heads for each attention layer in the Transformer decoder.
24
+ num_key_value_heads (`int`, *optional*):
25
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
26
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
27
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
28
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
29
+ by meanpooling all the original heads within that group. For more details checkout [this
30
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
31
+ `num_attention_heads`.
32
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
33
+ The non-linear activation function (function or string) in the decoder.
34
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
35
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
36
+ Llama 2 up to 4096, CodeLlama up to 16384.
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
40
+ The epsilon used by the rms normalization layers.
41
+ use_cache (`bool`, *optional*, defaults to `True`):
42
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
43
+ relevant if `config.is_decoder=True`.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id.
46
+ bos_token_id (`int`, *optional*, defaults to 1):
47
+ Beginning of stream token id.
48
+ eos_token_id (`int`, *optional*, defaults to 2):
49
+ End of stream token id.
50
+ pretraining_tp (`int`, *optional*, defaults to 1):
51
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
52
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
53
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
54
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
55
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
+ Whether to tie weight embeddings
57
+ rope_theta (`float`, *optional*, defaults to 10000.0):
58
+ The base period of the RoPE embeddings.
59
+ rope_scaling (`Dict`, *optional*):
60
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
61
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
62
+ accordingly.
63
+ Expected contents:
64
+ `rope_type` (`str`):
65
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
66
+ 'llama3'], with 'default' being the original RoPE implementation.
67
+ `factor` (`float`, *optional*):
68
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
69
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
70
+ original maximum pre-trained length.
71
+ `original_max_position_embeddings` (`int`, *optional*):
72
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
73
+ pretraining.
74
+ `attention_factor` (`float`, *optional*):
75
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
77
+ `factor` field to infer the suggested value.
78
+ `beta_fast` (`float`, *optional*):
79
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
+ ramp function. If unspecified, it defaults to 32.
81
+ `beta_slow` (`float`, *optional*):
82
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
+ ramp function. If unspecified, it defaults to 1.
84
+ `short_factor` (`List[float]`, *optional*):
85
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
86
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
87
+ size divided by the number of attention heads divided by 2
88
+ `long_factor` (`List[float]`, *optional*):
89
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
90
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
91
+ size divided by the number of attention heads divided by 2
92
+ `low_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
94
+ `high_freq_factor` (`float`, *optional*):
95
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
96
+ attention_bias (`bool`, *optional*, defaults to `False`):
97
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
98
+ attention_dropout (`float`, *optional*, defaults to 0.0):
99
+ The dropout ratio for the attention probabilities.
100
+ mlp_bias (`bool`, *optional*, defaults to `False`):
101
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
102
+ head_dim (`int`, *optional*):
103
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
104
+
105
+ """
106
+
107
+ model_type = "helium1_casa"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ # Default tensor parallel plan for base model `Helium1Model`
110
+ base_model_tp_plan = {
111
+ "layers.*.self_attn.q_proj": "colwise",
112
+ "layers.*.self_attn.k_proj": "colwise",
113
+ "layers.*.self_attn.v_proj": "colwise",
114
+ "layers.*.self_attn.o_proj": "rowwise",
115
+ "layers.*.mlp.gate_proj": "colwise",
116
+ "layers.*.mlp.up_proj": "colwise",
117
+ "layers.*.mlp.down_proj": "rowwise",
118
+ }
119
+ base_model_pp_plan = { # pyright: ignore[reportAssignmentType]
120
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
121
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
122
+ "norm": (["hidden_states"], ["hidden_states"]),
123
+ }
124
+
125
+ def __init__(
126
+ self,
127
+ vocab_size: int = 32000,
128
+ hidden_size: int = 4096,
129
+ intermediate_size: int = 11008,
130
+ num_hidden_layers: int = 32,
131
+ num_attention_heads: int = 32,
132
+ num_key_value_heads: None | int = None,
133
+ head_dim: None | int = None,
134
+ hidden_act: str = "silu",
135
+ attention_dropout: float = 0.0,
136
+ max_position_embeddings: int = 2048,
137
+ initializer_range: float = 0.02,
138
+ rms_norm_eps: float = 1e-6,
139
+ use_cache: bool = True,
140
+ tie_word_embeddings: bool = False,
141
+ rope_theta: float = 10000.0,
142
+ pad_token_id: int = 3,
143
+ eos_token_id: int = 2,
144
+ bos_token_id: int = 1,
145
+ pretraining_tp: int = 1,
146
+ rope_scaling: None | dict = None,
147
+ attention_bias: bool = False,
148
+ mlp_bias: bool = False,
149
+ # Our fusion mechanisms
150
+ # Common to all fusion mechanisms
151
+ xa_layers: None | tuple = None,
152
+ xa_order: Literal["ca_first", "parallel", "instead"] = "ca_first",
153
+ xa_norm_on_images: bool = False,
154
+ xa_update_image_embeds: bool = False,
155
+ mask_squash_blockwise: bool = False,
156
+ # CASA
157
+ casa_attention: bool = False,
158
+ casa_delta_w: bool = False,
159
+ casa_windows: Literal["batch", "squashed", "images", "turn_based"] = "batch",
160
+ casa_use_asymetric_qkv: bool = True,
161
+ xa_custom_norm: bool = False,
162
+ # Qwen2.5-VL vision config
163
+ vision_config: dict[str, Any] | None = None,
164
+ **kwargs: Any,
165
+ ):
166
+ from transformers.modeling_rope_utils import rope_config_validation
167
+
168
+ self.vocab_size = vocab_size
169
+ self.max_position_embeddings = max_position_embeddings
170
+ self.hidden_size = hidden_size
171
+ self.intermediate_size = intermediate_size
172
+ self.num_hidden_layers = num_hidden_layers
173
+ self.num_attention_heads = num_attention_heads
174
+
175
+ # for backward compatibility
176
+ if num_key_value_heads is None:
177
+ num_key_value_heads = num_attention_heads
178
+
179
+ self.num_key_value_heads = num_key_value_heads
180
+ self.head_dim = (
181
+ head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
182
+ )
183
+ self.hidden_act = hidden_act
184
+ self.initializer_range = initializer_range
185
+ self.rms_norm_eps = rms_norm_eps
186
+ self.pretraining_tp = pretraining_tp
187
+ self.use_cache = use_cache
188
+ self.rope_theta = rope_theta
189
+ self.rope_scaling = rope_scaling
190
+ self.attention_bias = attention_bias
191
+ self.attention_dropout = attention_dropout
192
+ self.mlp_bias = mlp_bias
193
+ # Validate the correctness of rotary position embeddings parameters
194
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
195
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
196
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
197
+ rope_config_validation(self)
198
+
199
+ self.head_dim = self.hidden_size // self.num_attention_heads
200
+ self.xa_layers = xa_layers
201
+ self.xa_order: Literal["ca_first", "parallel", "instead"] = xa_order
202
+ self.xa_norm_on_images = xa_norm_on_images
203
+ self.xa_update_image_embeds = xa_update_image_embeds
204
+ self.mask_squash_blockwise = mask_squash_blockwise
205
+ # CASA config
206
+ self.casa_attention = casa_attention
207
+ self.casa_delta_w = casa_delta_w
208
+ self.casa_windows: Literal["batch", "squashed", "images", "turn_based"] = casa_windows
209
+ self.casa_use_asymetric_qkv = casa_use_asymetric_qkv
210
+ self.xa_custom_norm = xa_custom_norm
211
+
212
+ if vision_config is None:
213
+ vision_config = dict()
214
+ self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
215
+ self.vision_config.temporal_patch_size = 1
216
+ self.vision_config.image_mean = [0.48145466, 0.4578275, 0.40821073]
217
+ self.vision_config.image_std = [0.26862954, 0.26130258, 0.27577711]
218
+ self.vision_config.out_dim = 2048
219
+
220
+ self.pre_image_tokens = []
221
+ self.post_image_tokens = []
222
+
223
+ super().__init__(
224
+ pad_token_id=pad_token_id,
225
+ bos_token_id=bos_token_id,
226
+ eos_token_id=eos_token_id,
227
+ tie_word_embeddings=tie_word_embeddings,
228
+ **kwargs,
229
+ )
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import argparse
234
+ from pathlib import Path
235
+
236
+ import rich
237
+ import yaml
238
+ from transformers.models.auto.configuration_auto import AutoConfig
239
+
240
+ parser = argparse.ArgumentParser()
241
+ parser.add_argument("--out_dir", type=str, default="./saved_config/")
242
+ parser.add_argument(
243
+ "--ckpt_path",
244
+ type=str,
245
+ default="/lustre/scwpod02/client/kyutai/juliette/experiments/finext_casa_896_xtxt_up_b20_64gpu/fdf76e6774",
246
+ )
247
+ args = parser.parse_args()
248
+ path = Path(args.ckpt_path) / "kyuteye_config.yml"
249
+
250
+ helium_config = AutoConfig.from_pretrained("kyutai/helium-1-2b")
251
+ vision_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct").vision_config
252
+
253
+ # 3) Create YOUR config by merging both
254
+ config = Helium1CASAConfig(
255
+ **helium_config.to_dict(), # all helium parameters
256
+ vision_config=vision_config.to_dict(), # override or add vision_config
257
+ )
258
+
259
+ with open(path) as stream:
260
+ kconfig = yaml.safe_load(stream)
261
+
262
+ # print keys that are in kconfig and in config
263
+ for key in set(kconfig.keys()).intersection(set(config.to_dict().keys())):
264
+ rich.print(f"Overwriting [bold green]{key:>50s}[/]: [bold red]{kconfig[key]}")
265
+ setattr(config, key, kconfig[key])
266
+ # TODO: handle casa_own_norm -> xa_custom_norm
267
+ print("Configuration successfully loaded.")
268
+ # Save config to json
269
+ config.save_pretrained(args.out_dir)
270
+ print(f"Configuration saved to {args.out_dir}/config.json")
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "pad_token_id": 3,
5
+ "eos_token_id": [
6
+ 3,
7
+ 103
8
+ ],
9
+ "transformers_version": "4.51.3"
10
+ }
image_encoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen2.5VL encoder with delayed normalization"""
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
6
+ Qwen2_5_VisionTransformerPretrainedModel,
7
+ )
8
+
9
+
10
+ def prepare_for_qwen_encoder(
11
+ x: torch.Tensor | list[torch.Tensor], mean: torch.Tensor, std: torch.Tensor
12
+ ) -> tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Preprocessing for Qwen encoder
15
+ Image mean and std come from processor.image_processor.image_mean and image_std
16
+ """
17
+ grid_thw = torch.Tensor([[1, img.shape[0], img.shape[1]] for img in x]).to(x[0].device)
18
+ hws_flatten_shape = torch.prod(grid_thw, dim=-1)
19
+ x = torch.cat(
20
+ [img.reshape((int(hws_flatten_shape[idx].item()), -1)) for idx, img in enumerate(x)],
21
+ dim=0,
22
+ )
23
+ assert x.min() >= 0.0 and x.max() <= 1.0
24
+ og_shape = x.shape
25
+ x = rearrange(x, "L (c d) -> L c d", c=3)
26
+ x = (x - mean) / std
27
+ x = x.view(og_shape).to(torch.bfloat16)
28
+ return x, grid_thw
29
+
30
+
31
+ class Qwen25VLEncoder(torch.nn.Module):
32
+ """Qwen2.5 VL encoder with pre/post processing to be compatible for
33
+ our CASA attention implementation"""
34
+
35
+ def __init__(
36
+ self,
37
+ visual: "Qwen2_5_VisionTransformerPretrainedModel",
38
+ ):
39
+ super().__init__()
40
+ self.visual = visual
41
+ self.image_mean = torch.tensor(self.visual.config.image_mean).view(1, 3, 1)
42
+ self.image_std = torch.tensor(self.visual.config.image_std).view(1, 3, 1)
43
+
44
+ def forward(
45
+ self, x: torch.Tensor | list[torch.Tensor]
46
+ ) -> dict[str, torch.Tensor | list[torch.Tensor]]:
47
+ x, grid_thw = prepare_for_qwen_encoder(
48
+ x, mean=self.image_mean.to(x[0].device), std=self.image_std.to(x[0].device)
49
+ )
50
+
51
+ grid_thw = grid_thw.type(torch.int)
52
+ assert len(x) == grid_thw.prod(dim=1).sum()
53
+ out = self.visual(x, grid_thw=grid_thw)
54
+
55
+ split_sizes = (grid_thw.prod(dim=-1) // self.visual.spatial_merge_size**2).tolist()
56
+ embeds = list(torch.split(out, split_sizes, dim=0)) # Ni * (seq, C)
57
+ return {"image_embeds": embeds, "grid_thw": grid_thw}
language_helium1_casa.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADAPTED FROM https://github.com/huggingface/transformers/blob/main/src/transformers/models/helium/modeling_helium.py
2
+ # GIT HASH 1b222903c3e1cfd9492d75e4b2548aa8bd458674
3
+ import logging
4
+ import math
5
+ from dataclasses import dataclass
6
+ from functools import partial
7
+ from typing import Any, Callable, Literal, Optional
8
+ from typing import cast as type_cast
9
+
10
+ import torch
11
+ from torch import nn
12
+ from transformers import (
13
+ ROPE_INIT_FUNCTIONS, # pyright: ignore[reportPrivateImportUsage]
14
+ dynamic_rope_update, # pyright: ignore[reportPrivateImportUsage]
15
+ )
16
+ from transformers.activations import ACT2FN
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.generation.utils import GenerationMixin
20
+ from transformers.loss.loss_utils import ForCausalLMLoss
21
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
22
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
23
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils.generic import LossKwargs, can_return_tuple
27
+ from transformers.utils.import_utils import is_torch_flex_attn_available
28
+
29
+ from .casa_attention import CASAAttention, CASAAttentionHandler, insert_image_tokens
30
+ from .configuration_helium1_casa import Helium1CASAConfig
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ if is_torch_flex_attn_available():
35
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
36
+
37
+
38
+ def remove_image_tokens(
39
+ inputs_embeds: torch.Tensor,
40
+ image_tokens_mask: torch.Tensor,
41
+ ) -> torch.Tensor:
42
+ """Remove the image tokens from inputs_embeds as indicated by image_tokens_mask
43
+
44
+ :param inputs_embeds: Tokens of shape (Batch, Seqlen, Dims) containing image tokens
45
+ :param image_tokens_mask: 1-0 mask indicating where image tokens are; (Batch, Seqlen)
46
+
47
+ :return: Tokens tensor of shape (Batch, S' < Seqlen, Dims)
48
+ """
49
+ image_seq_lengths = torch.sum(image_tokens_mask, dim=1)[:, 0]
50
+ image_seq_length = int(image_seq_lengths[0].item())
51
+ assert torch.all(image_seq_lengths == image_seq_length)
52
+ new_shape = (
53
+ inputs_embeds.shape[0],
54
+ inputs_embeds.shape[1] - image_seq_length,
55
+ inputs_embeds.shape[-1],
56
+ )
57
+ tokens = torch.masked_select(
58
+ inputs_embeds,
59
+ torch.logical_not(image_tokens_mask).expand((-1, -1, inputs_embeds.shape[-1])),
60
+ )
61
+ return tokens.reshape(new_shape)
62
+
63
+
64
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
65
+ """
66
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
67
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
68
+ """
69
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
70
+ if n_rep == 1:
71
+ return hidden_states
72
+ hidden_states = hidden_states[:, :, None, :, :].expand(
73
+ batch, num_key_value_heads, n_rep, slen, head_dim
74
+ )
75
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
76
+
77
+
78
+ def eager_attention_forward(
79
+ module: "HeliumAttention",
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ attention_mask: None | torch.Tensor,
84
+ scaling: float,
85
+ dropout: float = 0.0,
86
+ **kwargs: Any,
87
+ ):
88
+ del kwargs # unused
89
+ key_states = repeat_kv(key, module.num_key_value_groups)
90
+ value_states = repeat_kv(value, module.num_key_value_groups)
91
+
92
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
93
+ if attention_mask is not None:
94
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
95
+ attn_weights = attn_weights + causal_mask
96
+
97
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
98
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
99
+ attn_output = torch.matmul(attn_weights, value_states)
100
+ attn_output = attn_output.transpose(1, 2).contiguous()
101
+
102
+ return attn_output, attn_weights
103
+
104
+
105
+ # Different Attention Classes
106
+ class HeliumAttention(torch.nn.Module):
107
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
108
+
109
+ def __init__(self, config: Helium1CASAConfig, layer_idx: None | int = None):
110
+ super().__init__()
111
+ self.config = config
112
+ assert layer_idx is not None
113
+ self.layer_idx: int = layer_idx
114
+
115
+ self.apply_rotary_fn = ApplyRotaryPosEmbHelium1()
116
+ self.head_dim = getattr(
117
+ config, "head_dim", config.hidden_size // config.num_attention_heads
118
+ )
119
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
120
+ self.scaling = 1 / math.sqrt(self.head_dim)
121
+ self.attention_dropout = config.attention_dropout
122
+ self.is_causal = True
123
+
124
+ self.q_proj = nn.Linear(
125
+ config.hidden_size,
126
+ config.num_attention_heads * self.head_dim,
127
+ bias=config.attention_bias,
128
+ )
129
+ self.k_proj = nn.Linear(
130
+ config.hidden_size,
131
+ config.num_key_value_heads * self.head_dim,
132
+ bias=config.attention_bias,
133
+ )
134
+ self.v_proj = nn.Linear(
135
+ config.hidden_size,
136
+ config.num_key_value_heads * self.head_dim,
137
+ bias=config.attention_bias,
138
+ )
139
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
140
+
141
+ def forward(
142
+ self,
143
+ hidden_states: torch.Tensor,
144
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
145
+ attention_mask: None | torch.Tensor,
146
+ past_key_values: None | Cache = None,
147
+ cache_position: None | torch.LongTensor = None,
148
+ **kwargs: Unpack[FlashAttentionKwargs],
149
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
150
+ # del (cache_position, past_key_value) # we use our own generate/caching
151
+ bs, seq_len, _ = hidden_states.shape
152
+ # Get QKV
153
+ hidden_shape = (bs, seq_len, -1, self.head_dim)
154
+
155
+ # Embed Queries
156
+ # Shape: (batch_size, num_heads, seq_len, head_dim)
157
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
158
+ num_queries = query_states.shape[2]
159
+
160
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
161
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
162
+
163
+ # Applies rotation
164
+ cos, sin = position_embeddings
165
+ query_states, key_states = self.apply_rotary_fn(
166
+ query_states, key_states, cos, sin, num_queries=num_queries
167
+ )
168
+ assert key_states is not None and query_states is not None
169
+
170
+ attention_interface: Callable = eager_attention_forward
171
+
172
+ if self.config._attn_implementation != "eager":
173
+ if self.config._attn_implementation == "sdpa" and kwargs.get(
174
+ "output_attentions", False
175
+ ):
176
+ print(
177
+ "`torch.nn.functional.scaled_dot_product_attention` does not support"
178
+ " `output_attentions=True`. Falling back to "
179
+ 'eager attention. This warning can be removed using the argument"\
180
+ " `attn_implementation="eager"` when loading the model.'
181
+ )
182
+ else:
183
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
184
+
185
+ if past_key_values is not None:
186
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
187
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
188
+ key_states, value_states = past_key_values.update(
189
+ key_states, value_states, self.layer_idx, cache_kwargs
190
+ )
191
+ attn_output, attn_weights = attention_interface(
192
+ self,
193
+ query_states,
194
+ key_states,
195
+ value_states,
196
+ attention_mask,
197
+ dropout=0.0 if not self.training else self.attention_dropout,
198
+ scaling=self.scaling,
199
+ **kwargs,
200
+ )
201
+ attn_output = attn_output.reshape(bs, num_queries, -1).contiguous()
202
+ attn_output = self.o_proj(attn_output)
203
+
204
+ assert isinstance(attn_output, torch.Tensor)
205
+ return attn_output, attn_weights
206
+
207
+
208
+ class ApplyRotaryPosEmbHelium1:
209
+ @staticmethod
210
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
211
+ """Rotates half the hidden dims of the input."""
212
+ x1 = x[..., : x.shape[-1] // 2]
213
+ x2 = x[..., x.shape[-1] // 2 :]
214
+ return torch.cat((-x2, x1), dim=-1)
215
+
216
+ @staticmethod
217
+ def __call__(
218
+ q: torch.Tensor,
219
+ k: torch.Tensor,
220
+ cos: torch.Tensor,
221
+ sin: torch.Tensor,
222
+ position_ids: torch.Tensor | None = None,
223
+ unsqueeze_dim: int = 1,
224
+ num_queries: int | None = None,
225
+ ) -> tuple[torch.Tensor, torch.Tensor]:
226
+ """Applies Rotary Position Embedding to the query and key tensors.
227
+
228
+ Args:
229
+ q (`torch.Tensor`): The query tensor.
230
+ k (`torch.Tensor`): The key tensor.
231
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
232
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
233
+ position_ids (`torch.Tensor`, *optional*):
234
+ Deprecated and unused.
235
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
236
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
237
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
238
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
239
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
240
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
241
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
242
+ Returns:
243
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
244
+ """
245
+ del position_ids
246
+ cos = cos.unsqueeze(unsqueeze_dim)
247
+ sin = sin.unsqueeze(unsqueeze_dim)
248
+ if num_queries is None:
249
+ offset = 0
250
+ else:
251
+ offset = -num_queries
252
+
253
+ q_embed = (q * cos[:, :, offset:]) + (
254
+ ApplyRotaryPosEmbHelium1.rotate_half(q) * sin[:, :, offset:]
255
+ )
256
+ k_embed = (k * cos) + (ApplyRotaryPosEmbHelium1.rotate_half(k) * sin)
257
+
258
+ return q_embed, k_embed
259
+
260
+
261
+ class HeliumRotaryEmbedding(nn.Module):
262
+ def __init__(self, config: Helium1CASAConfig, device: None | torch.device | str = None):
263
+ super().__init__()
264
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
265
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
266
+ else:
267
+ self.rope_type = "default"
268
+ self.max_seq_len_cached = config.max_position_embeddings
269
+ self.original_max_seq_len = config.max_position_embeddings
270
+
271
+ self.config = config
272
+ assert self.rope_type in ROPE_INIT_FUNCTIONS, (
273
+ f"Invalid rope type {self.rope_type}. Supported types are: {list(ROPE_INIT_FUNCTIONS.keys())}"
274
+ )
275
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
276
+ inv_freq, self.attention_scaling = self.rope_init_fn(config, device=device)
277
+ self.inv_freq: torch.Tensor # only defined for typing
278
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
279
+ self.original_inv_freq = self.inv_freq
280
+
281
+ @torch.no_grad()
282
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
283
+ def forward(
284
+ self, x: torch.Tensor, position_ids: torch.Tensor
285
+ ) -> tuple[torch.Tensor, torch.Tensor]:
286
+ inv_freq_expanded = (
287
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
288
+ )
289
+ position_ids_expanded = position_ids[:, None, :].float()
290
+
291
+ device_type = (
292
+ x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
293
+ )
294
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
295
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
296
+ emb = torch.cat((freqs, freqs), dim=-1)
297
+ cos = emb.cos() * self.attention_scaling
298
+ sin = emb.sin() * self.attention_scaling
299
+
300
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
301
+
302
+
303
+ class Helium1CASAAttention(CASAAttention):
304
+ """A CASA Attention layer compatible with Qwen"""
305
+
306
+ def __init__(
307
+ self,
308
+ config: Helium1CASAConfig,
309
+ layer_idx: int | None,
310
+ self_attn: torch.nn.Module | None = None,
311
+ input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
312
+ ):
313
+ # Only adding this init for typing purposes for the config
314
+ super().__init__(config, layer_idx, self_attn, input_layernorm_fn) # pyright: ignore[reportArgumentType]
315
+
316
+ @staticmethod
317
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
318
+ """Rotates half the hidden dims of the input."""
319
+ x1 = x[..., : x.shape[-1] // 2]
320
+ x2 = x[..., x.shape[-1] // 2 :]
321
+ return torch.cat((-x2, x1), dim=-1)
322
+
323
+ def apply_position_embeddings(
324
+ self,
325
+ key: Literal["q", "kv"],
326
+ x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
327
+ casa_handler: CASAAttentionHandler | None,
328
+ num_queries: int = 0,
329
+ unsqueeze_dim: int = 1,
330
+ ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
331
+ """Apply position embeddings to query and key states"""
332
+ if casa_handler is not None:
333
+ posemb = casa_handler.get_position_embedding(key, num_queries=num_queries)
334
+
335
+ if posemb is not None:
336
+ x = x.transpose(1, 2).to(torch.float32)
337
+ x = (x * posemb[0].unsqueeze(dim=unsqueeze_dim)) + (
338
+ self.rotate_half(x) * posemb[1].unsqueeze(dim=unsqueeze_dim)
339
+ )
340
+ return x.transpose(1, 2)
341
+ return x
342
+
343
+ def init_from_config_proj(
344
+ self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
345
+ ) -> torch.nn.Linear:
346
+ """Initialize the Linear proj in this module"""
347
+ num_heads = config.num_key_value_heads if key in {"k", "v"} else config.num_attention_heads
348
+ return torch.nn.Linear(
349
+ config.hidden_size,
350
+ num_heads * config.head_dim,
351
+ bias=config.attention_bias if key != "o" else False,
352
+ )
353
+
354
+
355
+ # NORMALISATION LAYER
356
+ def __rms_norm_forward__(
357
+ hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6
358
+ ) -> torch.Tensor:
359
+ input_dtype = hidden_states.dtype
360
+ hidden_states = hidden_states.to(torch.float32)
361
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
362
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
363
+ return weight * hidden_states.to(input_dtype)
364
+
365
+
366
+ class Helium1RMSNorm(nn.Module):
367
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
368
+ """
369
+ Helium1RMSNorm is equivalent to T5LayerNorm
370
+ """
371
+ super().__init__()
372
+ self.weight = nn.Parameter(torch.ones(hidden_size))
373
+ self.variance_epsilon = eps
374
+
375
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
376
+ return __rms_norm_forward__(hidden_states, self.weight, self.variance_epsilon)
377
+
378
+ def extra_repr(self):
379
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
380
+
381
+
382
+ def delta_w_factory_rms_norm(
383
+ org_lin: Helium1RMSNorm, new_lin: Helium1RMSNorm
384
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
385
+ """Factory for building rms norm where the weights are the sum of two layers' weights"""
386
+
387
+ def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
388
+ nonlocal org_lin, new_lin
389
+ return __rms_norm_forward__(
390
+ input, org_lin.weight + new_lin.weight, new_lin.variance_epsilon
391
+ )
392
+
393
+ return _delta_w_fwd
394
+
395
+
396
+ # FULL CONNECTED LAYER
397
+
398
+
399
+ class HeliumMLP(nn.Module):
400
+ def __init__(self, config: Helium1CASAConfig) -> None:
401
+ super().__init__()
402
+ self.config = config
403
+ self.hidden_size = config.hidden_size
404
+ self.intermediate_size = config.intermediate_size
405
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
406
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
407
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
408
+ self.act_fn = ACT2FN[config.hidden_act]
409
+
410
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
411
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
412
+ return down_proj
413
+
414
+
415
+ class HeliumDecoderLayer(nn.Module):
416
+ def __init__(self, config: Helium1CASAConfig, layer_idx: None | int = None):
417
+ super().__init__()
418
+ self.hidden_size = config.hidden_size
419
+ self.config = config
420
+ self.mlp = HeliumMLP(config)
421
+ self.input_layernorm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422
+ self.post_attention_layernorm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
423
+
424
+ # Self-attention
425
+ self.self_attn = HeliumAttention(config=config, layer_idx=layer_idx)
426
+
427
+ # Setup norm for fusion mechanisms; Note that this norm is on the text tokens
428
+ is_xa_layer = layer_idx is None or not config.xa_layers or layer_idx in config.xa_layers
429
+ self.norm_cross: None | Helium1RMSNorm = None
430
+ self.override_norm_cross: Callable[[torch.Tensor], torch.Tensor] | None = None
431
+ if is_xa_layer and config.casa_attention:
432
+ # Custom normalization layer for the extra fusion module
433
+ if self.config.xa_custom_norm:
434
+ self.norm_cross = Helium1RMSNorm(config.hidden_size)
435
+ if config.casa_delta_w:
436
+ self.override_norm_cross = delta_w_factory_rms_norm(
437
+ self.input_layernorm, self.norm_cross
438
+ )
439
+ with torch.no_grad():
440
+ torch.nn.init.ones_(self.norm_cross.weight)
441
+
442
+ # Setup additional norm for images tokens which is set in each individual mechansims
443
+ norm_on_images_fn = (
444
+ None
445
+ if not self.config.xa_norm_on_images
446
+ else self.override_norm_cross
447
+ if self.override_norm_cross is not None
448
+ else self.norm_cross.forward
449
+ if self.norm_cross is not None
450
+ else self.input_layernorm.forward
451
+ )
452
+
453
+ # CASA
454
+ self.casa_attn: Helium1CASAAttention | None = None
455
+ if config.casa_attention and is_xa_layer:
456
+ self.casa_attn = Helium1CASAAttention(
457
+ config, layer_idx, self_attn=self.self_attn, input_layernorm_fn=norm_on_images_fn
458
+ )
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ attention_mask: None | torch.Tensor = None,
464
+ position_ids: None | torch.LongTensor = None,
465
+ past_key_values: None | Cache = None,
466
+ output_attentions: None | bool = False,
467
+ use_cache: None | bool = False,
468
+ cache_position: None | torch.LongTensor = None,
469
+ position_embeddings: None
470
+ | tuple[torch.Tensor, torch.Tensor] = None, # necessary, but kept here for BC
471
+ # CASA
472
+ casa_handler: CASAAttentionHandler | None = None,
473
+ cu_seqlens: torch.Tensor | None = None,
474
+ **kwargs: Unpack[FlashAttentionKwargs],
475
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
476
+ # Image fusion mechanisms
477
+ apply_ca = self.casa_attn is not None
478
+ ca_update: torch.Tensor | None = None
479
+ if (
480
+ self.config.xa_order
481
+ in {
482
+ "parallel",
483
+ "ca_first",
484
+ "instead",
485
+ }
486
+ and apply_ca
487
+ ):
488
+ # Apply layer norm
489
+ assert self.norm_cross is not None
490
+ ca_input = (
491
+ self.override_norm_cross
492
+ if self.override_norm_cross is not None
493
+ else self.norm_cross
494
+ )(hidden_states)
495
+ # CASA
496
+ if self.casa_attn is not None:
497
+ ca_update = self.casa_attn(ca_input, casa_handler=casa_handler)
498
+
499
+ # If we're here, it's because we had proper inputs (no text-only samples)
500
+ # so the output better be not None !
501
+ if ca_update is not None:
502
+ # `instead`: directly return the output of the CA module as residual
503
+ if self.config.xa_order == "instead":
504
+ outputs = (hidden_states + ca_update,)
505
+ if output_attentions:
506
+ outputs += (
507
+ torch.zeros((), device=ca_update.device, dtype=ca_update.dtype),
508
+ )
509
+ return outputs
510
+
511
+ # `ca_first`: update then continue with normal self-attention
512
+ if self.config.xa_order == "ca_first":
513
+ hidden_states = hidden_states + ca_update
514
+ ca_update = None
515
+
516
+ # Self Attention with initial input layer norm
517
+ residual = hidden_states
518
+ hidden_states, self_attn_weights = self.self_attn(
519
+ hidden_states=self.input_layernorm(hidden_states),
520
+ attention_mask=attention_mask,
521
+ position_ids=position_ids,
522
+ past_key_values=past_key_values,
523
+ output_attentions=output_attentions,
524
+ use_cache=use_cache,
525
+ cache_position=cache_position,
526
+ position_embeddings=position_embeddings,
527
+ cu_seqlens=cu_seqlens,
528
+ **kwargs,
529
+ )
530
+ hidden_states = residual + hidden_states
531
+
532
+ # parallel - residual update
533
+ if self.config.xa_order == "parallel" and apply_ca and ca_update is not None:
534
+ hidden_states = hidden_states + ca_update
535
+
536
+ # Fully Connected layer
537
+ residual = hidden_states
538
+ # MLP updates for image embeddings
539
+ if (
540
+ self.config.xa_update_image_embeds
541
+ and self.casa_attn is not None
542
+ and casa_handler is not None
543
+ and casa_handler.image_embeds is not None
544
+ ):
545
+ # Text flattening
546
+ hs = self.post_attention_layernorm(hidden_states).reshape(-1, hidden_states.shape[-1])
547
+ # Image flattening
548
+ img_seq_lengths = [_x.shape[0] for _x in casa_handler.image_embeds]
549
+ img_residual = torch.cat(list(casa_handler.image_embeds), dim=0)
550
+ update = self.mlp(torch.cat([hs, self.post_attention_layernorm(img_residual)], dim=0))
551
+ # update text
552
+ hidden_states = hidden_states + update[: hs.shape[0]].reshape(hidden_states.shape)
553
+ casa_handler.image_embeds = list(
554
+ torch.split(img_residual + update[hs.shape[0] :], img_seq_lengths)
555
+ )
556
+ else:
557
+ hidden_states = self.mlp(self.post_attention_layernorm(hidden_states))
558
+ hidden_states = residual + hidden_states
559
+
560
+ # Outputs
561
+ outputs = (hidden_states,)
562
+ if output_attentions:
563
+ outputs += (self_attn_weights,)
564
+
565
+ return outputs
566
+
567
+
568
+ # FULL HELIUM MODEL
569
+
570
+
571
+ @dataclass
572
+ class CausalHeliumOutput(CausalLMOutputWithPast):
573
+ attention_mask: Optional[torch.Tensor] = None
574
+ num_image_tokens_log: Optional[torch.Tensor] = None
575
+ num_text_tokens_log: Optional[torch.Tensor] = None
576
+
577
+
578
+ class Helium1PreTrainedModel(PreTrainedModel):
579
+ config_class = Helium1CASAConfig
580
+ base_model_prefix = "model"
581
+ supports_gradient_checkpointing = True
582
+ _no_split_modules = ["HeliumDecoderLayer"]
583
+ _skip_keys_device_placement = ["past_key_values"]
584
+ _supports_flash_attn_2 = True
585
+ _supports_sdpa = True
586
+ _supports_flex_attn = True
587
+ _supports_cache_class = True
588
+ _supports_quantized_cache = True
589
+ _supports_static_cache = True
590
+ _supports_attention_backend = True
591
+
592
+ def _init_weights(self, module: torch.nn.Module) -> None:
593
+ std = self.config.initializer_range
594
+ if isinstance(module, nn.Linear):
595
+ module.weight.data.normal_(mean=0.0, std=std)
596
+ if module.bias is not None:
597
+ module.bias.data.zero_()
598
+ elif isinstance(module, nn.Embedding):
599
+ module.weight.data.normal_(mean=0.0, std=std)
600
+ if module.padding_idx is not None:
601
+ module.weight.data[module.padding_idx].zero_()
602
+ elif isinstance(module, Helium1RMSNorm):
603
+ module.weight.data.fill_(1.0)
604
+
605
+
606
+ class Helium1Model(Helium1PreTrainedModel):
607
+ """
608
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
609
+
610
+ Args:
611
+ config: Helium1CASAConfig
612
+ """
613
+
614
+ def __init__(self, config: Helium1CASAConfig):
615
+ Helium1PreTrainedModel.__init__(self, config)
616
+ self.training: bool
617
+ self._gradient_checkpointing_func: Callable
618
+ self.config = config
619
+ self.padding_idx = config.pad_token_id
620
+ self.vocab_size = config.vocab_size
621
+
622
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
623
+ self.layers = nn.ModuleList(
624
+ [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
625
+ )
626
+ self.norm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
627
+ self.rotary_emb = HeliumRotaryEmbedding(config=config)
628
+ self.gradient_checkpointing = False
629
+
630
+ # Initialize weights and apply final processing
631
+ self.post_init()
632
+
633
+ def get_input_embeddings(self):
634
+ return self.embed_tokens
635
+
636
+ def set_input_embeddings(self, value: nn.Module) -> None:
637
+ self.embed_tokens = value
638
+
639
+ @can_return_tuple
640
+ def forward(
641
+ self,
642
+ input_ids: None | torch.LongTensor = None,
643
+ attention_mask: None | torch.Tensor = None,
644
+ position_ids: None | torch.Tensor = None,
645
+ past_key_values: None | DynamicCache = None,
646
+ inputs_embeds: None | torch.Tensor = None,
647
+ use_cache: None | bool = None,
648
+ output_attentions: None | bool = None,
649
+ output_hidden_states: None | bool = None,
650
+ cache_position: None | torch.Tensor = None,
651
+ # Insertion
652
+ image_tokens_mask: torch.Tensor | None = None,
653
+ # CASA
654
+ casa_handler: CASAAttentionHandler | None = None,
655
+ cu_seqlens: torch.Tensor | None = None,
656
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
657
+ ) -> BaseModelOutputWithPast:
658
+ output_attentions = (
659
+ output_attentions if output_attentions is not None else self.config.output_attentions
660
+ )
661
+ output_hidden_states = (
662
+ output_hidden_states
663
+ if output_hidden_states is not None
664
+ else self.config.output_hidden_states
665
+ )
666
+ use_cache = not self.training and (
667
+ use_cache if use_cache is not None else self.config.use_cache
668
+ )
669
+
670
+ if (input_ids is None) ^ (inputs_embeds is not None):
671
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
672
+
673
+ if self.gradient_checkpointing and self.training and use_cache:
674
+ print(
675
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
676
+ )
677
+ use_cache = False
678
+
679
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
680
+ if not isinstance(past_key_values, (type(None), Cache)):
681
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
682
+
683
+ if inputs_embeds is None:
684
+ inputs_embeds = self.embed_tokens(input_ids)
685
+ assert inputs_embeds is not None
686
+
687
+ if use_cache and past_key_values is None:
688
+ past_key_values = DynamicCache()
689
+
690
+ if cache_position is None:
691
+ past_seen_tokens = 0 if past_key_values is None else past_key_values._seen_tokens
692
+ assert inputs_embeds is not None
693
+ cache_position = torch.arange(
694
+ past_seen_tokens,
695
+ past_seen_tokens + inputs_embeds.shape[1],
696
+ device=inputs_embeds.device,
697
+ )
698
+ assert cache_position is not None
699
+
700
+ if position_ids is None:
701
+ position_ids = cache_position.unsqueeze(0)
702
+
703
+ # Get attention mask
704
+ causal_mask: None | torch.Tensor = self._update_causal_mask(
705
+ attention_mask,
706
+ inputs_embeds,
707
+ cache_position,
708
+ past_key_values,
709
+ output_attentions,
710
+ force_mask=False,
711
+ )
712
+
713
+ # create position embeddings to be shared across the decoder layers
714
+ hidden_states = inputs_embeds
715
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
716
+
717
+ # decoder layers
718
+ all_hidden_states = () if output_hidden_states else None
719
+ all_self_attns = () if output_attentions else None
720
+
721
+ for decoder_layer_idx, decoder_layer in enumerate(
722
+ self.layers[: self.config.num_hidden_layers]
723
+ ):
724
+ is_xa_layer = not self.config.xa_layers or decoder_layer_idx in self.config.xa_layers
725
+ if output_hidden_states is not None:
726
+ if all_hidden_states is None:
727
+ all_hidden_states = ()
728
+ all_hidden_states += (hidden_states,)
729
+
730
+ if self.gradient_checkpointing and self.training:
731
+ layer_outputs = self._gradient_checkpointing_func(
732
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
733
+ hidden_states,
734
+ causal_mask,
735
+ position_ids,
736
+ past_key_values,
737
+ output_attentions,
738
+ use_cache,
739
+ cache_position,
740
+ position_embeddings,
741
+ casa_handler if is_xa_layer else None,
742
+ cu_seqlens,
743
+ )
744
+ else:
745
+ layer_outputs = decoder_layer(
746
+ hidden_states,
747
+ attention_mask=causal_mask,
748
+ position_ids=position_ids,
749
+ past_key_values=past_key_values,
750
+ output_attentions=output_attentions,
751
+ use_cache=use_cache,
752
+ cache_position=cache_position,
753
+ position_embeddings=position_embeddings,
754
+ casa_handler=casa_handler if is_xa_layer else None,
755
+ cu_seqlens=cu_seqlens,
756
+ **flash_attn_kwargs,
757
+ )
758
+
759
+ hidden_states = layer_outputs[0]
760
+
761
+ if output_attentions:
762
+ if all_self_attns is None:
763
+ all_self_attns = ()
764
+ all_self_attns += (layer_outputs[1],)
765
+
766
+ hidden_states = self.norm(hidden_states)
767
+
768
+ # add hidden states from the last decoder layer
769
+ if output_hidden_states:
770
+ if all_hidden_states is None:
771
+ all_hidden_states = ()
772
+ all_hidden_states += (hidden_states,)
773
+
774
+ return BaseModelOutputWithPast(
775
+ last_hidden_state=hidden_states,
776
+ past_key_values=past_key_values if use_cache else None, # pyright: ignore[reportArgumentType]
777
+ hidden_states=all_hidden_states, # pyright: ignore[reportArgumentType]
778
+ attentions=all_self_attns,
779
+ )
780
+
781
+ def _update_causal_mask(
782
+ self,
783
+ attention_mask: torch.Tensor | None,
784
+ input_tensor: torch.Tensor,
785
+ cache_position: torch.Tensor,
786
+ past_key_values: None | DynamicCache | Cache,
787
+ output_attentions: bool = False,
788
+ force_mask: bool = False,
789
+ ) -> torch.Tensor | None:
790
+ if self.config._attn_implementation == "flex_attention":
791
+ if isinstance(attention_mask, torch.Tensor):
792
+ attention_mask = make_flex_block_causal_mask(attention_mask) # type: ignore
793
+ return attention_mask
794
+
795
+ assert attention_mask is None or isinstance(attention_mask, torch.Tensor)
796
+ if self.config._attn_implementation == "flash_attention_2":
797
+ if attention_mask is not None and (force_mask or (attention_mask == 0.0).any()):
798
+ return attention_mask
799
+ return None
800
+
801
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
802
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
803
+ # to infer the attention mask.
804
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
805
+ using_compilable_cache = (
806
+ past_key_values.is_compileable if past_key_values is not None else False
807
+ )
808
+
809
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
810
+ if (
811
+ self.config._attn_implementation == "sdpa"
812
+ and not using_compilable_cache
813
+ and not output_attentions
814
+ ):
815
+ if not force_mask and AttentionMaskConverter._ignore_causal_mask_sdpa(
816
+ attention_mask,
817
+ inputs_embeds=input_tensor,
818
+ past_key_values_length=past_seen_tokens,
819
+ is_training=self.training,
820
+ ):
821
+ return None
822
+
823
+ dtype = input_tensor.dtype
824
+ sequence_length = input_tensor.shape[1]
825
+ if using_compilable_cache and past_key_values is not None:
826
+ target_length = past_key_values.get_max_cache_shape()
827
+ else:
828
+ target_length = (
829
+ attention_mask.shape[-1]
830
+ if isinstance(attention_mask, torch.Tensor)
831
+ else past_seen_tokens + sequence_length
832
+ )
833
+
834
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
835
+ assert target_length is not None
836
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
837
+ attention_mask,
838
+ sequence_length=sequence_length,
839
+ target_length=target_length,
840
+ dtype=dtype,
841
+ cache_position=cache_position,
842
+ batch_size=input_tensor.shape[0],
843
+ )
844
+
845
+ if (
846
+ self.config._attn_implementation == "sdpa"
847
+ and attention_mask is not None
848
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
849
+ and not output_attentions
850
+ ):
851
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
852
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
853
+ # Details: https://github.com/pytorch/pytorch/issues/110213
854
+ min_dtype = torch.finfo(dtype).min
855
+ causal_mask = AttentionMaskConverter._unmask_unattended(
856
+ type_cast(torch.FloatTensor, causal_mask), min_dtype
857
+ )
858
+
859
+ return causal_mask
860
+
861
+ @staticmethod
862
+ def _prepare_4d_causal_attention_mask_with_cache_position(
863
+ attention_mask: torch.Tensor | None,
864
+ sequence_length: int,
865
+ target_length: int,
866
+ dtype: torch.dtype,
867
+ cache_position: torch.Tensor,
868
+ batch_size: int,
869
+ **kwargs: Any,
870
+ ):
871
+ """
872
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
873
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
874
+
875
+ Args:
876
+ attention_mask (`torch.Tensor`):
877
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
878
+ `(batch_size, 1, query_length, key_value_length)`.
879
+ sequence_length (`int`):
880
+ The sequence length being processed.
881
+ target_length (`int`):
882
+ The target length: when generating with static cache, the mask should be as long as the static cache,
883
+ to account for the 0 padding, the part of the cache that is not filled yet.
884
+ dtype (`torch.dtype`):
885
+ The dtype to use for the 4D attention mask.
886
+ cache_position (`torch.Tensor`):
887
+ Indices depicting the position of the input sequence tokens in the sequence.
888
+ batch_size (`torch.Tensor`):
889
+ Batch size.
890
+ """
891
+ del kwargs
892
+ if attention_mask is not None and attention_mask.dim() == 4:
893
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
894
+ causal_mask = attention_mask
895
+ else:
896
+ min_dtype = torch.finfo(dtype).min
897
+ causal_mask = torch.full(
898
+ (sequence_length, target_length),
899
+ fill_value=min_dtype,
900
+ dtype=dtype,
901
+ device=cache_position.device,
902
+ )
903
+ if sequence_length != 1:
904
+ causal_mask = torch.triu(causal_mask, diagonal=1)
905
+ causal_mask *= torch.arange(
906
+ target_length, device=cache_position.device
907
+ ) > cache_position.reshape(-1, 1)
908
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
909
+ if attention_mask is not None:
910
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
911
+ mask_length = attention_mask.shape[-1]
912
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
913
+ :, None, None, :
914
+ ].to(causal_mask.device)
915
+ padding_mask = padding_mask == 0
916
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
917
+ padding_mask, min_dtype
918
+ )
919
+
920
+ return causal_mask
921
+
922
+
923
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
924
+
925
+
926
+ class Helium1ForCausalLM(Helium1PreTrainedModel, GenerationMixin):
927
+ _tied_weights_keys = ["lm_head.weight"]
928
+ _tp_plan = {"lm_head": "colwise_rep"}
929
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
930
+
931
+ def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
932
+ del kwargs
933
+ super().__init__(config)
934
+ self.model: Helium1Model
935
+ self.model = Helium1Model(config)
936
+ self.vocab_size = config.vocab_size
937
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
938
+ self._loss_function = ForCausalLMLoss
939
+
940
+ def get_input_embeddings(self) -> nn.Module:
941
+ return self.model.embed_tokens
942
+
943
+ def set_input_embeddings(self, value: nn.Module) -> None:
944
+ self.model.embed_tokens = value
945
+
946
+ def get_output_embeddings(self) -> nn.Module:
947
+ return self.lm_head
948
+
949
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
950
+ self.lm_head = new_embeddings
951
+
952
+ def set_decoder(self, decoder: Helium1Model) -> None:
953
+ self.model = decoder
954
+
955
+ def get_decoder(self) -> Helium1Model:
956
+ return self.model
957
+
958
+ @can_return_tuple
959
+ def forward(
960
+ self,
961
+ input_ids: None | torch.LongTensor = None,
962
+ attention_mask: None | torch.Tensor = None,
963
+ position_ids: None | torch.LongTensor = None,
964
+ past_key_values: None | Cache = None,
965
+ inputs_embeds: None | torch.Tensor = None,
966
+ image_embeds: None | torch.Tensor | list[torch.Tensor] = None,
967
+ image_embeds_insertion_points: None | list[torch.Tensor] = None,
968
+ labels: None | torch.LongTensor = None,
969
+ use_cache: None | bool = None,
970
+ output_attentions: None | bool = None,
971
+ output_hidden_states: None | bool = None,
972
+ cache_position: None | torch.LongTensor = None,
973
+ logits_to_keep: int | torch.Tensor = 0,
974
+ # CASA
975
+ casa_windows_info: None | dict = None,
976
+ **kwargs: Unpack[KwargsForCausalLM],
977
+ ) -> CausalHeliumOutput:
978
+ r"""
979
+ Helium1 augmented with CASA layers
980
+ """
981
+ output_attentions = (
982
+ output_attentions if output_attentions is not None else self.config.output_attentions
983
+ )
984
+ output_hidden_states = (
985
+ output_hidden_states
986
+ if output_hidden_states is not None
987
+ else self.config.output_hidden_states
988
+ )
989
+ if input_ids is not None:
990
+ assert inputs_embeds is None, (
991
+ "Need to provide only one of `input_ids` or `inputs_embeds`."
992
+ )
993
+ inputs_embeds = self.model.embed_tokens(input_ids)
994
+ assert inputs_embeds is not None
995
+
996
+ # Setup image + text token fusion
997
+ bs, og_seq_len, _ = inputs_embeds.shape
998
+ image_tokens_mask: torch.Tensor | None = None
999
+ casa_handler: CASAAttentionHandler | None = None
1000
+
1001
+ num_image_tokens = -1
1002
+ if image_embeds is not None:
1003
+ num_image_tokens = sum(_x.shape[0] for _x in image_embeds)
1004
+ assert image_embeds_insertion_points is not None, (
1005
+ "Missing image embeddings insertion points"
1006
+ )
1007
+ # B1. CASA layers: We need to init the shared Handler
1008
+ if self.model.config.casa_attention:
1009
+ casa_handler = CASAAttentionHandler(
1010
+ # for text tokens, we don't need the actual values
1011
+ inputs_embeds=torch.zeros_like(inputs_embeds),
1012
+ # for image embeddings, we put real inputs as this will be fixed
1013
+ image_embeds=image_embeds,
1014
+ image_embeds_insertion_points=image_embeds_insertion_points,
1015
+ # attention mask is only needed at inference / left padding
1016
+ attention_mask=None if self.training else attention_mask,
1017
+ rope_fn=self.model.rotary_emb,
1018
+ windows=self.model.config.casa_windows,
1019
+ use_asymetric_q_kv=self.model.config.casa_use_asymetric_qkv,
1020
+ # further params are fed to the funtion computing attention
1021
+ casa_windows_info=casa_windows_info,
1022
+ )
1023
+ # B2. Direct image insertion
1024
+ else:
1025
+ inputs_embeds, _, attention_mask, image_tokens_mask = insert_image_tokens(
1026
+ inputs_embeds=inputs_embeds,
1027
+ image_embeds=image_embeds,
1028
+ image_embeds_insertion_points=image_embeds_insertion_points,
1029
+ attention_mask=attention_mask,
1030
+ padding_side="right" if self.training else "left",
1031
+ recover_batch_dim=True,
1032
+ )
1033
+
1034
+ del image_embeds
1035
+ del input_ids
1036
+ outputs: BaseModelOutputWithPast = self.model(
1037
+ inputs_embeds=inputs_embeds,
1038
+ attention_mask=attention_mask,
1039
+ position_ids=position_ids,
1040
+ past_key_values=past_key_values,
1041
+ use_cache=use_cache,
1042
+ output_attentions=output_attentions,
1043
+ output_hidden_states=output_hidden_states,
1044
+ cache_position=cache_position,
1045
+ image_tokens_mask=image_tokens_mask,
1046
+ casa_handler=casa_handler,
1047
+ **kwargs,
1048
+ )
1049
+
1050
+ hidden_states = outputs.last_hidden_state
1051
+ assert hidden_states is not None
1052
+ if image_tokens_mask is not None:
1053
+ hidden_states = remove_image_tokens(hidden_states, image_tokens_mask)
1054
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1055
+ slice_indices = (
1056
+ slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1057
+ )
1058
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1059
+
1060
+ loss = None
1061
+ if labels is not None:
1062
+ loss = self.loss_function(
1063
+ logits=logits,
1064
+ labels=labels,
1065
+ vocab_size=self.config.vocab_size,
1066
+ **kwargs,
1067
+ )
1068
+ out = CausalHeliumOutput(
1069
+ loss=loss,
1070
+ logits=logits,
1071
+ past_key_values=outputs.past_key_values,
1072
+ hidden_states=outputs.hidden_states,
1073
+ attentions=outputs.attentions,
1074
+ num_image_tokens_log=torch.tensor(num_image_tokens).to(logits.device).to(torch.float),
1075
+ num_text_tokens_log=torch.tensor(og_seq_len).to(logits.device).to(torch.float),
1076
+ )
1077
+ return out
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa6d71d108b8e3d968936d7d61e5928a63a8967cb26dd2c88ee942d9c84164a7
3
+ size 4936992240
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1acdddfff9557107c017f04e8e6b0f47244b07a139cf1849b49c0ee983097968
3
+ size 4993844784
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6656ee483ebfc9903325da8a4418a9713429ca6d60c4ce554c3fbdfe953c4b8
3
+ size 836448912
model.safetensors.index.json ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 10767208448
4
+ },
5
+ "weight_map": {
6
+ "image_prefix.enc.visual.blocks.0.attn.proj.bias": "model-00002-of-00003.safetensors",
7
+ "image_prefix.enc.visual.blocks.0.attn.proj.weight": "model-00002-of-00003.safetensors",
8
+ "image_prefix.enc.visual.blocks.0.attn.qkv.bias": "model-00002-of-00003.safetensors",
9
+ "image_prefix.enc.visual.blocks.0.attn.qkv.weight": "model-00002-of-00003.safetensors",
10
+ "image_prefix.enc.visual.blocks.0.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
11
+ "image_prefix.enc.visual.blocks.0.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
12
+ "image_prefix.enc.visual.blocks.0.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
13
+ "image_prefix.enc.visual.blocks.0.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
14
+ "image_prefix.enc.visual.blocks.0.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
15
+ "image_prefix.enc.visual.blocks.0.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
16
+ "image_prefix.enc.visual.blocks.0.norm1.weight": "model-00002-of-00003.safetensors",
17
+ "image_prefix.enc.visual.blocks.0.norm2.weight": "model-00002-of-00003.safetensors",
18
+ "image_prefix.enc.visual.blocks.1.attn.proj.bias": "model-00002-of-00003.safetensors",
19
+ "image_prefix.enc.visual.blocks.1.attn.proj.weight": "model-00002-of-00003.safetensors",
20
+ "image_prefix.enc.visual.blocks.1.attn.qkv.bias": "model-00002-of-00003.safetensors",
21
+ "image_prefix.enc.visual.blocks.1.attn.qkv.weight": "model-00002-of-00003.safetensors",
22
+ "image_prefix.enc.visual.blocks.1.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
23
+ "image_prefix.enc.visual.blocks.1.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
24
+ "image_prefix.enc.visual.blocks.1.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
25
+ "image_prefix.enc.visual.blocks.1.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
26
+ "image_prefix.enc.visual.blocks.1.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
27
+ "image_prefix.enc.visual.blocks.1.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
28
+ "image_prefix.enc.visual.blocks.1.norm1.weight": "model-00002-of-00003.safetensors",
29
+ "image_prefix.enc.visual.blocks.1.norm2.weight": "model-00002-of-00003.safetensors",
30
+ "image_prefix.enc.visual.blocks.10.attn.proj.bias": "model-00002-of-00003.safetensors",
31
+ "image_prefix.enc.visual.blocks.10.attn.proj.weight": "model-00002-of-00003.safetensors",
32
+ "image_prefix.enc.visual.blocks.10.attn.qkv.bias": "model-00002-of-00003.safetensors",
33
+ "image_prefix.enc.visual.blocks.10.attn.qkv.weight": "model-00002-of-00003.safetensors",
34
+ "image_prefix.enc.visual.blocks.10.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
35
+ "image_prefix.enc.visual.blocks.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
36
+ "image_prefix.enc.visual.blocks.10.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
37
+ "image_prefix.enc.visual.blocks.10.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
38
+ "image_prefix.enc.visual.blocks.10.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
39
+ "image_prefix.enc.visual.blocks.10.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
40
+ "image_prefix.enc.visual.blocks.10.norm1.weight": "model-00002-of-00003.safetensors",
41
+ "image_prefix.enc.visual.blocks.10.norm2.weight": "model-00002-of-00003.safetensors",
42
+ "image_prefix.enc.visual.blocks.11.attn.proj.bias": "model-00002-of-00003.safetensors",
43
+ "image_prefix.enc.visual.blocks.11.attn.proj.weight": "model-00002-of-00003.safetensors",
44
+ "image_prefix.enc.visual.blocks.11.attn.qkv.bias": "model-00002-of-00003.safetensors",
45
+ "image_prefix.enc.visual.blocks.11.attn.qkv.weight": "model-00002-of-00003.safetensors",
46
+ "image_prefix.enc.visual.blocks.11.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
47
+ "image_prefix.enc.visual.blocks.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
48
+ "image_prefix.enc.visual.blocks.11.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
49
+ "image_prefix.enc.visual.blocks.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
50
+ "image_prefix.enc.visual.blocks.11.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
51
+ "image_prefix.enc.visual.blocks.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
52
+ "image_prefix.enc.visual.blocks.11.norm1.weight": "model-00002-of-00003.safetensors",
53
+ "image_prefix.enc.visual.blocks.11.norm2.weight": "model-00002-of-00003.safetensors",
54
+ "image_prefix.enc.visual.blocks.12.attn.proj.bias": "model-00002-of-00003.safetensors",
55
+ "image_prefix.enc.visual.blocks.12.attn.proj.weight": "model-00002-of-00003.safetensors",
56
+ "image_prefix.enc.visual.blocks.12.attn.qkv.bias": "model-00002-of-00003.safetensors",
57
+ "image_prefix.enc.visual.blocks.12.attn.qkv.weight": "model-00002-of-00003.safetensors",
58
+ "image_prefix.enc.visual.blocks.12.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
59
+ "image_prefix.enc.visual.blocks.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
60
+ "image_prefix.enc.visual.blocks.12.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
61
+ "image_prefix.enc.visual.blocks.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
62
+ "image_prefix.enc.visual.blocks.12.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
63
+ "image_prefix.enc.visual.blocks.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
64
+ "image_prefix.enc.visual.blocks.12.norm1.weight": "model-00002-of-00003.safetensors",
65
+ "image_prefix.enc.visual.blocks.12.norm2.weight": "model-00002-of-00003.safetensors",
66
+ "image_prefix.enc.visual.blocks.13.attn.proj.bias": "model-00002-of-00003.safetensors",
67
+ "image_prefix.enc.visual.blocks.13.attn.proj.weight": "model-00002-of-00003.safetensors",
68
+ "image_prefix.enc.visual.blocks.13.attn.qkv.bias": "model-00002-of-00003.safetensors",
69
+ "image_prefix.enc.visual.blocks.13.attn.qkv.weight": "model-00002-of-00003.safetensors",
70
+ "image_prefix.enc.visual.blocks.13.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
71
+ "image_prefix.enc.visual.blocks.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
72
+ "image_prefix.enc.visual.blocks.13.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
73
+ "image_prefix.enc.visual.blocks.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
74
+ "image_prefix.enc.visual.blocks.13.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
75
+ "image_prefix.enc.visual.blocks.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
76
+ "image_prefix.enc.visual.blocks.13.norm1.weight": "model-00002-of-00003.safetensors",
77
+ "image_prefix.enc.visual.blocks.13.norm2.weight": "model-00002-of-00003.safetensors",
78
+ "image_prefix.enc.visual.blocks.14.attn.proj.bias": "model-00002-of-00003.safetensors",
79
+ "image_prefix.enc.visual.blocks.14.attn.proj.weight": "model-00002-of-00003.safetensors",
80
+ "image_prefix.enc.visual.blocks.14.attn.qkv.bias": "model-00002-of-00003.safetensors",
81
+ "image_prefix.enc.visual.blocks.14.attn.qkv.weight": "model-00002-of-00003.safetensors",
82
+ "image_prefix.enc.visual.blocks.14.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
83
+ "image_prefix.enc.visual.blocks.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
84
+ "image_prefix.enc.visual.blocks.14.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
85
+ "image_prefix.enc.visual.blocks.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
86
+ "image_prefix.enc.visual.blocks.14.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
87
+ "image_prefix.enc.visual.blocks.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
88
+ "image_prefix.enc.visual.blocks.14.norm1.weight": "model-00002-of-00003.safetensors",
89
+ "image_prefix.enc.visual.blocks.14.norm2.weight": "model-00002-of-00003.safetensors",
90
+ "image_prefix.enc.visual.blocks.15.attn.proj.bias": "model-00002-of-00003.safetensors",
91
+ "image_prefix.enc.visual.blocks.15.attn.proj.weight": "model-00002-of-00003.safetensors",
92
+ "image_prefix.enc.visual.blocks.15.attn.qkv.bias": "model-00002-of-00003.safetensors",
93
+ "image_prefix.enc.visual.blocks.15.attn.qkv.weight": "model-00002-of-00003.safetensors",
94
+ "image_prefix.enc.visual.blocks.15.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
95
+ "image_prefix.enc.visual.blocks.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
96
+ "image_prefix.enc.visual.blocks.15.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
97
+ "image_prefix.enc.visual.blocks.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
98
+ "image_prefix.enc.visual.blocks.15.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
99
+ "image_prefix.enc.visual.blocks.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
100
+ "image_prefix.enc.visual.blocks.15.norm1.weight": "model-00002-of-00003.safetensors",
101
+ "image_prefix.enc.visual.blocks.15.norm2.weight": "model-00002-of-00003.safetensors",
102
+ "image_prefix.enc.visual.blocks.16.attn.proj.bias": "model-00002-of-00003.safetensors",
103
+ "image_prefix.enc.visual.blocks.16.attn.proj.weight": "model-00002-of-00003.safetensors",
104
+ "image_prefix.enc.visual.blocks.16.attn.qkv.bias": "model-00002-of-00003.safetensors",
105
+ "image_prefix.enc.visual.blocks.16.attn.qkv.weight": "model-00002-of-00003.safetensors",
106
+ "image_prefix.enc.visual.blocks.16.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
107
+ "image_prefix.enc.visual.blocks.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
108
+ "image_prefix.enc.visual.blocks.16.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
109
+ "image_prefix.enc.visual.blocks.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
110
+ "image_prefix.enc.visual.blocks.16.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
111
+ "image_prefix.enc.visual.blocks.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
112
+ "image_prefix.enc.visual.blocks.16.norm1.weight": "model-00002-of-00003.safetensors",
113
+ "image_prefix.enc.visual.blocks.16.norm2.weight": "model-00002-of-00003.safetensors",
114
+ "image_prefix.enc.visual.blocks.17.attn.proj.bias": "model-00002-of-00003.safetensors",
115
+ "image_prefix.enc.visual.blocks.17.attn.proj.weight": "model-00002-of-00003.safetensors",
116
+ "image_prefix.enc.visual.blocks.17.attn.qkv.bias": "model-00002-of-00003.safetensors",
117
+ "image_prefix.enc.visual.blocks.17.attn.qkv.weight": "model-00002-of-00003.safetensors",
118
+ "image_prefix.enc.visual.blocks.17.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
119
+ "image_prefix.enc.visual.blocks.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
120
+ "image_prefix.enc.visual.blocks.17.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
121
+ "image_prefix.enc.visual.blocks.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
122
+ "image_prefix.enc.visual.blocks.17.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
123
+ "image_prefix.enc.visual.blocks.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
124
+ "image_prefix.enc.visual.blocks.17.norm1.weight": "model-00002-of-00003.safetensors",
125
+ "image_prefix.enc.visual.blocks.17.norm2.weight": "model-00002-of-00003.safetensors",
126
+ "image_prefix.enc.visual.blocks.18.attn.proj.bias": "model-00002-of-00003.safetensors",
127
+ "image_prefix.enc.visual.blocks.18.attn.proj.weight": "model-00002-of-00003.safetensors",
128
+ "image_prefix.enc.visual.blocks.18.attn.qkv.bias": "model-00002-of-00003.safetensors",
129
+ "image_prefix.enc.visual.blocks.18.attn.qkv.weight": "model-00002-of-00003.safetensors",
130
+ "image_prefix.enc.visual.blocks.18.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
131
+ "image_prefix.enc.visual.blocks.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
132
+ "image_prefix.enc.visual.blocks.18.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
133
+ "image_prefix.enc.visual.blocks.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
134
+ "image_prefix.enc.visual.blocks.18.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
135
+ "image_prefix.enc.visual.blocks.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
136
+ "image_prefix.enc.visual.blocks.18.norm1.weight": "model-00002-of-00003.safetensors",
137
+ "image_prefix.enc.visual.blocks.18.norm2.weight": "model-00002-of-00003.safetensors",
138
+ "image_prefix.enc.visual.blocks.19.attn.proj.bias": "model-00002-of-00003.safetensors",
139
+ "image_prefix.enc.visual.blocks.19.attn.proj.weight": "model-00002-of-00003.safetensors",
140
+ "image_prefix.enc.visual.blocks.19.attn.qkv.bias": "model-00002-of-00003.safetensors",
141
+ "image_prefix.enc.visual.blocks.19.attn.qkv.weight": "model-00002-of-00003.safetensors",
142
+ "image_prefix.enc.visual.blocks.19.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
143
+ "image_prefix.enc.visual.blocks.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
144
+ "image_prefix.enc.visual.blocks.19.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
145
+ "image_prefix.enc.visual.blocks.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
146
+ "image_prefix.enc.visual.blocks.19.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
147
+ "image_prefix.enc.visual.blocks.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
148
+ "image_prefix.enc.visual.blocks.19.norm1.weight": "model-00002-of-00003.safetensors",
149
+ "image_prefix.enc.visual.blocks.19.norm2.weight": "model-00002-of-00003.safetensors",
150
+ "image_prefix.enc.visual.blocks.2.attn.proj.bias": "model-00002-of-00003.safetensors",
151
+ "image_prefix.enc.visual.blocks.2.attn.proj.weight": "model-00002-of-00003.safetensors",
152
+ "image_prefix.enc.visual.blocks.2.attn.qkv.bias": "model-00002-of-00003.safetensors",
153
+ "image_prefix.enc.visual.blocks.2.attn.qkv.weight": "model-00002-of-00003.safetensors",
154
+ "image_prefix.enc.visual.blocks.2.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
155
+ "image_prefix.enc.visual.blocks.2.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
156
+ "image_prefix.enc.visual.blocks.2.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
157
+ "image_prefix.enc.visual.blocks.2.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
158
+ "image_prefix.enc.visual.blocks.2.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
159
+ "image_prefix.enc.visual.blocks.2.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
160
+ "image_prefix.enc.visual.blocks.2.norm1.weight": "model-00002-of-00003.safetensors",
161
+ "image_prefix.enc.visual.blocks.2.norm2.weight": "model-00002-of-00003.safetensors",
162
+ "image_prefix.enc.visual.blocks.20.attn.proj.bias": "model-00002-of-00003.safetensors",
163
+ "image_prefix.enc.visual.blocks.20.attn.proj.weight": "model-00002-of-00003.safetensors",
164
+ "image_prefix.enc.visual.blocks.20.attn.qkv.bias": "model-00002-of-00003.safetensors",
165
+ "image_prefix.enc.visual.blocks.20.attn.qkv.weight": "model-00002-of-00003.safetensors",
166
+ "image_prefix.enc.visual.blocks.20.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
167
+ "image_prefix.enc.visual.blocks.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
168
+ "image_prefix.enc.visual.blocks.20.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
169
+ "image_prefix.enc.visual.blocks.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
170
+ "image_prefix.enc.visual.blocks.20.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
171
+ "image_prefix.enc.visual.blocks.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
172
+ "image_prefix.enc.visual.blocks.20.norm1.weight": "model-00002-of-00003.safetensors",
173
+ "image_prefix.enc.visual.blocks.20.norm2.weight": "model-00002-of-00003.safetensors",
174
+ "image_prefix.enc.visual.blocks.21.attn.proj.bias": "model-00002-of-00003.safetensors",
175
+ "image_prefix.enc.visual.blocks.21.attn.proj.weight": "model-00002-of-00003.safetensors",
176
+ "image_prefix.enc.visual.blocks.21.attn.qkv.bias": "model-00002-of-00003.safetensors",
177
+ "image_prefix.enc.visual.blocks.21.attn.qkv.weight": "model-00002-of-00003.safetensors",
178
+ "image_prefix.enc.visual.blocks.21.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
179
+ "image_prefix.enc.visual.blocks.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
180
+ "image_prefix.enc.visual.blocks.21.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
181
+ "image_prefix.enc.visual.blocks.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
182
+ "image_prefix.enc.visual.blocks.21.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
183
+ "image_prefix.enc.visual.blocks.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
184
+ "image_prefix.enc.visual.blocks.21.norm1.weight": "model-00002-of-00003.safetensors",
185
+ "image_prefix.enc.visual.blocks.21.norm2.weight": "model-00002-of-00003.safetensors",
186
+ "image_prefix.enc.visual.blocks.22.attn.proj.bias": "model-00002-of-00003.safetensors",
187
+ "image_prefix.enc.visual.blocks.22.attn.proj.weight": "model-00002-of-00003.safetensors",
188
+ "image_prefix.enc.visual.blocks.22.attn.qkv.bias": "model-00002-of-00003.safetensors",
189
+ "image_prefix.enc.visual.blocks.22.attn.qkv.weight": "model-00002-of-00003.safetensors",
190
+ "image_prefix.enc.visual.blocks.22.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
191
+ "image_prefix.enc.visual.blocks.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
192
+ "image_prefix.enc.visual.blocks.22.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
193
+ "image_prefix.enc.visual.blocks.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
194
+ "image_prefix.enc.visual.blocks.22.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
195
+ "image_prefix.enc.visual.blocks.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
196
+ "image_prefix.enc.visual.blocks.22.norm1.weight": "model-00002-of-00003.safetensors",
197
+ "image_prefix.enc.visual.blocks.22.norm2.weight": "model-00002-of-00003.safetensors",
198
+ "image_prefix.enc.visual.blocks.23.attn.proj.bias": "model-00003-of-00003.safetensors",
199
+ "image_prefix.enc.visual.blocks.23.attn.proj.weight": "model-00003-of-00003.safetensors",
200
+ "image_prefix.enc.visual.blocks.23.attn.qkv.bias": "model-00002-of-00003.safetensors",
201
+ "image_prefix.enc.visual.blocks.23.attn.qkv.weight": "model-00002-of-00003.safetensors",
202
+ "image_prefix.enc.visual.blocks.23.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
203
+ "image_prefix.enc.visual.blocks.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
204
+ "image_prefix.enc.visual.blocks.23.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
205
+ "image_prefix.enc.visual.blocks.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
206
+ "image_prefix.enc.visual.blocks.23.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
207
+ "image_prefix.enc.visual.blocks.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
208
+ "image_prefix.enc.visual.blocks.23.norm1.weight": "model-00002-of-00003.safetensors",
209
+ "image_prefix.enc.visual.blocks.23.norm2.weight": "model-00002-of-00003.safetensors",
210
+ "image_prefix.enc.visual.blocks.24.attn.proj.bias": "model-00003-of-00003.safetensors",
211
+ "image_prefix.enc.visual.blocks.24.attn.proj.weight": "model-00003-of-00003.safetensors",
212
+ "image_prefix.enc.visual.blocks.24.attn.qkv.bias": "model-00003-of-00003.safetensors",
213
+ "image_prefix.enc.visual.blocks.24.attn.qkv.weight": "model-00003-of-00003.safetensors",
214
+ "image_prefix.enc.visual.blocks.24.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
215
+ "image_prefix.enc.visual.blocks.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
216
+ "image_prefix.enc.visual.blocks.24.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
217
+ "image_prefix.enc.visual.blocks.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
218
+ "image_prefix.enc.visual.blocks.24.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
219
+ "image_prefix.enc.visual.blocks.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
220
+ "image_prefix.enc.visual.blocks.24.norm1.weight": "model-00003-of-00003.safetensors",
221
+ "image_prefix.enc.visual.blocks.24.norm2.weight": "model-00003-of-00003.safetensors",
222
+ "image_prefix.enc.visual.blocks.25.attn.proj.bias": "model-00003-of-00003.safetensors",
223
+ "image_prefix.enc.visual.blocks.25.attn.proj.weight": "model-00003-of-00003.safetensors",
224
+ "image_prefix.enc.visual.blocks.25.attn.qkv.bias": "model-00003-of-00003.safetensors",
225
+ "image_prefix.enc.visual.blocks.25.attn.qkv.weight": "model-00003-of-00003.safetensors",
226
+ "image_prefix.enc.visual.blocks.25.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
227
+ "image_prefix.enc.visual.blocks.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
228
+ "image_prefix.enc.visual.blocks.25.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
229
+ "image_prefix.enc.visual.blocks.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
230
+ "image_prefix.enc.visual.blocks.25.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
231
+ "image_prefix.enc.visual.blocks.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
232
+ "image_prefix.enc.visual.blocks.25.norm1.weight": "model-00003-of-00003.safetensors",
233
+ "image_prefix.enc.visual.blocks.25.norm2.weight": "model-00003-of-00003.safetensors",
234
+ "image_prefix.enc.visual.blocks.26.attn.proj.bias": "model-00003-of-00003.safetensors",
235
+ "image_prefix.enc.visual.blocks.26.attn.proj.weight": "model-00003-of-00003.safetensors",
236
+ "image_prefix.enc.visual.blocks.26.attn.qkv.bias": "model-00003-of-00003.safetensors",
237
+ "image_prefix.enc.visual.blocks.26.attn.qkv.weight": "model-00003-of-00003.safetensors",
238
+ "image_prefix.enc.visual.blocks.26.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
239
+ "image_prefix.enc.visual.blocks.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
240
+ "image_prefix.enc.visual.blocks.26.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
241
+ "image_prefix.enc.visual.blocks.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
242
+ "image_prefix.enc.visual.blocks.26.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
243
+ "image_prefix.enc.visual.blocks.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
244
+ "image_prefix.enc.visual.blocks.26.norm1.weight": "model-00003-of-00003.safetensors",
245
+ "image_prefix.enc.visual.blocks.26.norm2.weight": "model-00003-of-00003.safetensors",
246
+ "image_prefix.enc.visual.blocks.27.attn.proj.bias": "model-00003-of-00003.safetensors",
247
+ "image_prefix.enc.visual.blocks.27.attn.proj.weight": "model-00003-of-00003.safetensors",
248
+ "image_prefix.enc.visual.blocks.27.attn.qkv.bias": "model-00003-of-00003.safetensors",
249
+ "image_prefix.enc.visual.blocks.27.attn.qkv.weight": "model-00003-of-00003.safetensors",
250
+ "image_prefix.enc.visual.blocks.27.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
251
+ "image_prefix.enc.visual.blocks.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "image_prefix.enc.visual.blocks.27.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
253
+ "image_prefix.enc.visual.blocks.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
254
+ "image_prefix.enc.visual.blocks.27.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
255
+ "image_prefix.enc.visual.blocks.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
256
+ "image_prefix.enc.visual.blocks.27.norm1.weight": "model-00003-of-00003.safetensors",
257
+ "image_prefix.enc.visual.blocks.27.norm2.weight": "model-00003-of-00003.safetensors",
258
+ "image_prefix.enc.visual.blocks.28.attn.proj.bias": "model-00003-of-00003.safetensors",
259
+ "image_prefix.enc.visual.blocks.28.attn.proj.weight": "model-00003-of-00003.safetensors",
260
+ "image_prefix.enc.visual.blocks.28.attn.qkv.bias": "model-00003-of-00003.safetensors",
261
+ "image_prefix.enc.visual.blocks.28.attn.qkv.weight": "model-00003-of-00003.safetensors",
262
+ "image_prefix.enc.visual.blocks.28.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
263
+ "image_prefix.enc.visual.blocks.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
264
+ "image_prefix.enc.visual.blocks.28.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
265
+ "image_prefix.enc.visual.blocks.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
266
+ "image_prefix.enc.visual.blocks.28.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
267
+ "image_prefix.enc.visual.blocks.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
268
+ "image_prefix.enc.visual.blocks.28.norm1.weight": "model-00003-of-00003.safetensors",
269
+ "image_prefix.enc.visual.blocks.28.norm2.weight": "model-00003-of-00003.safetensors",
270
+ "image_prefix.enc.visual.blocks.29.attn.proj.bias": "model-00003-of-00003.safetensors",
271
+ "image_prefix.enc.visual.blocks.29.attn.proj.weight": "model-00003-of-00003.safetensors",
272
+ "image_prefix.enc.visual.blocks.29.attn.qkv.bias": "model-00003-of-00003.safetensors",
273
+ "image_prefix.enc.visual.blocks.29.attn.qkv.weight": "model-00003-of-00003.safetensors",
274
+ "image_prefix.enc.visual.blocks.29.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
275
+ "image_prefix.enc.visual.blocks.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
276
+ "image_prefix.enc.visual.blocks.29.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
277
+ "image_prefix.enc.visual.blocks.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
278
+ "image_prefix.enc.visual.blocks.29.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
279
+ "image_prefix.enc.visual.blocks.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
280
+ "image_prefix.enc.visual.blocks.29.norm1.weight": "model-00003-of-00003.safetensors",
281
+ "image_prefix.enc.visual.blocks.29.norm2.weight": "model-00003-of-00003.safetensors",
282
+ "image_prefix.enc.visual.blocks.3.attn.proj.bias": "model-00002-of-00003.safetensors",
283
+ "image_prefix.enc.visual.blocks.3.attn.proj.weight": "model-00002-of-00003.safetensors",
284
+ "image_prefix.enc.visual.blocks.3.attn.qkv.bias": "model-00002-of-00003.safetensors",
285
+ "image_prefix.enc.visual.blocks.3.attn.qkv.weight": "model-00002-of-00003.safetensors",
286
+ "image_prefix.enc.visual.blocks.3.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
287
+ "image_prefix.enc.visual.blocks.3.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
288
+ "image_prefix.enc.visual.blocks.3.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
289
+ "image_prefix.enc.visual.blocks.3.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
290
+ "image_prefix.enc.visual.blocks.3.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
291
+ "image_prefix.enc.visual.blocks.3.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
292
+ "image_prefix.enc.visual.blocks.3.norm1.weight": "model-00002-of-00003.safetensors",
293
+ "image_prefix.enc.visual.blocks.3.norm2.weight": "model-00002-of-00003.safetensors",
294
+ "image_prefix.enc.visual.blocks.30.attn.proj.bias": "model-00003-of-00003.safetensors",
295
+ "image_prefix.enc.visual.blocks.30.attn.proj.weight": "model-00003-of-00003.safetensors",
296
+ "image_prefix.enc.visual.blocks.30.attn.qkv.bias": "model-00003-of-00003.safetensors",
297
+ "image_prefix.enc.visual.blocks.30.attn.qkv.weight": "model-00003-of-00003.safetensors",
298
+ "image_prefix.enc.visual.blocks.30.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
299
+ "image_prefix.enc.visual.blocks.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
300
+ "image_prefix.enc.visual.blocks.30.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
301
+ "image_prefix.enc.visual.blocks.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
302
+ "image_prefix.enc.visual.blocks.30.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
303
+ "image_prefix.enc.visual.blocks.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
304
+ "image_prefix.enc.visual.blocks.30.norm1.weight": "model-00003-of-00003.safetensors",
305
+ "image_prefix.enc.visual.blocks.30.norm2.weight": "model-00003-of-00003.safetensors",
306
+ "image_prefix.enc.visual.blocks.31.attn.proj.bias": "model-00003-of-00003.safetensors",
307
+ "image_prefix.enc.visual.blocks.31.attn.proj.weight": "model-00003-of-00003.safetensors",
308
+ "image_prefix.enc.visual.blocks.31.attn.qkv.bias": "model-00003-of-00003.safetensors",
309
+ "image_prefix.enc.visual.blocks.31.attn.qkv.weight": "model-00003-of-00003.safetensors",
310
+ "image_prefix.enc.visual.blocks.31.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
311
+ "image_prefix.enc.visual.blocks.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
312
+ "image_prefix.enc.visual.blocks.31.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
313
+ "image_prefix.enc.visual.blocks.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
314
+ "image_prefix.enc.visual.blocks.31.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
315
+ "image_prefix.enc.visual.blocks.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
316
+ "image_prefix.enc.visual.blocks.31.norm1.weight": "model-00003-of-00003.safetensors",
317
+ "image_prefix.enc.visual.blocks.31.norm2.weight": "model-00003-of-00003.safetensors",
318
+ "image_prefix.enc.visual.blocks.4.attn.proj.bias": "model-00002-of-00003.safetensors",
319
+ "image_prefix.enc.visual.blocks.4.attn.proj.weight": "model-00002-of-00003.safetensors",
320
+ "image_prefix.enc.visual.blocks.4.attn.qkv.bias": "model-00002-of-00003.safetensors",
321
+ "image_prefix.enc.visual.blocks.4.attn.qkv.weight": "model-00002-of-00003.safetensors",
322
+ "image_prefix.enc.visual.blocks.4.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
323
+ "image_prefix.enc.visual.blocks.4.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
324
+ "image_prefix.enc.visual.blocks.4.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
325
+ "image_prefix.enc.visual.blocks.4.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
326
+ "image_prefix.enc.visual.blocks.4.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
327
+ "image_prefix.enc.visual.blocks.4.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
328
+ "image_prefix.enc.visual.blocks.4.norm1.weight": "model-00002-of-00003.safetensors",
329
+ "image_prefix.enc.visual.blocks.4.norm2.weight": "model-00002-of-00003.safetensors",
330
+ "image_prefix.enc.visual.blocks.5.attn.proj.bias": "model-00002-of-00003.safetensors",
331
+ "image_prefix.enc.visual.blocks.5.attn.proj.weight": "model-00002-of-00003.safetensors",
332
+ "image_prefix.enc.visual.blocks.5.attn.qkv.bias": "model-00002-of-00003.safetensors",
333
+ "image_prefix.enc.visual.blocks.5.attn.qkv.weight": "model-00002-of-00003.safetensors",
334
+ "image_prefix.enc.visual.blocks.5.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
335
+ "image_prefix.enc.visual.blocks.5.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
336
+ "image_prefix.enc.visual.blocks.5.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
337
+ "image_prefix.enc.visual.blocks.5.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
338
+ "image_prefix.enc.visual.blocks.5.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
339
+ "image_prefix.enc.visual.blocks.5.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
340
+ "image_prefix.enc.visual.blocks.5.norm1.weight": "model-00002-of-00003.safetensors",
341
+ "image_prefix.enc.visual.blocks.5.norm2.weight": "model-00002-of-00003.safetensors",
342
+ "image_prefix.enc.visual.blocks.6.attn.proj.bias": "model-00002-of-00003.safetensors",
343
+ "image_prefix.enc.visual.blocks.6.attn.proj.weight": "model-00002-of-00003.safetensors",
344
+ "image_prefix.enc.visual.blocks.6.attn.qkv.bias": "model-00002-of-00003.safetensors",
345
+ "image_prefix.enc.visual.blocks.6.attn.qkv.weight": "model-00002-of-00003.safetensors",
346
+ "image_prefix.enc.visual.blocks.6.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
347
+ "image_prefix.enc.visual.blocks.6.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
348
+ "image_prefix.enc.visual.blocks.6.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
349
+ "image_prefix.enc.visual.blocks.6.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
350
+ "image_prefix.enc.visual.blocks.6.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
351
+ "image_prefix.enc.visual.blocks.6.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
352
+ "image_prefix.enc.visual.blocks.6.norm1.weight": "model-00002-of-00003.safetensors",
353
+ "image_prefix.enc.visual.blocks.6.norm2.weight": "model-00002-of-00003.safetensors",
354
+ "image_prefix.enc.visual.blocks.7.attn.proj.bias": "model-00002-of-00003.safetensors",
355
+ "image_prefix.enc.visual.blocks.7.attn.proj.weight": "model-00002-of-00003.safetensors",
356
+ "image_prefix.enc.visual.blocks.7.attn.qkv.bias": "model-00002-of-00003.safetensors",
357
+ "image_prefix.enc.visual.blocks.7.attn.qkv.weight": "model-00002-of-00003.safetensors",
358
+ "image_prefix.enc.visual.blocks.7.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
359
+ "image_prefix.enc.visual.blocks.7.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
360
+ "image_prefix.enc.visual.blocks.7.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
361
+ "image_prefix.enc.visual.blocks.7.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
362
+ "image_prefix.enc.visual.blocks.7.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
363
+ "image_prefix.enc.visual.blocks.7.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
364
+ "image_prefix.enc.visual.blocks.7.norm1.weight": "model-00002-of-00003.safetensors",
365
+ "image_prefix.enc.visual.blocks.7.norm2.weight": "model-00002-of-00003.safetensors",
366
+ "image_prefix.enc.visual.blocks.8.attn.proj.bias": "model-00002-of-00003.safetensors",
367
+ "image_prefix.enc.visual.blocks.8.attn.proj.weight": "model-00002-of-00003.safetensors",
368
+ "image_prefix.enc.visual.blocks.8.attn.qkv.bias": "model-00002-of-00003.safetensors",
369
+ "image_prefix.enc.visual.blocks.8.attn.qkv.weight": "model-00002-of-00003.safetensors",
370
+ "image_prefix.enc.visual.blocks.8.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
371
+ "image_prefix.enc.visual.blocks.8.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
372
+ "image_prefix.enc.visual.blocks.8.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
373
+ "image_prefix.enc.visual.blocks.8.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
374
+ "image_prefix.enc.visual.blocks.8.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
375
+ "image_prefix.enc.visual.blocks.8.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
376
+ "image_prefix.enc.visual.blocks.8.norm1.weight": "model-00002-of-00003.safetensors",
377
+ "image_prefix.enc.visual.blocks.8.norm2.weight": "model-00002-of-00003.safetensors",
378
+ "image_prefix.enc.visual.blocks.9.attn.proj.bias": "model-00002-of-00003.safetensors",
379
+ "image_prefix.enc.visual.blocks.9.attn.proj.weight": "model-00002-of-00003.safetensors",
380
+ "image_prefix.enc.visual.blocks.9.attn.qkv.bias": "model-00002-of-00003.safetensors",
381
+ "image_prefix.enc.visual.blocks.9.attn.qkv.weight": "model-00002-of-00003.safetensors",
382
+ "image_prefix.enc.visual.blocks.9.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
383
+ "image_prefix.enc.visual.blocks.9.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
384
+ "image_prefix.enc.visual.blocks.9.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
385
+ "image_prefix.enc.visual.blocks.9.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
386
+ "image_prefix.enc.visual.blocks.9.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
387
+ "image_prefix.enc.visual.blocks.9.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
388
+ "image_prefix.enc.visual.blocks.9.norm1.weight": "model-00002-of-00003.safetensors",
389
+ "image_prefix.enc.visual.blocks.9.norm2.weight": "model-00002-of-00003.safetensors",
390
+ "image_prefix.enc.visual.merger.ln_q.weight": "model-00003-of-00003.safetensors",
391
+ "image_prefix.enc.visual.merger.mlp.0.bias": "model-00003-of-00003.safetensors",
392
+ "image_prefix.enc.visual.merger.mlp.0.weight": "model-00003-of-00003.safetensors",
393
+ "image_prefix.enc.visual.merger.mlp.2.bias": "model-00003-of-00003.safetensors",
394
+ "image_prefix.enc.visual.merger.mlp.2.weight": "model-00003-of-00003.safetensors",
395
+ "image_prefix.enc.visual.patch_embed.proj.weight": "model-00002-of-00003.safetensors",
396
+ "image_prefix.norm_extra.weight": "model-00003-of-00003.safetensors",
397
+ "lm_head.weight": "model-00002-of-00003.safetensors",
398
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
399
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
400
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
401
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
402
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
403
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
404
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
405
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
406
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
407
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
408
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
409
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
410
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
411
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
412
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
413
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
414
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
415
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
416
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
417
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
418
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
419
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
420
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
421
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
422
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
423
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
424
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
425
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
426
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
427
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
428
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
429
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
430
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
431
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
432
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
433
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
434
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
435
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00003.safetensors",
436
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
437
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
438
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
439
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
440
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
441
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
442
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
443
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
444
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00003.safetensors",
445
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
446
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
447
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
448
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
449
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
450
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
451
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
452
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
453
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00003.safetensors",
454
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
455
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
456
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
457
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
458
+ "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
459
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
460
+ "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
461
+ "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
462
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00003.safetensors",
463
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
464
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
465
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
466
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
467
+ "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
468
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
469
+ "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
470
+ "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
471
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00003.safetensors",
472
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
473
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
474
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
475
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
476
+ "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
477
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
478
+ "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
479
+ "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
480
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
481
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
482
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
483
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
484
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
485
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
486
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
487
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
488
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
489
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
490
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
491
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
492
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
493
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
494
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
495
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
496
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
497
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
498
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
499
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
500
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
501
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
502
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
503
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
504
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
505
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
506
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
507
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
508
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
509
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
510
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
511
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
512
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
513
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
514
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
515
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
516
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
517
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
518
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
519
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
520
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
521
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
522
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
523
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
524
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
525
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
526
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
527
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
528
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
529
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
530
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
531
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
532
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
533
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
534
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
535
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
536
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
537
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
538
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
539
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
540
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
541
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
542
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
543
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
544
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
545
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
546
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
547
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
548
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
549
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
550
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
551
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
552
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
553
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
554
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
555
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
556
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
557
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
558
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
559
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
560
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
561
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
562
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
563
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
564
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
565
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
566
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
567
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
568
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
569
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
570
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00003.safetensors",
571
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
572
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
573
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
574
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
575
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
576
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
577
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
578
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
579
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00003.safetensors",
580
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
581
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
582
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
583
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
584
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
585
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
586
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
587
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
588
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
589
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
590
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
591
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
592
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
593
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
594
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
595
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
596
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
597
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
598
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
599
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
600
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
601
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
602
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
603
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
604
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
605
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
606
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
607
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
608
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
609
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
610
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
611
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
612
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
613
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
614
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
615
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
616
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
617
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
618
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
619
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
620
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
621
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
622
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
623
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
624
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
625
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
626
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
627
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
628
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
629
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
630
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
631
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
632
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
633
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
634
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
635
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
636
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
637
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
638
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
639
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
640
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
641
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
642
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
643
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
644
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
645
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
646
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
647
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
648
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
649
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
650
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
651
+ "model.norm.weight": "model-00002-of-00003.safetensors"
652
+ }
653
+ }
modeling_helium1_casa.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable
2
+ from typing import cast as type_cast
3
+
4
+ import torch
5
+ from transformers.cache_utils import DynamicCache
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.generation.utils import GenerateOutput
8
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
9
+ Qwen2_5_VisionTransformerPretrainedModel,
10
+ )
11
+
12
+ from .image_encoder import Qwen25VLEncoder
13
+ from .configuration_helium1_casa import Helium1CASAConfig
14
+ from .language_helium1_casa import (
15
+ CausalHeliumOutput,
16
+ Helium1CASAAttention,
17
+ Helium1ForCausalLM,
18
+ Helium1RMSNorm,
19
+ )
20
+
21
+
22
+ def meta_project(
23
+ logits: torch.Tensor | list[torch.Tensor],
24
+ projector: torch.nn.Module,
25
+ norm: torch.nn.Module | None = None,
26
+ ) -> torch.Tensor | list[torch.Tensor]:
27
+ """Projection operation that handles both tensors and list of tensors
28
+
29
+ Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where
30
+ S can be a different sequence length per image)
31
+ """
32
+ split_sizes: list[int] | None = None
33
+ if not isinstance(logits, torch.Tensor):
34
+ split_sizes = [_x.shape[0] for _x in logits]
35
+ logits = torch.cat(logits, dim=0)[None, :, :]
36
+ logits = type_cast(torch.Tensor, logits)
37
+ logits = projector(logits)
38
+
39
+ assert isinstance(logits, torch.Tensor)
40
+ if norm is not None:
41
+ logits = norm(logits)
42
+ if split_sizes is not None:
43
+ return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0))
44
+ return logits
45
+
46
+
47
+ class ImageProjection(torch.nn.Module):
48
+ """Takes in a batch or sequence of images and returns embeddings
49
+ which are then fed to the LM.
50
+
51
+ :param config: KyuteyeConfig object
52
+ :param lm_model_dim: Output dimension (number of channels) for this module
53
+ """
54
+
55
+ def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None:
56
+ super().__init__()
57
+ self.config = config
58
+ self.out_dim = lm_model_dim
59
+ visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
60
+
61
+ self.enc = Qwen25VLEncoder(visual=visual)
62
+ # Projection layer
63
+ self.proj_extra = self.init_proj_module()
64
+ # Output normalizations
65
+ self.norm_extra = Helium1RMSNorm(self.out_dim)
66
+
67
+ def init_proj_module(self) -> torch.nn.Module:
68
+ """Init the project module for the inserted and/or cross-attended image tokens"""
69
+ if self.config.vision_config.out_dim == self.out_dim:
70
+ return torch.nn.Identity()
71
+ return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim)
72
+
73
+ def forward(
74
+ self, x: torch.Tensor | list[torch.Tensor]
75
+ ) -> dict[
76
+ str,
77
+ torch.Tensor | list[torch.Tensor],
78
+ ]:
79
+ """Image embedding mapping
80
+
81
+ :param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors
82
+ with shape (C, H, W) (or (H, W, C) in the case of Qwen)
83
+
84
+ :return: Either a tensor with shape (num_total_image, S, D) or, if images
85
+ can have different seq length, a list of `num_total_images` Tensors with shape
86
+ (S, D)
87
+ """
88
+
89
+ # Apply image encoder
90
+ og_dtype = x[0].dtype
91
+ encoded = self.enc(x)["image_embeds"]
92
+ encoded = [_x.to(og_dtype) for _x in encoded]
93
+ if all(x.shape[0] == encoded[0].shape[0] for x in encoded):
94
+ encoded = torch.stack(encoded, dim=0)
95
+
96
+ # Extra projection
97
+ image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra)
98
+
99
+ # Apply different projection for extra vs cross attended tokens
100
+ return {"image_embeds": image_embeds}
101
+
102
+
103
+ class V2Helium1(Helium1ForCausalLM): # pyright: ignore[reportIncompatibleMethodOverride]
104
+ config_class = Helium1CASAConfig
105
+
106
+ def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
107
+ del kwargs
108
+ super().__init__(config)
109
+ self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim)
110
+
111
+ def get_device(self) -> str:
112
+ """Return the device type of the model"""
113
+ return next(self.parameters()).device.type
114
+
115
+ @property
116
+ def token_dim(self) -> int:
117
+ """Returns the number of dimensions for the token representation"""
118
+ return self.config.hidden_size
119
+
120
+ @property
121
+ def rotary_embed(self) -> Callable:
122
+ """Returns the rotary embedding function of the underlying model"""
123
+ return self.model.rotary_emb
124
+
125
+ def _update_model_kwargs_for_generation(
126
+ self,
127
+ outputs: Any,
128
+ model_kwargs: dict[str, Any],
129
+ is_encoder_decoder: bool = False,
130
+ num_new_tokens: int = 1,
131
+ ):
132
+ """This is required to handle multiple gen calls for subtitles"""
133
+ # Call parent to get default updates
134
+ model_kwargs = super()._update_model_kwargs_for_generation(
135
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
136
+ )
137
+ # Used by prepare_inputs_for_generation
138
+ model_kwargs["__is_first_gen_call__"] = False
139
+ return model_kwargs
140
+
141
+ def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
142
+ self,
143
+ input_ids: torch.Tensor,
144
+ past_key_values: DynamicCache | None = None,
145
+ **kwargs: Any,
146
+ ):
147
+ __is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True)
148
+ if past_key_values is not None and (
149
+ kwargs.get("cache_position") is None
150
+ or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
151
+ ):
152
+ # We're continuing from a cached state
153
+ past_length = past_key_values._seen_tokens
154
+ kwargs["cache_position"] = torch.arange(
155
+ past_length,
156
+ past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
157
+ dtype=torch.long,
158
+ device=input_ids.device,
159
+ )
160
+
161
+ return super().prepare_inputs_for_generation(
162
+ type_cast(torch.LongTensor, input_ids),
163
+ past_key_values=past_key_values,
164
+ **kwargs,
165
+ )
166
+
167
+ def prepare_multimodal_inputs(
168
+ self,
169
+ # text only training
170
+ input_ids: torch.Tensor | None = None,
171
+ inputs_embeds: torch.Tensor | None = None,
172
+ attention_mask: torch.Tensor | None = None,
173
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
174
+ labels: torch.Tensor | None = None,
175
+ # image values
176
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
177
+ pre_image_tokens: list[int] | None = None,
178
+ post_image_tokens: list[int] | None = None,
179
+ **_kwargs: Any,
180
+ ) -> dict:
181
+ """Get a batch data mixing text and image data"""
182
+ del _kwargs
183
+
184
+ processed_inputs = {
185
+ "input_ids": input_ids,
186
+ "inputs_embeds": inputs_embeds,
187
+ "labels": labels,
188
+ "attention_mask": attention_mask,
189
+ "image_embeds_insertion_points": image_embeds_insertion_points,
190
+ }
191
+ if pixel_values is not None:
192
+ processed_inputs.update(self.image_prefix(pixel_values))
193
+ assert "image_embeds" in processed_inputs
194
+ assert (
195
+ isinstance(processed_inputs["image_embeds"], torch.Tensor)
196
+ and processed_inputs["image_embeds"].ndim == 3
197
+ ) or (
198
+ isinstance(processed_inputs["image_embeds"], list)
199
+ and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
200
+ )
201
+
202
+ # Add kwargs necessary to compute cu_seqlens windows for CASA
203
+ processed_inputs["casa_windows_info"] = {
204
+ "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
205
+ "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
206
+ }
207
+
208
+ return processed_inputs
209
+
210
+ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
211
+ self,
212
+ input_ids: torch.Tensor | None = None,
213
+ inputs_embeds: torch.Tensor | None = None,
214
+ attention_mask: torch.Tensor | None = None,
215
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
216
+ return_loss: bool = True,
217
+ labels: torch.Tensor | None = None,
218
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
219
+ pre_image_tokens: list[int] | None = None,
220
+ post_image_tokens: list[int] | None = None,
221
+ **kwargs: Any,
222
+ ) -> CausalHeliumOutput:
223
+ """Multi modal forward pass"""
224
+ assert input_ids is not None or inputs_embeds is not None
225
+
226
+ if self.training:
227
+ assert return_loss is True, (
228
+ "Helium models always compute its own labels/losses in train mode"
229
+ )
230
+
231
+ # Case 1: For first generation call we need to compute pixel values and CASA states
232
+ if kwargs.get("__is_first_gen_call__", True):
233
+ processed_inputs = self.prepare_multimodal_inputs(
234
+ input_ids=input_ids,
235
+ inputs_embeds=inputs_embeds,
236
+ attention_mask=attention_mask,
237
+ image_embeds_insertion_points=image_embeds_insertion_points,
238
+ pixel_values=pixel_values,
239
+ labels=labels,
240
+ pre_image_tokens=pre_image_tokens,
241
+ post_image_tokens=post_image_tokens,
242
+ )
243
+ processed_inputs.pop("inputs_embeds", None)
244
+ else:
245
+ processed_inputs = {
246
+ "inputs_embeds": self.model.embed_tokens(input_ids),
247
+ "attention_mask": attention_mask,
248
+ }
249
+
250
+ # For Helium prefix, we need to update the positions by the number
251
+ # of image tokens inserted in the first call
252
+ if (
253
+ not self.config.casa_attention
254
+ and (cp := kwargs.get("cache_position", None)) is not None
255
+ and pixel_values is not None
256
+ ):
257
+ start = kwargs["cache_position"][0].item()
258
+ num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4
259
+ num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # type: ignore
260
+ kwargs["cache_position"] = torch.arange(
261
+ start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens),
262
+ start + num_tokens + num_image_tokens,
263
+ dtype=cp.dtype,
264
+ device=cp.device,
265
+ )
266
+
267
+ kwargs.pop("__is_first_gen_call__", True)
268
+ out = super().forward(
269
+ **processed_inputs, # type: ignore
270
+ **kwargs,
271
+ )
272
+
273
+ return out
274
+
275
+ @torch.no_grad()
276
+ def generate_from_image( # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride]
277
+ self,
278
+ input_ids: torch.Tensor | None = None,
279
+ inputs_embeds: torch.Tensor | None = None,
280
+ attention_mask: torch.Tensor | None = None,
281
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
282
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
283
+ reset_streaming: bool = True,
284
+ **kwargs: Any,
285
+ ) -> "GenerateOutput | torch.LongTensor":
286
+ assert input_ids is not None and inputs_embeds is None, (
287
+ "Input IDs must be provided for generation"
288
+ )
289
+
290
+ # init self-attention KVCache
291
+ if kwargs.get("past_key_values", None) is None:
292
+ kwargs["past_key_values"] = DynamicCache()
293
+
294
+ # To avoid generate warning
295
+ if kwargs.get("pad_token_id", None) is None:
296
+ kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
297
+ if isinstance(kwargs["pad_token_id"], (list, tuple)):
298
+ kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
299
+
300
+ self.start_casa_streaming_states()
301
+ outputs = self.generate(
302
+ input_ids,
303
+ attention_mask=attention_mask,
304
+ pixel_values=pixel_values,
305
+ image_embeds_insertion_points=image_embeds_insertion_points,
306
+ use_cache=True,
307
+ **kwargs,
308
+ )
309
+ if reset_streaming:
310
+ self.reset_casa_streaming_states()
311
+ return outputs
312
+
313
+ def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
314
+ def __reset__(m: torch.nn.Module):
315
+ if isinstance(m, Helium1CASAAttention):
316
+ m._set_streaming(False, ())
317
+ m.reset_streaming()
318
+ if clean_cache:
319
+ del m.streaming_state.k
320
+ del m.streaming_state.v
321
+ del m.streaming_state.casa_handler
322
+
323
+ self.apply(__reset__)
324
+
325
+ def start_casa_streaming_states(self) -> None:
326
+ def __start__(m: torch.nn.Module):
327
+ if isinstance(m, Helium1CASAAttention):
328
+ m._set_streaming(True, ())
329
+
330
+ self.apply(__start__)
processing.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=no-member # avoid weird pylint warnings from SentencePieceProcessor
2
+ """Text and Image processor for CASA models using Qwen2.5_VL image encoder"""
3
+
4
+ from math import ceil
5
+ from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload
6
+ from typing import cast as type_cast
7
+
8
+ import torch
9
+ import torchvision.transforms.v2 as T
10
+ from einops import rearrange
11
+ from PIL import Image
12
+ from torchvision.transforms import InterpolationMode
13
+ from torchvision.transforms.functional import to_tensor as pil_to_tensor
14
+ from torchvision.transforms.v2 import functional as F
15
+ from transformers.image_processing_utils import BaseImageProcessor
16
+ from transformers.processing_utils import ProcessorMixin
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+
22
+
23
+ ImageMessage = TypedDict(
24
+ "ImageMessage",
25
+ {
26
+ "type": Literal["image"],
27
+ "image": str | Image.Image | None,
28
+ },
29
+ )
30
+
31
+ TextMessage = TypedDict(
32
+ "TextMessage",
33
+ {
34
+ "type": Literal["text"],
35
+ "text": str,
36
+ },
37
+ )
38
+
39
+ MessageContent = list[ImageMessage | TextMessage]
40
+
41
+ Message = TypedDict(
42
+ "Message",
43
+ {
44
+ "role": Literal["system", "user", "assistant"],
45
+ "content": MessageContent,
46
+ },
47
+ )
48
+
49
+ ProcessorInput = list[list[Message]] | list[Message]
50
+
51
+ __INTERP_NAME_TO_MODE__ = {
52
+ "nearest": InterpolationMode.NEAREST,
53
+ "bilinear": InterpolationMode.BILINEAR,
54
+ "bicubic": InterpolationMode.BICUBIC,
55
+ "lanczos": InterpolationMode.LANCZOS,
56
+ }
57
+
58
+ __INTERP_INT_TO_MODE__ = {
59
+ 0: InterpolationMode.NEAREST,
60
+ 2: InterpolationMode.BILINEAR,
61
+ 3: InterpolationMode.BICUBIC,
62
+ 4: InterpolationMode.BOX,
63
+ 5: InterpolationMode.HAMMING,
64
+ 1: InterpolationMode.LANCZOS,
65
+ }
66
+
67
+
68
+ @overload
69
+ def universal_resize(
70
+ img: Image.Image,
71
+ size: tuple[int, int],
72
+ interpolation: str | InterpolationMode | int = "bilinear",
73
+ antialias: bool = True,
74
+ ) -> Image.Image: ...
75
+ @overload
76
+ def universal_resize(
77
+ img: torch.Tensor,
78
+ size: tuple[int, int],
79
+ interpolation: str | InterpolationMode | int = "bilinear",
80
+ antialias: bool = True,
81
+ ) -> torch.Tensor: ...
82
+ def universal_resize(
83
+ img: Image.Image | torch.Tensor,
84
+ size: tuple[int, int],
85
+ interpolation: str | InterpolationMode | int = "bilinear",
86
+ antialias: bool = True,
87
+ ) -> Image.Image | torch.Tensor:
88
+ """Resize that works for PIL.Image, CHW tensor, or BCHW tensor"""
89
+ if isinstance(interpolation, str):
90
+ interpolation = __INTERP_NAME_TO_MODE__[interpolation]
91
+ elif isinstance(interpolation, int):
92
+ interpolation = __INTERP_INT_TO_MODE__[interpolation]
93
+
94
+ return F.resize(
95
+ img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias
96
+ )
97
+
98
+
99
+ @overload
100
+ def convert_to_rgb(img: Image.Image) -> Image.Image: ...
101
+ @overload
102
+ def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ...
103
+ def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor:
104
+ """Convert any image to RGB in a way that does not throw PIL warning"""
105
+ if isinstance(img, torch.Tensor):
106
+ return img
107
+ if img.mode == "RGB": # no changes
108
+ return img
109
+ if img.mode == "P": # palette images need to be converted to RGBA first
110
+ return img.convert("RGBA").convert("RGB")
111
+ return img.convert("RGB")
112
+
113
+
114
+ class QwenImageProcessor(BaseImageProcessor):
115
+ """Resizing for the Qwen2.5VL encoder. Note that the normalization is
116
+ handled in the image_encoder in the model forward"""
117
+
118
+ def __init__(
119
+ self,
120
+ img_size: int = 448,
121
+ interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic",
122
+ max_ratio: int = 10,
123
+ round_to_patch_size: int = 56,
124
+ use_fast: bool = True,
125
+ **kwargs: Any,
126
+ ) -> None:
127
+ # this will also be used in V2llms to determine whether to remove
128
+ # the temporal conv
129
+ self._num_target_channels = 588
130
+ self._merge_size = 2
131
+ self._patch_size = 14
132
+ super().__init__(
133
+ use_fast=use_fast,
134
+ do_normalize=False,
135
+ **kwargs,
136
+ )
137
+ self.img_size = img_size
138
+ self.interpolation = interpolation
139
+ self.max_ratio = max_ratio
140
+ self.round_to_patch_size = round_to_patch_size
141
+
142
+ def resize_transform(
143
+ self, img: Image.Image | torch.Tensor, img_size: int | None = None
144
+ ) -> Image.Image | torch.Tensor:
145
+ if img_size is None:
146
+ img_size = self.img_size
147
+ max_area = img_size**2
148
+ if isinstance(img, Image.Image):
149
+ img = convert_to_rgb(img)
150
+ w_og, h_og = img.size
151
+ else:
152
+ h_og, w_og = img.shape[-2:]
153
+ w, h = w_og, h_og
154
+
155
+ # Qwen requires max ratio of 10 between max and min sizes
156
+ if self.max_ratio > 0:
157
+ w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio)
158
+
159
+ # resize to max area
160
+ current_area = w * h
161
+ if current_area > max_area:
162
+ scale = (max_area / current_area) ** 0.5
163
+ w, h = int(w * scale), int(h * scale)
164
+
165
+ # resize to patch size
166
+ if self.round_to_patch_size > 0:
167
+ w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size
168
+ h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size
169
+
170
+ # resize
171
+ if w != w_og or h != h_og:
172
+ img = universal_resize(img, (h, w), self.interpolation)
173
+ if isinstance(img, torch.Tensor):
174
+ img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img))
175
+ return img
176
+
177
+ def __process_one__(
178
+ self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None
179
+ ) -> torch.Tensor:
180
+ """Same operation as __process_one_with_processor__ but without going through numpy"""
181
+ video_or_img = self.resize_transform(video_or_img, img_size)
182
+ if isinstance(video_or_img, Image.Image):
183
+ video_or_img = pil_to_tensor(video_or_img)
184
+ assert isinstance(video_or_img, torch.Tensor)
185
+ if video_or_img.ndim == 3:
186
+ video_or_img = video_or_img[None]
187
+ assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, (
188
+ f"Invalid shape {video_or_img.shape}."
189
+ )
190
+ t, c, h, w = video_or_img.shape
191
+ p = self._patch_size
192
+ m = self._merge_size
193
+
194
+ # Convert to RGB
195
+ if c == 1:
196
+ video_or_img = video_or_img.expand((-1, 3, -1, -1))
197
+ if c == 4:
198
+ video_or_img = video_or_img[:, :3]
199
+ c = video_or_img.shape[1]
200
+ assert c == 3, "Expecting RGB image in QwenNormalize"
201
+
202
+ # Reshape to t h w c' format
203
+ h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p
204
+ rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m)
205
+
206
+ video_or_img = rearrange(
207
+ video_or_img,
208
+ "t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)",
209
+ **rearrange_dict,
210
+ )
211
+ assert video_or_img.shape[-1] == self._num_target_channels, (
212
+ f"{video_or_img.shape[-1]} != {self._num_target_channels}"
213
+ )
214
+ video_or_img = video_or_img.view((-1, h, w, self._num_target_channels))
215
+
216
+ return video_or_img
217
+
218
+ @overload
219
+ def process_images(
220
+ self, image: Image.Image | torch.Tensor, img_size: int | None = None
221
+ ) -> torch.Tensor: ...
222
+ @overload
223
+ def process_images(
224
+ self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None
225
+ ) -> list[torch.Tensor]: ...
226
+ def process_images(
227
+ self,
228
+ image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor],
229
+ img_size: int | None = None,
230
+ ) -> torch.Tensor | list[torch.Tensor]:
231
+ if isinstance(image, list):
232
+ return [self.__process_one__(_x, img_size) for _x in image]
233
+ return self.__process_one__(image, img_size)
234
+
235
+
236
+ class ProcessorOutput(dict):
237
+ input_ids: torch.Tensor
238
+ attention_mask: torch.Tensor
239
+ image_embeds_insertion_points: list[torch.Tensor] | None
240
+ pixel_values: torch.Tensor | list[torch.Tensor] | None
241
+
242
+ def to(
243
+ self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16
244
+ ) -> "ProcessorOutput":
245
+ return ProcessorOutput(
246
+ {
247
+ "input_ids": self["input_ids"].to(device),
248
+ "attention_mask": self["attention_mask"].to(device),
249
+ "image_embeds_insertion_points": self["image_embeds_insertion_points"],
250
+ "pixel_values": (
251
+ self["pixel_values"].to(dtype).to(device)
252
+ if isinstance(self["pixel_values"], torch.Tensor)
253
+ else [x.to(dtype).to(device) for x in self["pixel_values"]]
254
+ if self["pixel_values"] is not None
255
+ else None
256
+ ),
257
+ }
258
+ )
259
+
260
+
261
+ class BaseProcessor(ProcessorMixin):
262
+ def __init__(
263
+ self,
264
+ tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer",
265
+ pre_image_tokens: tuple[int, ...] = (),
266
+ post_image_tokens: tuple[int, ...] = (),
267
+ system_start_tokens: tuple[int, ...] = (),
268
+ system_end_tokens: tuple[int, ...] = (),
269
+ user_start_tokens: tuple[int, ...] = (),
270
+ user_end_tokens: tuple[int, ...] = (),
271
+ asst_start_tokens: tuple[int, ...] = (),
272
+ asst_end_tokens: tuple[int, ...] = (),
273
+ allow_system_prompt: bool = True,
274
+ pad_token: int = 0,
275
+ bos_token: int | None = None,
276
+ ) -> None:
277
+ self.pre_image_tokens = list(pre_image_tokens)
278
+ self.post_image_tokens = list(post_image_tokens)
279
+ self.system_start_tokens = list(system_start_tokens)
280
+ self.system_end_tokens = list(system_end_tokens)
281
+ self.user_start_tokens = list(user_start_tokens)
282
+ self.user_end_tokens = list(user_end_tokens)
283
+ self.asst_start_tokens = list(asst_start_tokens)
284
+ self.asst_end_tokens = list(asst_end_tokens)
285
+ self._allow_system_prompt = allow_system_prompt
286
+ self.tokenizer = tokenizer
287
+ self._image_processor = None
288
+ self._pad_token = pad_token
289
+ self.bos_token = bos_token
290
+
291
+ @property
292
+ def image_processor(self) -> QwenImageProcessor:
293
+ assert self._image_processor is not None
294
+ return self._image_processor
295
+
296
+ def _process_content(
297
+ self,
298
+ message_content: MessageContent,
299
+ role: Literal["system", "user", "assistant"],
300
+ tokenized_messages: list[torch.Tensor],
301
+ insertion_points: list[int],
302
+ image_list: list[torch.Tensor | None],
303
+ token_count: int,
304
+ img_size: int | None = None,
305
+ **kwargs: Any,
306
+ ) -> int:
307
+ mapping = {
308
+ "user": (self.user_start_tokens, self.user_end_tokens),
309
+ "assistant": (self.asst_start_tokens, self.asst_end_tokens),
310
+ "system": (self.system_start_tokens, self.system_end_tokens),
311
+ }
312
+ if role.lower() not in mapping:
313
+ raise ValueError(f"Unknown role '{role}' encountered in messages.")
314
+ start_tokens, end_tokens = mapping[role.lower()]
315
+ # 1) Add the start tokens
316
+ if start_tokens:
317
+ tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long))
318
+ token_count += len(start_tokens)
319
+ # 2) Process the message content one by one (potentially interleaved image and text)
320
+ for part in message_content:
321
+ elt_type = part["type"]
322
+ if elt_type == "image":
323
+ part = cast(ImageMessage, part)
324
+ self._process_image_message(
325
+ part,
326
+ tokenized_messages,
327
+ image_list,
328
+ img_size=img_size,
329
+ )
330
+ token_count += len(self.pre_image_tokens)
331
+ insertion_points.append(token_count)
332
+ token_count += len(self.post_image_tokens)
333
+ else:
334
+ part = cast(TextMessage, part)
335
+ self._process_text_message(
336
+ part["text"],
337
+ role=role,
338
+ token_list=tokenized_messages,
339
+ **kwargs,
340
+ )
341
+ token_count += tokenized_messages[-1].size(0)
342
+ # 3) Add the end tokens
343
+ if end_tokens:
344
+ tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long))
345
+ token_count += len(end_tokens)
346
+ return token_count
347
+
348
+ def _process_text_message(
349
+ self,
350
+ message: str,
351
+ role: Literal["system", "user", "assistant"],
352
+ token_list: list[torch.Tensor],
353
+ **kwargs: Any,
354
+ ) -> None:
355
+ if role.lower() == "system" and not self._allow_system_prompt:
356
+ raise ValueError("System prompts are not allowed in this tokenizer configuration.")
357
+ tokens = self.tokenizer.encode(
358
+ message, add_special_tokens=False, return_tensors="pt", **kwargs
359
+ )
360
+ tokens = cast(torch.Tensor, tokens)
361
+ token_list.append(tokens.flatten().to(torch.long))
362
+
363
+ def _process_image_message(
364
+ self,
365
+ message: ImageMessage,
366
+ token_list: list[torch.Tensor],
367
+ image_list: list[torch.Tensor | None],
368
+ img_size: int | None = None,
369
+ ) -> None:
370
+ img = message["image"]
371
+ if img is None:
372
+ image_list.append(None)
373
+ else:
374
+ image_list.append(
375
+ self.image_processor.process_images(
376
+ self._load_image(img), img_size=img_size
377
+ ).squeeze(0)
378
+ )
379
+ if self.pre_image_tokens:
380
+ token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long))
381
+
382
+ if self.post_image_tokens:
383
+ token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long))
384
+
385
+ def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image:
386
+ if isinstance(image_path_or_image, str):
387
+ return Image.open(image_path_or_image).convert("RGB")
388
+ return image_path_or_image
389
+
390
+ def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor:
391
+ return torch.nn.functional.pad(
392
+ tokens,
393
+ (0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0),
394
+ value=pad_value,
395
+ )
396
+
397
+ def pad_tokenized_messages(
398
+ self,
399
+ tokenized_messages_batch: list[torch.Tensor],
400
+ image_insertion_points_batch: list[torch.Tensor] | None = None,
401
+ ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
402
+ max_len = max(len(x) for x in tokenized_messages_batch)
403
+ if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left":
404
+ image_insertion_points_batch = [
405
+ x + max_len - len(tokenized_messages_batch[idx])
406
+ for idx, x in enumerate(image_insertion_points_batch)
407
+ ]
408
+ input_ids = torch.stack(
409
+ [
410
+ self._maybe_pad(s, max_len - s.size(0), self._pad_token)
411
+ for s in tokenized_messages_batch
412
+ ],
413
+ dim=0,
414
+ )
415
+ attention_mask = torch.stack(
416
+ [
417
+ self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0)
418
+ for s in tokenized_messages_batch
419
+ ],
420
+ dim=0,
421
+ )
422
+ return input_ids, attention_mask, image_insertion_points_batch
423
+
424
+ def tokenize_messages(
425
+ self,
426
+ messages: ProcessorInput,
427
+ suppress_bos_token: bool = False,
428
+ **kwargs: Any,
429
+ ) -> ProcessorOutput | None:
430
+ """Tokenize a batch of messages into token IDs suitable for Helium1 CASA model.
431
+
432
+ Args:
433
+ messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages),
434
+ where each message is a list of dictionaries with 'role' and 'content' keys.
435
+ continue_final_message (bool, optional): If True, the final message in each list will not have an end token added.
436
+ Defaults to False.
437
+ suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added.
438
+ Defaults to False.
439
+ **kwargs: Additional keyword arguments passed to the underlying encode method.
440
+ """
441
+ if not messages:
442
+ return None
443
+ if isinstance(messages[0], dict):
444
+ messages = [messages] # type: ignore[assignment]
445
+
446
+ messages = cast(list[list[Message]], messages)
447
+ image_insertion_points_batch = []
448
+ tokenized_messages_batch = []
449
+ image_list: list[torch.Tensor | None] = []
450
+ for msgs in messages:
451
+ # msgs.append({
452
+ # "role": "assistant",
453
+ # "content": [{"type": "text", "text": ""}]
454
+ # })
455
+ tokenized_messages = []
456
+ if not suppress_bos_token and self.bos_token is not None:
457
+ tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long))
458
+ insertion_points = []
459
+ token_count = 0
460
+ for msg in msgs:
461
+ token_count = self._process_content(
462
+ msg["content"],
463
+ role=msg["role"],
464
+ tokenized_messages=tokenized_messages,
465
+ insertion_points=insertion_points,
466
+ image_list=image_list,
467
+ token_count=token_count,
468
+ **kwargs,
469
+ )
470
+ tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long))
471
+ image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long))
472
+
473
+ if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant":
474
+ # Remove the assistant end tokens from the final message
475
+ end_token_len = len(self.asst_end_tokens)
476
+ tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len]
477
+ if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user":
478
+ # Remove the assistant end tokens from the final message
479
+ end_token_len = len(self.asst_end_tokens)
480
+ tokenized_messages_batch[-1] = torch.cat(
481
+ [
482
+ tokenized_messages_batch[-1],
483
+ torch.Tensor(self.asst_start_tokens).to(torch.long),
484
+ ]
485
+ )
486
+
487
+ input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages(
488
+ tokenized_messages_batch, image_insertion_points_batch
489
+ )
490
+
491
+ if image_list:
492
+ assert sum(img is None for img in image_list) % len(image_list) == 0, (
493
+ "Either all or no image must be None."
494
+ )
495
+ pixel_values: None | torch.Tensor | list[torch.Tensor]
496
+ if image_list[0] is None:
497
+ pixel_values = None
498
+ else:
499
+ pixel_values = cast(list[torch.Tensor], image_list)
500
+ return ProcessorOutput(
501
+ input_ids=input_ids,
502
+ image_embeds_insertion_points=image_embeds_insertion_points,
503
+ attention_mask=attention_mask,
504
+ pixel_values=pixel_values,
505
+ )
processing_helium1_casa.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
2
+
3
+ from .processing import BaseProcessor, QwenImageProcessor
4
+
5
+
6
+ class Helium1CASAProcessor(BaseProcessor):
7
+ attributes = ["tokenizer"]
8
+ tokenizer_class = "PreTrainedTokenizerFast"
9
+
10
+ def __init__(
11
+ self,
12
+ tokenizer: PreTrainedTokenizerFast,
13
+ pre_image_tokens: tuple[int, ...] = tuple(),
14
+ post_image_tokens: tuple[int, ...] = tuple(),
15
+ system_start_tokens: tuple[int, ...] = tuple(),
16
+ system_end_tokens: tuple[int, ...] = tuple(),
17
+ user_start_tokens: tuple[int, ...] = (104,),
18
+ user_end_tokens: tuple[int, ...] = (105,),
19
+ asst_start_tokens: tuple[int, ...] = (102,),
20
+ asst_end_tokens: tuple[int, ...] = (103,),
21
+ bos_token: int = 1,
22
+ image_size: int = 896,
23
+ ):
24
+ super().__init__(
25
+ tokenizer=tokenizer,
26
+ pre_image_tokens=pre_image_tokens,
27
+ post_image_tokens=post_image_tokens,
28
+ system_start_tokens=system_start_tokens,
29
+ system_end_tokens=system_end_tokens,
30
+ user_start_tokens=user_start_tokens,
31
+ user_end_tokens=user_end_tokens,
32
+ asst_start_tokens=asst_start_tokens,
33
+ asst_end_tokens=asst_end_tokens,
34
+ allow_system_prompt=False,
35
+ bos_token=bos_token,
36
+ )
37
+ self._image_processor = QwenImageProcessor(img_size=image_size)
processor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_helium1_casa.Helium1CASAProcessor"
4
+ },
5
+ "bos_token": 1,
6
+ "image_size": 896,
7
+ "post_image_tokens": [],
8
+ "pre_image_tokens": [],
9
+ "processor_class": "Helium1CASAProcessor"
10
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90cea6d2a04d6c89a9904853c22aac0c342fc193a75048f4cbee4f98b9c835d8
3
+ size 70505
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "additional_special_tokens": [
4
+ "<|im_sp_00|>",
5
+ "<|im_sp_01|>",
6
+ "<|im_sp_02|>",
7
+ "<|im_sp_94|>",
8
+ "<|im_sp_95|>",
9
+ "<|im_sp_96|>",
10
+ "<|im_sp_97|>",
11
+ "<|im_sp_98|>",
12
+ "<|im_sp_99|>"
13
+ ]
14
+ }
utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=protected-access
2
+ """Utils to handle CASA layers construction"""
3
+
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass, fields
6
+ from typing import Any, Callable, Generic, TypeVar
7
+
8
+ import torch
9
+
10
+
11
+ def delta_w_factory(
12
+ org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
13
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
14
+ """Factory for building linear op where the weights are the sum of two layers' weights"""
15
+
16
+ def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
17
+ nonlocal org_lin, new_lin
18
+ bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
19
+ return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)
20
+
21
+ return _delta_w_fwd
22
+
23
+
24
+ @dataclass
25
+ class StreamingState:
26
+ """Streaming State used by CASA layers at inference to save
27
+ e.g. the offset, the KV Cache and other persistent states"""
28
+
29
+ offset: int = 0
30
+
31
+ def _is_valid_field(self, key: str) -> bool:
32
+ return key in {x.name for x in fields(self)}
33
+
34
+ def _init_field(self, key: str) -> None:
35
+ """Init function for non-arggment dependent defauls"""
36
+ assert self._is_valid_field(key)
37
+ if key == "offset":
38
+ self.offset = 0
39
+ else:
40
+ # for fields which should be set explicitly and cannot be auto-initialized
41
+ setattr(self, key, None)
42
+
43
+ def init(self) -> None:
44
+ for key in [x.name for x in fields(self)]:
45
+ self._init_field(key)
46
+
47
+ def _reset_field(self, name: str) -> None:
48
+ """Resets the given field"""
49
+ self._init_field(name)
50
+
51
+ def reset(self) -> None:
52
+ for f in fields(self):
53
+ self._reset_field(f.name)
54
+
55
+ def _get_field(self, f: str) -> Any:
56
+ """Get field and init if not"""
57
+ assert self._is_valid_field(f)
58
+ if getattr(self, f) is None:
59
+ self._init_field(f)
60
+ return getattr(self, f)
61
+
62
+ def _set_field(self, f: str, value: Any) -> None:
63
+ assert self._is_valid_field(f)
64
+ setattr(self, f, value)
65
+
66
+
67
+ StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)
68
+
69
+
70
+ class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method
71
+ """Overrides Audiocraft's Streaming modules with additional small utils"""
72
+
73
+ def __init__(self, state_class: type) -> None:
74
+ torch.nn.Module.__init__(self)
75
+ self.is_streaming: bool = False
76
+ self.enable_viz: tuple[str, ...] = ()
77
+ self._streaming_state: StreamingStateT = state_class()
78
+
79
+ @property
80
+ def streaming_state(self) -> StreamingStateT:
81
+ return self._streaming_state
82
+
83
+ def _apply_named_streaming(self, fn: Callable):
84
+ """Apply function to all streaming modules"""
85
+ for name, module in self.named_modules():
86
+ if isinstance(module, StreamingModule):
87
+ fn(name, module)
88
+
89
+ def reset_streaming(self):
90
+ """Reset the streaming state."""
91
+
92
+ def _reset(_: str, module: StreamingModule):
93
+ module._streaming_state.reset()
94
+
95
+ self._apply_named_streaming(_reset)
96
+
97
+ def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
98
+ """Set all streaming modules in streaming mode"""
99
+
100
+ def _set_streaming(_, module: StreamingModule) -> None:
101
+ module.is_streaming = streaming
102
+ module.enable_viz = viz
103
+ if streaming:
104
+ module.streaming_state.init()
105
+
106
+ self._apply_named_streaming(_set_streaming)
107
+
108
+ @contextmanager
109
+ def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
110
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
111
+ self._set_streaming(stream, viz)
112
+ try:
113
+ yield
114
+ finally:
115
+ self._set_streaming(False, ())
116
+ self.reset_streaming()