ameroyer commited on
Commit
8a1bc81
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Notice ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ CASA-Qwen2_5-VL-3B-LiveCC is finetuned from Qwen2.5-VL-3B with additional CASA layers.
2
+ Qwen is licensed under the Qwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - chenjoya/Live-WhisperX-526K
4
+ language:
5
+ - en
6
+ base_model:
7
+ - Qwen/Qwen2.5-VL-3B-Instruct
8
+ pipeline_tag: video-text-to-text
9
+ license: cc-by-nc-sa-4.0
10
+ ---
11
+ Please refer to the [main model card](https://huggingface.co/kyutai/CASA-Helium1-VL-2B) for more information and instructions to run.
12
+
13
+ This model page contains model weights for `CASA-Qwen2_5-VL-3B-LiveCC`, a Qwen-2.5VL model adapted from token insertion to cross-attention based using CASA layers and further finetuned on LiveCC for live video captioning. We provide model weights for other CASA models in the associated model collection.
__init__.py ADDED
File without changes
casa_attention.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CASA layers"""
2
+
3
+ import bisect
4
+ from dataclasses import dataclass
5
+ from itertools import accumulate
6
+ from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload
7
+ from typing import cast as type_cast
8
+
9
+ import torch
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+ from .utils import StreamingModule, StreamingState, delta_w_factory
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers.configuration_utils import PretrainedConfig
16
+
17
+ try:
18
+ from flash_attn import flash_attn_varlen_func
19
+ except ImportError:
20
+ flash_attn_varlen_func = None # type: ignore
21
+
22
+
23
+ WindowsComputeKwargs = TypedDict(
24
+ "WindowsComputeKwargs",
25
+ {
26
+ "num_post_image_tokens": int,
27
+ "num_pre_image_tokens": int,
28
+ },
29
+ total=False,
30
+ )
31
+
32
+
33
+ def __split_n_merge__(
34
+ x: torch.Tensor,
35
+ sample_lengths: list[int],
36
+ padding_side: Literal["left", "right"] = "right",
37
+ pad_value: int | float | bool = 0,
38
+ ) -> torch.Tensor:
39
+ max_sample_length = max(sample_lengths)
40
+ pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2))
41
+ return torch.stack(
42
+ [
43
+ torch.nn.functional.pad(
44
+ _x,
45
+ pad_tuple + (0, max_sample_length - _x.shape[0])
46
+ if padding_side == "right"
47
+ else pad_tuple + (max_sample_length - _x.shape[0], 0),
48
+ value=pad_value,
49
+ )
50
+ for _x in torch.split(x, sample_lengths, dim=0)
51
+ ],
52
+ dim=0,
53
+ )
54
+
55
+
56
+ @overload
57
+ def insert_image_tokens(
58
+ inputs_embeds: torch.Tensor,
59
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
60
+ image_embeds_insertion_points: list[torch.Tensor],
61
+ recover_batch_dim: Literal[True],
62
+ attention_mask: torch.Tensor | None = None,
63
+ padding_side: Literal["left", "right"] = "right",
64
+ keep_only_attended: bool = False,
65
+ pad_output: int | float | bool = 0.0,
66
+ ) -> tuple[
67
+ torch.Tensor,
68
+ None,
69
+ torch.Tensor | None,
70
+ torch.Tensor,
71
+ ]: ...
72
+ @overload
73
+ def insert_image_tokens(
74
+ inputs_embeds: torch.Tensor,
75
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
76
+ image_embeds_insertion_points: list[torch.Tensor],
77
+ recover_batch_dim: Literal[False],
78
+ attention_mask: torch.Tensor | None = None,
79
+ padding_side: Literal["left", "right"] = "right",
80
+ keep_only_attended: bool = False,
81
+ pad_output: int | float | bool = 0.0,
82
+ ) -> tuple[
83
+ torch.Tensor,
84
+ list[int],
85
+ torch.Tensor | None,
86
+ torch.Tensor,
87
+ ]: ...
88
+ def insert_image_tokens(
89
+ inputs_embeds: torch.Tensor,
90
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
91
+ image_embeds_insertion_points: list[torch.Tensor],
92
+ recover_batch_dim: bool = True,
93
+ attention_mask: torch.Tensor | None = None,
94
+ padding_side: Literal["left", "right"] = "right",
95
+ keep_only_attended: bool = False,
96
+ pad_output: int | float | bool = 0.0,
97
+ ) -> tuple[
98
+ torch.Tensor | torch.Tensor,
99
+ list[int] | None,
100
+ torch.Tensor | torch.Tensor | None,
101
+ torch.Tensor | torch.Tensor,
102
+ ]:
103
+ """
104
+ Insert image embeddings into text embeddings
105
+
106
+ Args:
107
+ inputs_embeds (torch.Tensor): (B, S, D) input token embeddings.
108
+ image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings.
109
+ image_embeds_insertion_points (list[torch.Tensor]): Insertion indices.
110
+ attention_mask (torch.Tensor, optional): (B, S) attention mask.
111
+ padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images.
112
+ return_indices (bool): Whether to return gather indices or the fused sequence directly.
113
+ keep_only_attended: This is only applicable when recover_batch_dim is False; whether to
114
+ remove any non-attended tokens in the whole array. In this case, the attention
115
+ mask returned is **still the original one**, so we can remember which indices have been
116
+ removed
117
+ Returns:
118
+ output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence
119
+ image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list
120
+ attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding.
121
+ image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions.
122
+ """
123
+ if isinstance(image_embeds, list) and len(image_embeds) == 0:
124
+ batch_size, text_seq_length, token_dim = inputs_embeds.shape
125
+ if recover_batch_dim:
126
+ return (
127
+ inputs_embeds,
128
+ None,
129
+ attention_mask,
130
+ torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool),
131
+ )
132
+ else:
133
+ flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1]
134
+ return (
135
+ torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])),
136
+ [text_seq_length] * inputs_embeds.shape[0],
137
+ attention_mask.flatten() if attention_mask is not None else None,
138
+ torch.zeros((flattened_seq_length, 1), dtype=torch.bool),
139
+ )
140
+
141
+ # Sanity checks
142
+ if isinstance(image_embeds, torch.Tensor):
143
+ assert inputs_embeds.shape[-1] == image_embeds.shape[-1]
144
+ else:
145
+ assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds)
146
+
147
+ batch_size, text_seq_length, token_dim = inputs_embeds.shape
148
+ image_seq_length = [x.shape[0] for x in image_embeds]
149
+
150
+ # Flatten insertion points
151
+ insertion_offset = []
152
+ counter, offset_from_text, offset_from_image = 0, 0, 0
153
+ for sample in image_embeds_insertion_points:
154
+ for pt in sample:
155
+ insertion_offset.append(pt + offset_from_image + offset_from_text)
156
+ offset_from_image += image_seq_length[counter]
157
+ counter += 1
158
+ offset_from_text += text_seq_length
159
+ image_insert_positions = [
160
+ x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx])
161
+ ]
162
+
163
+ # Flatten image embeds
164
+ if isinstance(image_embeds, list):
165
+ image_embeds = torch.cat(image_embeds, dim=0)
166
+ else:
167
+ image_embeds = type_cast(torch.Tensor, image_embeds)
168
+ image_embeds = torch.reshape(image_embeds, (-1, token_dim))
169
+
170
+ # Flatten text embeds across batch dim (B x S, D)
171
+ inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim))
172
+ flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length)
173
+ text_insert_positions = sorted(
174
+ set(range(flattened_seq_length)).difference(set(image_insert_positions))
175
+ )
176
+
177
+ # Scatter image embeds in the flattened dict
178
+ # scatter text related stuff
179
+ output = torch.empty(
180
+ (flattened_seq_length, token_dim),
181
+ device=inputs_embeds.device,
182
+ dtype=inputs_embeds.dtype,
183
+ )
184
+ txt_positions_tensor = torch.Tensor(text_insert_positions).to(
185
+ dtype=torch.long, device=inputs_embeds.device
186
+ )
187
+ output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds)
188
+ attention_mask_new: torch.Tensor | None = None
189
+ if attention_mask is not None:
190
+ attention_mask_new = torch.ones(
191
+ (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
192
+ )
193
+ attention_mask_new.scatter_(
194
+ 0, txt_positions_tensor, attention_mask.flatten().to(torch.bool)
195
+ )
196
+
197
+ # scatter image related stuff
198
+ image_tokens_mask = torch.zeros(
199
+ (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
200
+ )
201
+ img_positions_tensor = torch.Tensor(image_insert_positions).to(
202
+ device=inputs_embeds.device, dtype=torch.long
203
+ )
204
+ output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds)
205
+ image_tokens_mask.scatter_(0, img_positions_tensor, True)
206
+
207
+ # Compute expected sample length, taking into account the real batch
208
+ # i.e. recover the batch dimension of image embeddings
209
+ sample_lengths = []
210
+ counter = 0
211
+ for sample_idx, pts in enumerate(image_embeds_insertion_points):
212
+ num_image_tokens = 0
213
+ for _ in pts:
214
+ num_image_tokens += image_seq_length[counter]
215
+ counter += 1
216
+ if keep_only_attended and attention_mask is not None:
217
+ attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item()
218
+ sample_lengths.append(attended_seq_length + num_image_tokens)
219
+ else:
220
+ sample_lengths.append(text_seq_length + num_image_tokens)
221
+
222
+ # For CASA attention, we can keep stuff flatten ad return
223
+ # the sample_lengths for the blockwise attention
224
+ if not recover_batch_dim:
225
+ if keep_only_attended and attention_mask_new is not None:
226
+ output = output[attention_mask_new]
227
+ image_tokens_mask = image_tokens_mask[attention_mask_new]
228
+ return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None]
229
+
230
+ # Otherwise, time to (pad) and reshape
231
+ # Easy case: everything has the same length
232
+ if all(x == sample_lengths[0] for x in sample_lengths):
233
+ output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim))
234
+ image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1))
235
+ if attention_mask_new is not None:
236
+ attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0]))
237
+ # if there is any size mismatch we break into a
238
+ # list and pad again
239
+ else:
240
+ # split and merge
241
+ output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output)
242
+ # note that the extra padding tokens are also marked as image tokens to be removed later
243
+ image_tokens_mask = __split_n_merge__(
244
+ image_tokens_mask, sample_lengths, padding_side, True
245
+ )[:, :, None]
246
+ if attention_mask_new is not None:
247
+ attention_mask_new = __split_n_merge__(
248
+ attention_mask_new, sample_lengths, padding_side, 0
249
+ )
250
+ # Return
251
+ return output, sample_lengths, attention_mask_new, image_tokens_mask
252
+
253
+
254
+ def get_sample_lengths_from_insertion_points(
255
+ image_embeds_insertion_points: list[torch.Tensor],
256
+ image_embeds: torch.Tensor | list[torch.Tensor] | None,
257
+ total_seq_len: int | None = None,
258
+ attention_mask: torch.Tensor | None = None,
259
+ **kwargs: WindowsComputeKwargs,
260
+ ) -> tuple[list[tuple[int, bool]], list[int]]:
261
+ """Compute sample lengths as if each image insertion point defines a
262
+ new document (ex document ID)
263
+ """
264
+ num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0))
265
+ num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0))
266
+ squashed_samples_lengths = type_cast(
267
+ list[list[int]] | None, kwargs.get("squashed_samples_lengths", None)
268
+ )
269
+ if squashed_samples_lengths is not None:
270
+ assert len(squashed_samples_lengths) == len(image_embeds_insertion_points)
271
+
272
+ def __insert_next_sample__(
273
+ batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False
274
+ ) -> None:
275
+ nonlocal attention_mask
276
+ nonlocal text_sample_lengths, full_sample_lengths
277
+ nonlocal cum_samples_lengths, current_image_offset
278
+ nonlocal last_image_idx, current_image_idx, current_length
279
+ # Add the sample between [last_insrt_pt, insrt_pt] with breaks in
280
+ # between any squashed samples we find on the way
281
+ start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt)
282
+ added_sample = False
283
+ for end_of_sample in cum_samples_lengths[start_pt:]:
284
+ # we will break the loop at the end when end_of_sample = insrt_pt
285
+ end_of_sample = min(end_of_sample, insrt_pt)
286
+
287
+ # Add between [last_insrt_pt, end_of_sample]
288
+ current_length = end_of_sample - last_insrt_pt
289
+ if attention_mask is not None:
290
+ current_length -= int(
291
+ torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item()
292
+ )
293
+ if current_length > 0:
294
+ added_sample = True
295
+ text_sample_lengths.append(
296
+ (current_length, end_of_batch_sample and insrt_pt == end_of_sample)
297
+ )
298
+ # add image tokens to current_length
299
+ if current_image_idx > 0 and image_embeds is not None:
300
+ images_in_sample = [
301
+ img_idx
302
+ for img_idx in range(last_image_idx, current_image_idx)
303
+ if img_idx < len(image_embeds_insertion_points[batch_idx])
304
+ and last_insrt_pt
305
+ <= image_embeds_insertion_points[batch_idx][img_idx]
306
+ < end_of_sample
307
+ ]
308
+ if len(images_in_sample) > 0:
309
+ num_image_tokens = sum(
310
+ _x.shape[0]
311
+ for _x in image_embeds[
312
+ current_image_offset + images_in_sample[0] : current_image_offset
313
+ + images_in_sample[-1]
314
+ + 1
315
+ ]
316
+ )
317
+ current_length += num_image_tokens
318
+ full_sample_lengths.append(current_length)
319
+
320
+ # prepare for next loop
321
+ last_insrt_pt = end_of_sample
322
+ if end_of_sample == insrt_pt:
323
+ break
324
+ # End of loop: Catching weird use case where we may end up on a span
325
+ # full of padding tokens which will not get added due to current_length > 0
326
+ if end_of_batch_sample:
327
+ assert added_sample, "Weird edge case. Don't do that, thank you"
328
+ text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
329
+
330
+ # End of loop: Catching weird use case where we may end up on a span
331
+ # full of padding tokens which will not get added due to current_length > 0
332
+ if end_of_batch_sample:
333
+ assert added_sample, "Weird edge case. Don't do that, thank you"
334
+ text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
335
+
336
+ current_image_offset = 0
337
+ text_sample_lengths, full_sample_lengths = [], []
338
+ cum_samples_lengths: list[int] = []
339
+ current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
340
+ for batch_idx, pts in enumerate(image_embeds_insertion_points):
341
+ if squashed_samples_lengths is not None:
342
+ cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx]))
343
+ else:
344
+ assert total_seq_len is not None
345
+ cum_samples_lengths = [total_seq_len]
346
+
347
+ for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()):
348
+ # check if the images are consecutive in which way we want
349
+ # them to belong to the same window
350
+ if current_image_idx >= 1 and insrt_pt == (
351
+ image_embeds_insertion_points[batch_idx][current_image_idx - 1]
352
+ + num_pre_image_tokens
353
+ + num_post_image_tokens
354
+ ):
355
+ continue
356
+ # Otherwise, we found a new sample
357
+ # not very important but for completeness: the insertion points come *after*
358
+ # the pre-image tokens per design but for the document-id mask it is more consistent to
359
+ # have them correspond to the same image
360
+ insrt_pt -= num_pre_image_tokens
361
+
362
+ # Update text and full sample lengths
363
+ if insrt_pt > last_insrt_pt:
364
+ __insert_next_sample__(
365
+ batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False
366
+ )
367
+ last_image_idx = current_image_idx
368
+ last_insrt_pt = insrt_pt
369
+
370
+ # End of batch: add sample in progress and reset
371
+ current_image_idx += 1
372
+ if cum_samples_lengths[-1] > last_insrt_pt:
373
+ __insert_next_sample__(
374
+ batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True
375
+ )
376
+ current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
377
+ current_image_offset += len(pts)
378
+
379
+ # Sanity checks that the is_eob are correctly place
380
+ assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), (
381
+ f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs"
382
+ f" from original batch size ({len(image_embeds_insertion_points)})"
383
+ )
384
+ return text_sample_lengths, full_sample_lengths
385
+
386
+
387
+ class CASAAttentionHandler:
388
+ def __init__(
389
+ self,
390
+ inputs_embeds: torch.Tensor,
391
+ image_embeds: torch.Tensor | list[torch.Tensor],
392
+ image_embeds_insertion_points: list[torch.Tensor],
393
+ attention_mask: torch.Tensor | None = None,
394
+ rope_fn: Callable | None = None,
395
+ windows: Literal["batch", "squashed", "images", "turn_based"] = "images",
396
+ use_asymetric_q_kv: bool = True,
397
+ casa_windows_info: None | dict = None,
398
+ ):
399
+ """Initialize the structure holding the query buffer for CASA attention layers
400
+ (ie the **flattened** text+image inserted tokens).
401
+ Note that this structure is shared across all casa layers, and it gets updated
402
+ with the current hidden states at every layer; this is merely a buffer to keep
403
+ scatter_ operations in-plae as much as possible
404
+
405
+ In this module, the embeddings related values (image_tokens_mask,
406
+ text_sample_lengths etc) are stored under the assumption of a tensor
407
+ which is *flatened* and *witout padding tokens*
408
+ Only the attention mask is kept as-is (text-only, batched, padded) to
409
+ be able to recover original shapes when needed
410
+ """
411
+ super().__init__()
412
+ assert windows == "images" # for inference code release
413
+ # Note 1: Unless overriden, text/full_sample_lengths are defined such that one
414
+ # document = one sample in the batch
415
+ if attention_mask is None:
416
+ text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds]
417
+ else:
418
+ text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask]
419
+ (
420
+ full_inputs_embeds,
421
+ full_sample_lengths,
422
+ # Full attention mask is only needed at inference to
423
+ # flatten the KV-Cache and remove padding tokens
424
+ _,
425
+ self.image_tokens_mask,
426
+ ) = insert_image_tokens(
427
+ inputs_embeds=inputs_embeds,
428
+ image_embeds=image_embeds,
429
+ image_embeds_insertion_points=image_embeds_insertion_points,
430
+ attention_mask=attention_mask,
431
+ recover_batch_dim=False,
432
+ keep_only_attended=attention_mask is not None,
433
+ )
434
+ assert self.image_tokens_mask.ndim == 2
435
+ self.image_embeds = image_embeds
436
+ self.image_embeds_insertion_points = image_embeds_insertion_points
437
+ self.attention_mask = None if attention_mask is None else attention_mask.bool()
438
+ self.use_asymetric_qkv = use_asymetric_q_kv
439
+ # At inference, we have to use asymetric QKV for efficiency
440
+ if self.attention_mask is not None:
441
+ self.use_asymetric_qkv = True
442
+
443
+ # Build CASA windows
444
+ assert casa_windows_info is not None
445
+ text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points(
446
+ image_embeds_insertion_points=image_embeds_insertion_points,
447
+ image_embeds=image_embeds,
448
+ total_seq_len=inputs_embeds.shape[1],
449
+ attention_mask=self.attention_mask,
450
+ **casa_windows_info, # pyright: ignore
451
+ )
452
+
453
+ # Sanity checks on the sample lengths
454
+ self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0]
455
+ self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0]
456
+
457
+ assert len(self.text_sample_lengths) == len(self.full_sample_lengths), (
458
+ f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}"
459
+ f" != full sample lengths {len(self.full_sample_lengths)}"
460
+ )
461
+ if self.attention_mask is None:
462
+ num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1]
463
+ else:
464
+ num_unpadded_text_tokens = int(
465
+ torch.sum(type_cast(torch.Tensor, attention_mask)).item()
466
+ )
467
+ assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, (
468
+ f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
469
+ )
470
+ assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], (
471
+ f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
472
+ )
473
+
474
+ # Finally we can compute cu_seqlen based on sample lengths
475
+ self.max_seqlen_q = max(self.text_sample_lengths)[0]
476
+ self.cu_seqlens_q = self.get_cu_seqlens(
477
+ [x[0] for x in self.text_sample_lengths], device=inputs_embeds.device
478
+ )
479
+
480
+ self.max_seqlen_kv = max(self.full_sample_lengths)
481
+ self.cu_seqlens_kv = self.get_cu_seqlens(
482
+ self.full_sample_lengths, device=inputs_embeds.device
483
+ )
484
+
485
+ # For inference: We save the length of the current document
486
+ # to trim the KV cache appropriately
487
+ self.current_doc_lengths = self.full_sample_lengths
488
+
489
+ # Precompute position embeddings
490
+ self.position_embeds = None
491
+ self.rope_fn = rope_fn
492
+ if self.rope_fn is not None:
493
+ self.position_embeds = self.compute_position_embeddings(
494
+ self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds
495
+ )
496
+
497
+ @property
498
+ def batch_lengths(self) -> list[int]:
499
+ """Return a (batch_size,) list of integers containing the
500
+ number of (non-padded) text tokens for each sample in the batch"""
501
+ bls = [0]
502
+ for ln, eob in self.text_sample_lengths:
503
+ bls[-1] += ln
504
+ if eob:
505
+ bls.append(0)
506
+ return bls[:-1]
507
+
508
+ @property
509
+ def full_batch_lengths(self) -> list[int]:
510
+ """Same as batch_lengths for text+image tokens"""
511
+ bls = [0]
512
+ for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths):
513
+ bls[-1] += ln
514
+ if eob:
515
+ bls.append(0)
516
+ return bls[:-1]
517
+
518
+ def get_cu_seqlens(
519
+ self, sample_lengths: list[int], device: torch.device | None
520
+ ) -> torch.Tensor:
521
+ """Update cu_seqlengths according to the given sample_lengths"""
522
+ return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to(
523
+ dtype=torch.int32, device=device
524
+ )
525
+
526
+ def compute_position_embeddings(
527
+ self,
528
+ rope_fn: Callable,
529
+ sample_lengths: list[int],
530
+ dummy_for_dtype_and_device: torch.Tensor,
531
+ ) -> tuple[torch.Tensor, torch.Tensor]:
532
+ """Compute info required for position embeddings. Can be override e.g. for Qwen"""
533
+ # option 1: Standard range
534
+ # position_ids = torch.arange(0, full_inputs_embeds.shape[0])
535
+ # option 2: Follows document boundary
536
+ position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0)
537
+ return rope_fn(
538
+ dummy_for_dtype_and_device,
539
+ position_ids.to(dummy_for_dtype_and_device.device)[None, ...],
540
+ )
541
+
542
+ def get_position_embedding(
543
+ self,
544
+ key: Literal["q", "kv"],
545
+ num_queries: int = 0,
546
+ ) -> tuple[torch.Tensor, torch.Tensor] | None:
547
+ if self.position_embeds is None:
548
+ return None
549
+ cos, sin = self.position_embeds
550
+ bls = self.full_batch_lengths
551
+ # For Q, we only want the text-only posembeds
552
+ if key == "q" and self.use_asymetric_qkv:
553
+ bls = self.batch_lengths
554
+ cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]]
555
+ elif key not in {"q", "kv"}:
556
+ raise ValueError(f"Unknow for position embedding {key}")
557
+
558
+ # Easy case: training or first step at inference: we use all the posembeds
559
+ if num_queries == 0:
560
+ return cos, sin
561
+ # If num queries is given, we need to trim for *every sample in the batch*
562
+ cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)]
563
+ sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)]
564
+ return torch.cat(cos, dim=1), torch.cat(sin, dim=1)
565
+
566
+ def get_full_embeds(
567
+ self, hidden_states: torch.Tensor, norm_fn: Callable | None
568
+ ) -> torch.Tensor:
569
+ """Update attended hidden states in the current query buffer
570
+
571
+ :param hidden_states: (b, s, d) Tensor input to the CASA attention layer"
572
+ """
573
+ assert self.image_embeds is not None
574
+ return insert_image_tokens(
575
+ inputs_embeds=hidden_states,
576
+ image_embeds=self.image_embeds
577
+ if norm_fn is None
578
+ else norm_fn(self.image_embeds)
579
+ if isinstance(self.image_embeds, torch.Tensor)
580
+ else [norm_fn(_x) for _x in self.image_embeds],
581
+ image_embeds_insertion_points=self.image_embeds_insertion_points,
582
+ attention_mask=self.attention_mask,
583
+ recover_batch_dim=False,
584
+ keep_only_attended=self.attention_mask is not None,
585
+ )[0][None, :, :]
586
+
587
+ def recover_text_embeds(
588
+ self,
589
+ hidden_states_out: torch.Tensor,
590
+ hidden_states_in: torch.Tensor,
591
+ update_image_embeddings: bool = False,
592
+ ) -> torch.Tensor:
593
+ """Returns text embeddings from the query buffer, including non-attended tokens at inference"""
594
+ if update_image_embeddings and not self.use_asymetric_qkv:
595
+ raise NotImplementedError("Implement image embeddings updates for asymetric QKV")
596
+ # Remove image tokens in the symetric case
597
+ if not self.use_asymetric_qkv:
598
+ hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]]
599
+
600
+ # if there's not attention mask, we are in the right padded case
601
+ # (keep_only_attended = False) we can directly return the query
602
+ # outputs (which don't contain the image)
603
+ if self.attention_mask is None:
604
+ return hidden_states_out
605
+
606
+ # Otherwise, we need to "scatter" back only the text-attended tokens to the original
607
+ # hidden states, which contain the paddings
608
+ num_queries = hidden_states_in.shape[1]
609
+
610
+ # Case 1: the padded hidden_states_in is larger than hidden_states_out
611
+ # we rebatch+pad hidden_state_out before doing the scattering
612
+ if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]:
613
+ s = torch.split(hidden_states_out, self.batch_lengths, dim=0)
614
+ assert max(_s.shape[0] for _s in s) <= num_queries # sanity check
615
+ s = [
616
+ torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0)
617
+ for _s in s
618
+ ]
619
+ return torch.where(
620
+ self.attention_mask[:, -num_queries:, None],
621
+ torch.stack(s),
622
+ hidden_states_in,
623
+ )
624
+ # If both have the smae shape, it means hidden_states_in contained no padding
625
+ # so we can directly return hidden states out
626
+ return hidden_states_out
627
+
628
+ def extend(self, num_tokens: int, offset: int = 0):
629
+ """Extend all necessary values of the Handler for infenrece
630
+ Note: this implementation curently assumes a single conversation at a time
631
+ (otherwise image tokens mask would have to change) and that tokens added are
632
+ attended to"""
633
+ # image embeds is inserted in the first step and stored in the KV cache
634
+ self.image_embeds = None
635
+
636
+ # Update attention mask (non-flattened) (assumes all new tokens are attended to)
637
+ if self.attention_mask is not None:
638
+ self.attention_mask = torch.nn.functional.pad(
639
+ self.attention_mask, (0, num_tokens), value=1
640
+ )
641
+
642
+ # Update image token mask (assumes only one image/conversation
643
+ # is started at once so that we always extend by zero)
644
+ # Note that the mask is stored flattened to avoid padding so we have to
645
+ # do something a bit ugly and inefficient here
646
+ imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0)
647
+ imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask]
648
+ self.image_tokens_mask = torch.cat(imtokmask, dim=0)
649
+
650
+ # Recompute cumulative document lengths after assigning the new
651
+ # number of tokens to each sample in the batch
652
+ for idx, (ln, is_eob) in enumerate(self.text_sample_lengths):
653
+ if is_eob:
654
+ self.text_sample_lengths[idx] = (num_tokens + ln, is_eob)
655
+ self.full_sample_lengths[idx] += num_tokens
656
+
657
+ # Recompute cu sequlen
658
+ # First step: Technically this never occurs, but we keep it for completeness
659
+ if offset == 0:
660
+ self.max_seqlen_q = max(self.text_sample_lengths)[0]
661
+ self.cu_seqlens_q = self.get_cu_seqlens(
662
+ [x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device
663
+ )
664
+
665
+ self.max_seqlen_kv = max(self.full_sample_lengths)
666
+ self.cu_seqlens_kv = self.get_cu_seqlens(
667
+ self.full_sample_lengths, device=self.cu_seqlens_kv.device
668
+ )
669
+ # Step > 0: the annoying part is since flashattn_varlen does not accept
670
+ # 0-len documents, we need to remove documents from the KV Cache when they're past
671
+ # their windows. In our current setting, this means we only want to keep the latest
672
+ # documents
673
+ else:
674
+ self.max_seqlen_q = num_tokens
675
+ self.cu_seqlens_q = self.get_cu_seqlens(
676
+ [num_tokens for (_, eob) in self.text_sample_lengths if eob],
677
+ device=self.cu_seqlens_q.device,
678
+ )
679
+
680
+ final_doc_lengths = [
681
+ ln
682
+ for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths)
683
+ if eob
684
+ ]
685
+ self.current_doc_lengths = final_doc_lengths
686
+ self.max_seqlen_kv = max(self.current_doc_lengths)
687
+ self.cu_seqlens_kv = self.get_cu_seqlens(
688
+ final_doc_lengths,
689
+ device=self.cu_seqlens_kv.device,
690
+ )
691
+ # Update position embeddings
692
+ if self.rope_fn is not None and self.position_embeds is not None:
693
+ self.position_embeds = self.compute_position_embeddings(
694
+ self.rope_fn,
695
+ self.full_sample_lengths,
696
+ dummy_for_dtype_and_device=self.position_embeds[0],
697
+ )
698
+
699
+
700
+ @dataclass
701
+ class CASAAttentionStreamingState(StreamingState):
702
+ """Streaming State for CASA Atention module. Keep the hidden"""
703
+
704
+ k: torch.Tensor = None # pyright: ignore[reportAssignmentType]
705
+ v: torch.Tensor = None # pyright: ignore[reportAssignmentType]
706
+ recover_batched_trims: list[int] = None # pyright: ignore[reportAssignmentType]
707
+ casa_handler: CASAAttentionHandler = None # pyright: ignore[reportAssignmentType]
708
+
709
+ def maybe_get_casa_handler(
710
+ self,
711
+ casa_handler: CASAAttentionHandler | None,
712
+ is_first_casa_layer: bool = False,
713
+ num_queries: int = -1,
714
+ ) -> CASAAttentionHandler | None:
715
+ # Set given Casa Handler the first time we reach this
716
+ if self.casa_handler is None:
717
+ self.casa_handler = casa_handler # pyright: ignore
718
+ # subsequent calls: we need to extend shape to accomodate new tokens
719
+ # however because CASA handler is shared across layers, we only need to do it once
720
+ if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer:
721
+ # since CasaHandler is shared, we only use its extend step once
722
+ self.casa_handler.extend(num_queries, offset=self.offset)
723
+ return self.casa_handler
724
+
725
+ def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor:
726
+ """Recover batched key/value states with left padding"""
727
+ s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1)
728
+ mlen = max(_s.shape[1] for _s in s)
729
+ # Remember the added padding so that we can re-flatten KV later
730
+ if self.recover_batched_trims is None:
731
+ self.recover_batched_trims = [mlen - _s.shape[1] for _s in s]
732
+ s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s]
733
+ return torch.cat(s, dim=0)
734
+
735
+ def __get_flattened_kv__(
736
+ self, k: torch.Tensor | None = None, v: torch.Tensor | None = None
737
+ ) -> tuple[torch.Tensor, torch.Tensor]:
738
+ """
739
+ Flattened and remove padding to act with flash_attn_func
740
+ """
741
+ k = self.k if k is None else k
742
+ v = self.v if v is None else v
743
+ assert k is not None and v is not None
744
+
745
+ # Since every batch at least contributes one document,
746
+ # we can use this to check whether we are in streaming mode with dropped docs.
747
+ # If so, we should trim the kv cache accordingly
748
+ if len(self.casa_handler.current_doc_lengths) == len(k):
749
+ k = torch.cat(
750
+ [
751
+ _k[self.recover_batched_trims[idx] :][-doc_len:]
752
+ for idx, _k, doc_len in zip(
753
+ range(len(k)), k, self.casa_handler.current_doc_lengths
754
+ )
755
+ ]
756
+ )
757
+ v = torch.cat(
758
+ [
759
+ _v[self.recover_batched_trims[idx] :][-doc_len:]
760
+ for idx, _v, doc_len in zip(
761
+ range(len(k)), v, self.casa_handler.current_doc_lengths
762
+ )
763
+ ]
764
+ )
765
+ return k[None, ...], v[None, ...]
766
+
767
+ k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)])
768
+ v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)])
769
+ return k[None, ...], v[None, ...]
770
+
771
+ def extend_kv(
772
+ self, key_states: torch.Tensor, value_states: torch.Tensor
773
+ ) -> tuple[torch.Tensor, torch.Tensor]:
774
+ """
775
+ Extend KV Cache while keep
776
+ """
777
+ assert self.casa_handler is not None
778
+ if self.k is None and self.v is None:
779
+ # Init with batch-padded key and value states
780
+ self.k = self.__recover_batched_kv__(key_states)
781
+ self.v = self.__recover_batched_kv__(value_states)
782
+ return self.__get_flattened_kv__()
783
+ if self.k is not None and self.v is not None:
784
+ # this is during generation; normally there is no padding at this stage
785
+ # so we can directly reshape the flattened key states
786
+ rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3])
787
+ self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1)
788
+ self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1)
789
+ return self.__get_flattened_kv__()
790
+
791
+ raise ValueError("Impossible configuration (k and v updates are desynchronized )")
792
+
793
+
794
+ class CASAAttention(StreamingModule[CASAAttentionStreamingState]):
795
+ def __init__(
796
+ self,
797
+ config: "PretrainedConfig",
798
+ layer_idx: int | None,
799
+ self_attn: torch.nn.Module | None = None,
800
+ input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
801
+ ):
802
+ super().__init__(CASAAttentionStreamingState)
803
+ self.head_dim = config.head_dim
804
+ self.config = config
805
+
806
+ self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0)
807
+ self.use_delta_w = config.casa_delta_w
808
+
809
+ self.q_proj_casa = self.init_from_config_proj("q", config)
810
+ self.k_proj_casa = self.init_from_config_proj("k", config)
811
+ self.v_proj_casa = self.init_from_config_proj("v", config)
812
+ self.o_proj_casa = self.init_from_config_proj("o", config)
813
+
814
+ # Delta_w
815
+ self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
816
+ self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
817
+ self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
818
+ self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
819
+
820
+ if config.casa_delta_w:
821
+ assert self_attn is not None
822
+ self.set_delta_w(self_attn)
823
+
824
+ # Layer norm
825
+ self.norm_fn: Callable | None = None
826
+ if config.xa_norm_on_images:
827
+ assert input_layernorm_fn is not None
828
+ self.norm_fn = input_layernorm_fn
829
+
830
+ def init_from_mha(self, self_attn: torch.nn.Module):
831
+ assert self_attn is not None
832
+ with torch.no_grad():
833
+ assert hasattr(self_attn, "q_proj")
834
+ for key in ["q", "k", "v", "o"]:
835
+ src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj"))
836
+ tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa"))
837
+ tgt.weight.copy_(src.weight)
838
+ if tgt.bias is not None and src.bias is not None:
839
+ tgt.bias.copy_(src.bias)
840
+
841
+ def set_delta_w(self, self_attn: torch.nn.Module):
842
+ """Delta w setup"""
843
+ self.override_q_proj = delta_w_factory(
844
+ self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj)
845
+ )
846
+ self.override_k_proj = delta_w_factory(
847
+ self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj)
848
+ )
849
+ self.override_v_proj = delta_w_factory(
850
+ self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj)
851
+ )
852
+ self.override_o_proj = delta_w_factory(
853
+ self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj)
854
+ )
855
+
856
+ with torch.no_grad():
857
+ torch.nn.init.zeros_(self.q_proj_casa.weight)
858
+ torch.nn.init.zeros_(self.k_proj_casa.weight)
859
+ torch.nn.init.zeros_(self.v_proj_casa.weight)
860
+ torch.nn.init.zeros_(self.o_proj_casa.weight)
861
+ if self.q_proj_casa.bias is not None:
862
+ torch.nn.init.zeros_(self.q_proj_casa.bias)
863
+ if self.k_proj_casa.bias is not None:
864
+ torch.nn.init.zeros_(self.k_proj_casa.bias)
865
+ if self.v_proj_casa.bias is not None:
866
+ torch.nn.init.zeros_(self.v_proj_casa.bias)
867
+ if self.o_proj_casa.bias is not None:
868
+ torch.nn.init.zeros_(self.o_proj_casa.bias)
869
+
870
+ def init_from_config_proj(
871
+ self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
872
+ ) -> torch.nn.Linear:
873
+ """Initialize the Linear proj in this module"""
874
+ raise NotImplementedError("Abastract class.")
875
+
876
+ def apply_position_embeddings(
877
+ self,
878
+ key: Literal["q", "kv"],
879
+ x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
880
+ casa_handler: CASAAttentionHandler | None,
881
+ num_queries: int = 0,
882
+ unsqueeze_dim: int = 1,
883
+ ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
884
+ """Apply position embeddings to query and key states"""
885
+ raise NotImplementedError("Abastract class.")
886
+
887
+ def forward(
888
+ self,
889
+ hidden_states: torch.Tensor,
890
+ casa_handler: CASAAttentionHandler | None,
891
+ ) -> torch.Tensor | None:
892
+ """Generic forward for CASA uses for instance in `helium1_attention`"""
893
+ og_dtype = hidden_states.dtype
894
+ if self.is_streaming:
895
+ casa_handler = self.streaming_state.maybe_get_casa_handler(
896
+ casa_handler,
897
+ is_first_casa_layer=self.is_first_casa_layer,
898
+ num_queries=hidden_states.shape[1],
899
+ )
900
+
901
+ # Case of text-only samples at training (or inference when no handler was cached)
902
+ # in this case we just skip CASA so we return None (no casa_update)
903
+ if casa_handler is None:
904
+ return None
905
+
906
+ if self.is_streaming:
907
+ assert casa_handler.use_asymetric_qkv, (
908
+ "You should set `use_asymetric_qkv` to True during inference"
909
+ )
910
+
911
+ og_shape = hidden_states.shape
912
+
913
+ # Build Q inputs
914
+ if casa_handler.use_asymetric_qkv:
915
+ q_inputs = hidden_states.flatten(0, 1)[None, ...]
916
+ if casa_handler.attention_mask is not None:
917
+ q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()]
918
+ else:
919
+ q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
920
+
921
+ # Case 1: Training or first inference step
922
+ if not self.is_streaming or self.streaming_state.offset == 0:
923
+ kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
924
+ else:
925
+ # during streaming, the KV cache including image embeddings
926
+ # will be inserted later so for now we only update the incoming queries
927
+ kv_inputs = q_inputs
928
+
929
+ # Compute QKV for the blockwise attention
930
+ bs, total_seq_len = kv_inputs.shape[:2]
931
+ hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim)
932
+ hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim)
933
+
934
+ if self.override_q_proj is None:
935
+ query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q)
936
+ else:
937
+ query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q)
938
+
939
+ if self.override_k_proj is None:
940
+ key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv)
941
+ else:
942
+ key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv)
943
+
944
+ if self.override_v_proj is None:
945
+ value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv)
946
+ else:
947
+ value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv)
948
+
949
+ # Apply position embedding at the right offset
950
+ num_queries = 0
951
+ if self.streaming and self.streaming_state.offset > 0:
952
+ num_queries = og_shape[1]
953
+
954
+ query_states = self.apply_position_embeddings(
955
+ "q", query_states, num_queries=num_queries, casa_handler=casa_handler
956
+ )
957
+ key_states = self.apply_position_embeddings(
958
+ "kv", key_states, num_queries=num_queries, casa_handler=casa_handler
959
+ )
960
+ assert flash_attn_varlen_func is not None, (
961
+ "flash_attention is not installed but required for block-wise attention"
962
+ )
963
+
964
+ # Flashattention has different efficient implem for streaming
965
+ # In that case, the KV cache has to be batched and has been extended
966
+ # to accomodate the shape of ne the new updates
967
+ if self.is_streaming:
968
+ key_states, value_states = self.streaming_state.extend_kv(
969
+ key_states=key_states, value_states=value_states
970
+ )
971
+ if casa_handler.use_asymetric_qkv:
972
+ cu_seqlens_q = casa_handler.cu_seqlens_q
973
+ max_seqlen_q = casa_handler.max_seqlen_q
974
+ else:
975
+ cu_seqlens_q = casa_handler.cu_seqlens_kv
976
+ max_seqlen_q = casa_handler.max_seqlen_kv
977
+ assert cu_seqlens_q[-1] == query_states.shape[1], (
978
+ f"{cu_seqlens_q[-1]} != {query_states.shape[1]}"
979
+ )
980
+ assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], (
981
+ f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}"
982
+ )
983
+ # for quer
984
+ attn_output: torch.Tensor = flash_attn_varlen_func(
985
+ query_states[0].to(torch.bfloat16),
986
+ key_states[0].to(torch.bfloat16),
987
+ value_states[0].to(torch.bfloat16),
988
+ cu_seqlens_q=cu_seqlens_q,
989
+ cu_seqlens_k=casa_handler.cu_seqlens_kv,
990
+ max_seqlen_q=max_seqlen_q,
991
+ max_seqlen_k=casa_handler.max_seqlen_kv,
992
+ dropout_p=0.0,
993
+ # softmax_scale=None, # defaults to 1/sqrt(d)
994
+ causal=True,
995
+ ).to(og_dtype)
996
+
997
+ attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous()
998
+ if self.override_o_proj is None:
999
+ attn_output = self.o_proj_casa(attn_output)
1000
+ else:
1001
+ attn_output = self.override_o_proj(attn_output)
1002
+
1003
+ attn_output = casa_handler.recover_text_embeds(
1004
+ attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds
1005
+ )
1006
+ attn_output = attn_output.reshape(og_shape)
1007
+
1008
+ if self.is_streaming:
1009
+ self.streaming_state.offset += attn_output.shape[1]
1010
+ return attn_output
config.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_qwen2_5vl_casa.Qwen2_5_VLCASAConfig",
5
+ "AutoModel": "modeling_qwen2_5vl_casa.V2Qwen2_5VL"
6
+ },
7
+ "bos_token_id": 151643,
8
+ "casa_attention": true,
9
+ "casa_delta_w": true,
10
+ "casa_use_asymetric_qkv": true,
11
+ "casa_windows": "images",
12
+ "eos_token_id": 151645,
13
+ "head_dim": 128,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 2048,
16
+ "image_token_id": 151655,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 11008,
19
+ "max_position_embeddings": 128000,
20
+ "max_window_layers": 70,
21
+ "model_type": "CASA_Qwen_2_5_VL_3B_LiveCC",
22
+ "num_attention_heads": 16,
23
+ "num_hidden_layers": 36,
24
+ "num_key_value_heads": 2,
25
+ "rms_norm_eps": 1e-06,
26
+ "rope_scaling": {
27
+ "mrope_section": [
28
+ 16,
29
+ 24,
30
+ 24
31
+ ],
32
+ "rope_type": "default",
33
+ "type": "default"
34
+ },
35
+ "rope_theta": 1000000.0,
36
+ "sliding_window": 32768,
37
+ "tie_word_embeddings": true,
38
+ "torch_dtype": "bfloat16",
39
+ "transformers_version": "4.51.3",
40
+ "use_cache": true,
41
+ "use_sliding_window": false,
42
+ "video_token_id": 151656,
43
+ "vision_config": {
44
+ "depth": 32,
45
+ "fullatt_block_indexes": [
46
+ 7,
47
+ 15,
48
+ 23,
49
+ 31
50
+ ],
51
+ "hidden_act": "silu",
52
+ "hidden_size": 1280,
53
+ "image_mean": [
54
+ 0.48145466,
55
+ 0.4578275,
56
+ 0.40821073
57
+ ],
58
+ "image_std": [
59
+ 0.26862954,
60
+ 0.26130258,
61
+ 0.27577711
62
+ ],
63
+ "in_channels": 3,
64
+ "in_chans": 3,
65
+ "intermediate_size": 3420,
66
+ "model_type": "qwen2_5_vl",
67
+ "num_heads": 16,
68
+ "out_dim": 2048,
69
+ "out_hidden_size": 2048,
70
+ "patch_size": 14,
71
+ "spatial_merge_size": 2,
72
+ "spatial_patch_size": 14,
73
+ "temporal_patch_size": 1,
74
+ "tokens_per_second": 2,
75
+ "window_size": 112
76
+ },
77
+ "vision_end_token_id": 151653,
78
+ "vision_start_token_id": 151652,
79
+ "vision_token_id": 151654,
80
+ "vocab_size": 151936,
81
+ "xa_layers": [
82
+ 0,
83
+ 1,
84
+ 2,
85
+ 3,
86
+ 4,
87
+ 5,
88
+ 6,
89
+ 7,
90
+ 8,
91
+ 9,
92
+ 10,
93
+ 11,
94
+ 12,
95
+ 13,
96
+ 14,
97
+ 15,
98
+ 16,
99
+ 17,
100
+ 18,
101
+ 19,
102
+ 20,
103
+ 21,
104
+ 22,
105
+ 23,
106
+ 24,
107
+ 25,
108
+ 26,
109
+ 27,
110
+ 28,
111
+ 29,
112
+ 30,
113
+ 31,
114
+ 32,
115
+ 33,
116
+ 34,
117
+ 35
118
+ ],
119
+ "xa_norm_on_images": true,
120
+ "xa_order": "parallel",
121
+ "xa_update_image_embeds": false
122
+ }
configuration_qwen2_5vl_casa.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+
3
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
4
+
5
+
6
+ class Qwen2_5_VLCASAConfig(Qwen2_5_VLConfig):
7
+ """Qwen config augmented with CASA options"""
8
+
9
+ model_type = "qwen2_5vl_casa"
10
+
11
+ def __init__(
12
+ self,
13
+ *args: Any,
14
+ # Common to all fusion mechanisms
15
+ xa_layers: None | tuple = None,
16
+ xa_order: Literal["ca_first", "parallel", "instead"] = "ca_first",
17
+ xa_norm_on_images: bool = False,
18
+ xa_update_image_embeds: bool = False,
19
+ # CASA
20
+ casa_attention: bool = False,
21
+ casa_delta_w: bool = False,
22
+ casa_windows: Literal["batch", "squashed", "images", "turn_based"] = "batch",
23
+ casa_use_asymetric_qkv: bool = True,
24
+ **kwargs: Any,
25
+ ):
26
+ super().__init__(*args, **kwargs)
27
+ self.head_dim = self.hidden_size // self.num_attention_heads
28
+ self.xa_layers = xa_layers
29
+ self.xa_order: Literal["ca_first", "parallel", "instead"] = xa_order
30
+ self.xa_norm_on_images = xa_norm_on_images
31
+ self.xa_update_image_embeds = xa_update_image_embeds
32
+ # CASA config
33
+ self.casa_attention = casa_attention
34
+ self.casa_delta_w = casa_delta_w
35
+ self.casa_windows: Literal["batch", "squashed", "images", "turn_based"] = casa_windows
36
+ self.casa_use_asymetric_qkv = casa_use_asymetric_qkv
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 715,
5
+ "transformers_version": "4.51.3"
6
+ }
image_encoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen2.5VL encoder with delayed normalization"""
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
6
+ Qwen2_5_VisionTransformerPretrainedModel,
7
+ )
8
+
9
+
10
+ def prepare_for_qwen_encoder(
11
+ x: torch.Tensor | list[torch.Tensor], mean: torch.Tensor, std: torch.Tensor
12
+ ) -> tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Preprocessing for Qwen encoder
15
+ Image mean and std come from processor.image_processor.image_mean and image_std
16
+ """
17
+ grid_thw = torch.Tensor([[1, img.shape[0], img.shape[1]] for img in x]).to(x[0].device)
18
+ hws_flatten_shape = torch.prod(grid_thw, dim=-1)
19
+ x = torch.cat(
20
+ [img.reshape((int(hws_flatten_shape[idx].item()), -1)) for idx, img in enumerate(x)],
21
+ dim=0,
22
+ )
23
+ assert x.min() >= 0.0 and x.max() <= 1.0
24
+ og_shape = x.shape
25
+ x = rearrange(x, "L (c d) -> L c d", c=3)
26
+ x = (x - mean) / std
27
+ x = x.view(og_shape).to(torch.bfloat16)
28
+ return x, grid_thw
29
+
30
+
31
+ class Qwen25VLEncoder(torch.nn.Module):
32
+ """Qwen2.5 VL encoder with pre/post processing to be compatible for
33
+ our CASA attention implementation"""
34
+
35
+ def __init__(
36
+ self,
37
+ visual: "Qwen2_5_VisionTransformerPretrainedModel",
38
+ ):
39
+ super().__init__()
40
+ self.visual = visual
41
+ self.image_mean = torch.tensor(self.visual.config.image_mean).view(1, 3, 1)
42
+ self.image_std = torch.tensor(self.visual.config.image_std).view(1, 3, 1)
43
+
44
+ def forward(
45
+ self, x: torch.Tensor | list[torch.Tensor]
46
+ ) -> dict[str, torch.Tensor | list[torch.Tensor]]:
47
+ x, grid_thw = prepare_for_qwen_encoder(
48
+ x, mean=self.image_mean.to(x[0].device), std=self.image_std.to(x[0].device)
49
+ )
50
+
51
+ grid_thw = grid_thw.type(torch.int)
52
+ assert len(x) == grid_thw.prod(dim=1).sum()
53
+ out = self.visual(x, grid_thw=grid_thw)
54
+
55
+ split_sizes = (grid_thw.prod(dim=-1) // self.visual.spatial_merge_size**2).tolist()
56
+ embeds = list(torch.split(out, split_sizes, dim=0)) # Ni * (seq, C)
57
+ return {"image_embeds": embeds, "grid_thw": grid_thw}
language_qwen2_5vl_casa.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, Literal, Optional
3
+
4
+ import torch
5
+ from transformers.cache_utils import Cache
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
8
+ Qwen2_5_VLDecoderLayer,
9
+ Qwen2_5_VLFlashAttention2,
10
+ rotate_half,
11
+ )
12
+
13
+ from .casa_attention import CASAAttention, CASAAttentionHandler
14
+ from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig
15
+
16
+
17
+ class QwenCASAAttentionHandler(CASAAttentionHandler):
18
+ """Overrides CASAAttention with the right pos embedding computation for Qwen"""
19
+
20
+ def __init__(
21
+ self,
22
+ *args: Any,
23
+ get_rope_index: Callable | None = None,
24
+ grid_thw: torch.Tensor | None = None,
25
+ position_ids_offset: int = 0,
26
+ **kwargs: Any,
27
+ ):
28
+ assert get_rope_index is not None, "get_rope_index should be given for QwenCASA"
29
+ self.get_rope_index = partial(get_rope_index, image_grid_thw=grid_thw)
30
+ self.position_ids_offset = position_ids_offset
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def compute_position_embeddings(
34
+ self,
35
+ rope_fn: Callable,
36
+ sample_lengths: list[int],
37
+ dummy_for_dtype_and_device: torch.Tensor,
38
+ ) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """Compute info required for position embeddings. Can be overriden e.g. for Qwen"""
40
+ # Here rope_fn is the "get_rope_index" function from the original mode
41
+ dummy_input_ids = torch.zeros(
42
+ (int(sum(sample_lengths)),), device=dummy_for_dtype_and_device.device, dtype=torch.long
43
+ )
44
+ # Set image token ids
45
+ dummy_input_ids[self.image_tokens_mask[:, 0]] = 151655
46
+
47
+ # required for the weird start of image tokens
48
+ # Highly recommended to use pre and post image tokens with Qwen
49
+ # Add vision start token ids (wherever a 151655 follows a 0)
50
+ start_of_images = torch.logical_and(
51
+ dummy_input_ids == 0,
52
+ torch.nn.functional.pad(dummy_input_ids[1:] == 151655, (0, 1), value=0),
53
+ )
54
+ dummy_input_ids[start_of_images] = 151652
55
+
56
+ # rebatch dummy input ids
57
+ padding_side = "left" if self.attention_mask is not None else "right"
58
+ s = list(torch.split(dummy_input_ids, self.full_batch_lengths))
59
+ mlen = max(_s.shape[0] for _s in s)
60
+ trims = [mlen - _s.shape[0] for _s in s]
61
+ dummy_input_ids = torch.stack(
62
+ [
63
+ torch.nn.functional.pad(
64
+ _s,
65
+ (
66
+ trims[i] if padding_side == "left" else 0,
67
+ trims[i] if padding_side == "right" else 0,
68
+ ),
69
+ value=-1,
70
+ )
71
+ for i, _s in enumerate(s)
72
+ ],
73
+ dim=0,
74
+ )
75
+
76
+ # We need to give attention map to rope_index in left padding
77
+ attention_mask = torch.ones_like(dummy_input_ids)
78
+ for i, t in enumerate(trims):
79
+ if padding_side == "right":
80
+ attention_mask[i, attention_mask.shape[-1] - t :] = 0
81
+ else:
82
+ attention_mask[i, :t] = 0
83
+
84
+ # compute pos embeds shape (3, bs, seq)
85
+ position_ids = (
86
+ self.get_rope_index(dummy_input_ids, attention_mask=attention_mask)[0]
87
+ + self.position_ids_offset
88
+ )
89
+
90
+ # Compute pos-ebemds and recover flattened unpadded shape
91
+ cos, sin = rope_fn(dummy_for_dtype_and_device, position_ids)
92
+ # reflatten seq
93
+ if padding_side == "right":
94
+ cos = torch.cat(
95
+ [cos[:, i : i + 1, : cos.shape[2] - t, :] for i, t in enumerate(trims)], dim=2
96
+ )
97
+ sin = torch.cat(
98
+ [sin[:, i : i + 1, : sin.shape[2] - t, :] for i, t in enumerate(trims)], dim=2
99
+ )
100
+ else:
101
+ cos = torch.cat([cos[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2)
102
+ sin = torch.cat([sin[:, i : i + 1, t:, :] for i, t in enumerate(trims)], dim=2)
103
+ return cos, sin
104
+
105
+ def get_position_embedding(
106
+ self,
107
+ key: Literal["q", "kv"],
108
+ num_queries: int = 0,
109
+ ) -> tuple[torch.Tensor, torch.Tensor] | None:
110
+ if self.position_embeds is None:
111
+ return None
112
+ cos, sin = self.position_embeds
113
+ # For Q, we only want the text-only posembeds
114
+ if key == "q":
115
+ cos, sin = (
116
+ cos[:, :, ~self.image_tokens_mask[:, 0]],
117
+ sin[:, :, ~self.image_tokens_mask[:, 0]],
118
+ )
119
+ elif key != "kv":
120
+ raise ValueError(f"Unknown key for position embedding {key}")
121
+
122
+ # Easy case: training or first step at inference: we use all the posembeds
123
+ if num_queries == 0:
124
+ return cos, sin
125
+ # If num queries is given, we need to trim for *every sample in the batch*
126
+ bls = self.full_batch_lengths if key == "kv" else self.batch_lengths
127
+ cos = [x[:, :, -num_queries:] for x in torch.split(cos, bls, dim=2)]
128
+ sin = [x[:, :, -num_queries:] for x in torch.split(sin, bls, dim=2)]
129
+ return torch.cat(cos, dim=2), torch.cat(sin, dim=2)
130
+
131
+
132
+ class QwenCASAAttention(CASAAttention):
133
+ """A CASA Attention layer compatible with Qwen"""
134
+
135
+ def __init__(
136
+ self,
137
+ config: Qwen2_5_VLCASAConfig,
138
+ layer_idx: int | None,
139
+ self_attn: torch.nn.Module | None = None,
140
+ input_layernorm_fn: Callable | None = None,
141
+ ):
142
+ # Only adding this init for typing purposes for the config
143
+ super().__init__(config, layer_idx, self_attn, input_layernorm_fn) # pyright: ignore[reportArgumentType]
144
+ assert config.rope_scaling is not None
145
+ self.mrope_section = config.rope_scaling["mrope_section"] * 2
146
+
147
+ def apply_position_embeddings(
148
+ self,
149
+ key: Literal["q", "kv"],
150
+ x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
151
+ casa_handler: CASAAttentionHandler | None,
152
+ num_queries: int = 0,
153
+ unsqueeze_dim: int = 1,
154
+ ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
155
+ """Apply position embeddings to query and key states"""
156
+ if casa_handler is not None:
157
+ posemb = casa_handler.get_position_embedding(key, num_queries=num_queries)
158
+
159
+ if posemb is not None:
160
+ x = x.transpose(1, 2).to(torch.float32)
161
+ cos, sin = posemb
162
+ cos = torch.cat(
163
+ [m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], dim=-1
164
+ ).unsqueeze(unsqueeze_dim)
165
+
166
+ sin = torch.cat(
167
+ [m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], dim=-1
168
+ ).unsqueeze(unsqueeze_dim)
169
+
170
+ x = (x * cos) + (rotate_half(x) * sin)
171
+ return x.transpose(1, 2)
172
+ return x
173
+
174
+ def init_from_config_proj(
175
+ self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
176
+ ) -> torch.nn.Linear:
177
+ """Follows modeling_qwen2_5_vl.py initialization"""
178
+ head_dim = config.hidden_size // config.num_attention_heads
179
+ if key == "q":
180
+ return torch.nn.Linear(
181
+ config.hidden_size, config.num_attention_heads * head_dim, bias=True
182
+ )
183
+ if key in {"k", "v"}:
184
+ return torch.nn.Linear(
185
+ config.hidden_size, config.num_key_value_heads * head_dim, bias=True
186
+ )
187
+ if key == "o":
188
+ return torch.nn.Linear(
189
+ config.num_attention_heads * config.head_dim, config.hidden_size, bias=False
190
+ )
191
+ raise NotImplementedError(f"Unknown key {key}")
192
+
193
+
194
+ class Qwen2_5_VLAttention_CASA(Qwen2_5_VLFlashAttention2):
195
+ """
196
+ Qwen Attention with extra CASA Attention layer
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ config: Qwen2_5_VLCASAConfig,
202
+ layer_idx: Optional[int] = None,
203
+ input_layernorm: torch.nn.Module | None = None,
204
+ ):
205
+ super().__init__(config, layer_idx) # pyright: ignore[reportArgumentType]
206
+ self.casa_attn = QwenCASAAttention(
207
+ config,
208
+ layer_idx=layer_idx,
209
+ self_attn=self,
210
+ input_layernorm_fn=input_layernorm.forward if input_layernorm is not None else None,
211
+ )
212
+ self.casa_attention_handler: CASAAttentionHandler | None = None
213
+
214
+ @classmethod
215
+ def from_qwen2_5_vl_attention(
216
+ cls, attention: Qwen2_5_VLFlashAttention2, input_layernorm: torch.nn.Module | None
217
+ ):
218
+ """Init this layer from Qwen Attention layer"""
219
+ layer_idx = attention.layer_idx
220
+ assert layer_idx is not None
221
+ new_attention = cls(attention.config, layer_idx=layer_idx, input_layernorm=input_layernorm) # pyright: ignore
222
+ new_attention.load_state_dict(attention.state_dict(), strict=False)
223
+ return new_attention
224
+
225
+ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ position_ids: Optional[torch.LongTensor] = None,
230
+ past_key_value: Optional[Cache] = None,
231
+ output_attentions: bool = False,
232
+ use_cache: bool = False,
233
+ cache_position: Optional[torch.LongTensor] = None,
234
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
235
+ ):
236
+ casa_out: None | torch.Tensor = None
237
+ if self.casa_attn is not None and self.config.xa_order in {
238
+ "parallel",
239
+ "ca_first",
240
+ "instead",
241
+ }:
242
+ casa_out = self.casa_attn(
243
+ hidden_states=hidden_states,
244
+ casa_handler=self.casa_attention_handler,
245
+ )
246
+
247
+ if self.config.xa_order == "instead":
248
+ return casa_out, None, None
249
+
250
+ if self.config.xa_order == "ca_first" and casa_out is not None:
251
+ hidden_states, casa_out = casa_out, None
252
+
253
+ attn_output, attn_weights, past_key_values = super().forward(
254
+ hidden_states,
255
+ attention_mask,
256
+ position_ids,
257
+ past_key_value,
258
+ output_attentions,
259
+ use_cache,
260
+ cache_position,
261
+ position_embeddings,
262
+ )
263
+ if self.config.xa_order == "parallel" and casa_out is not None:
264
+ attn_output = casa_out + attn_output
265
+ return attn_output, attn_weights, past_key_values
266
+
267
+
268
+ def add_casa_layers(m: torch.nn.Module, xa_layers: tuple[int, ...] | None):
269
+ """Replace Attention layer by CASA Attention layer as needed"""
270
+ if isinstance(m, Qwen2_5_VLDecoderLayer):
271
+ layer_idx = m.self_attn.layer_idx
272
+ assert layer_idx is not None
273
+ if xa_layers is None or len(xa_layers) == 0 or layer_idx in xa_layers:
274
+ m.self_attn = Qwen2_5_VLAttention_CASA.from_qwen2_5_vl_attention(
275
+ m.self_attn, input_layernorm=m.input_layernorm
276
+ )
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:959dc60fd4e5974349a1e23b79d33edd765bbcb911ea29ae594e4ba1d2872188
3
+ size 4961226720
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aab471816213f5a9cf2b2214de75a7e1d17d9b9371d75e12a8481066db458f18
3
+ size 4993905448
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32abd001c7514589845599d3cc64d739bf9c57acb548c4099cbeab7be9e57f5a
3
+ size 4994485840
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8943ad7ef101666fb28d9b0ac5b27cd1c8381c2d755ce0ddac1f9c6dea33de2d
3
+ size 1425314512
model.safetensors.index.json ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16374804480
4
+ },
5
+ "weight_map": {
6
+ "image_prefix.visual.blocks.0.attn.proj.bias": "model-00003-of-00004.safetensors",
7
+ "image_prefix.visual.blocks.0.attn.proj.weight": "model-00003-of-00004.safetensors",
8
+ "image_prefix.visual.blocks.0.attn.qkv.bias": "model-00003-of-00004.safetensors",
9
+ "image_prefix.visual.blocks.0.attn.qkv.weight": "model-00003-of-00004.safetensors",
10
+ "image_prefix.visual.blocks.0.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
11
+ "image_prefix.visual.blocks.0.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
12
+ "image_prefix.visual.blocks.0.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
13
+ "image_prefix.visual.blocks.0.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
14
+ "image_prefix.visual.blocks.0.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
15
+ "image_prefix.visual.blocks.0.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
16
+ "image_prefix.visual.blocks.0.norm1.weight": "model-00003-of-00004.safetensors",
17
+ "image_prefix.visual.blocks.0.norm2.weight": "model-00003-of-00004.safetensors",
18
+ "image_prefix.visual.blocks.1.attn.proj.bias": "model-00003-of-00004.safetensors",
19
+ "image_prefix.visual.blocks.1.attn.proj.weight": "model-00003-of-00004.safetensors",
20
+ "image_prefix.visual.blocks.1.attn.qkv.bias": "model-00003-of-00004.safetensors",
21
+ "image_prefix.visual.blocks.1.attn.qkv.weight": "model-00003-of-00004.safetensors",
22
+ "image_prefix.visual.blocks.1.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
23
+ "image_prefix.visual.blocks.1.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
24
+ "image_prefix.visual.blocks.1.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
25
+ "image_prefix.visual.blocks.1.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
26
+ "image_prefix.visual.blocks.1.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
27
+ "image_prefix.visual.blocks.1.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
28
+ "image_prefix.visual.blocks.1.norm1.weight": "model-00003-of-00004.safetensors",
29
+ "image_prefix.visual.blocks.1.norm2.weight": "model-00003-of-00004.safetensors",
30
+ "image_prefix.visual.blocks.10.attn.proj.bias": "model-00003-of-00004.safetensors",
31
+ "image_prefix.visual.blocks.10.attn.proj.weight": "model-00003-of-00004.safetensors",
32
+ "image_prefix.visual.blocks.10.attn.qkv.bias": "model-00003-of-00004.safetensors",
33
+ "image_prefix.visual.blocks.10.attn.qkv.weight": "model-00003-of-00004.safetensors",
34
+ "image_prefix.visual.blocks.10.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
35
+ "image_prefix.visual.blocks.10.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
36
+ "image_prefix.visual.blocks.10.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
37
+ "image_prefix.visual.blocks.10.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
38
+ "image_prefix.visual.blocks.10.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
39
+ "image_prefix.visual.blocks.10.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
40
+ "image_prefix.visual.blocks.10.norm1.weight": "model-00003-of-00004.safetensors",
41
+ "image_prefix.visual.blocks.10.norm2.weight": "model-00003-of-00004.safetensors",
42
+ "image_prefix.visual.blocks.11.attn.proj.bias": "model-00003-of-00004.safetensors",
43
+ "image_prefix.visual.blocks.11.attn.proj.weight": "model-00003-of-00004.safetensors",
44
+ "image_prefix.visual.blocks.11.attn.qkv.bias": "model-00003-of-00004.safetensors",
45
+ "image_prefix.visual.blocks.11.attn.qkv.weight": "model-00003-of-00004.safetensors",
46
+ "image_prefix.visual.blocks.11.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
47
+ "image_prefix.visual.blocks.11.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
48
+ "image_prefix.visual.blocks.11.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
49
+ "image_prefix.visual.blocks.11.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
50
+ "image_prefix.visual.blocks.11.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
51
+ "image_prefix.visual.blocks.11.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
52
+ "image_prefix.visual.blocks.11.norm1.weight": "model-00003-of-00004.safetensors",
53
+ "image_prefix.visual.blocks.11.norm2.weight": "model-00003-of-00004.safetensors",
54
+ "image_prefix.visual.blocks.12.attn.proj.bias": "model-00003-of-00004.safetensors",
55
+ "image_prefix.visual.blocks.12.attn.proj.weight": "model-00003-of-00004.safetensors",
56
+ "image_prefix.visual.blocks.12.attn.qkv.bias": "model-00003-of-00004.safetensors",
57
+ "image_prefix.visual.blocks.12.attn.qkv.weight": "model-00003-of-00004.safetensors",
58
+ "image_prefix.visual.blocks.12.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
59
+ "image_prefix.visual.blocks.12.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
60
+ "image_prefix.visual.blocks.12.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
61
+ "image_prefix.visual.blocks.12.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
62
+ "image_prefix.visual.blocks.12.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
63
+ "image_prefix.visual.blocks.12.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
64
+ "image_prefix.visual.blocks.12.norm1.weight": "model-00003-of-00004.safetensors",
65
+ "image_prefix.visual.blocks.12.norm2.weight": "model-00003-of-00004.safetensors",
66
+ "image_prefix.visual.blocks.13.attn.proj.bias": "model-00003-of-00004.safetensors",
67
+ "image_prefix.visual.blocks.13.attn.proj.weight": "model-00003-of-00004.safetensors",
68
+ "image_prefix.visual.blocks.13.attn.qkv.bias": "model-00003-of-00004.safetensors",
69
+ "image_prefix.visual.blocks.13.attn.qkv.weight": "model-00003-of-00004.safetensors",
70
+ "image_prefix.visual.blocks.13.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
71
+ "image_prefix.visual.blocks.13.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
72
+ "image_prefix.visual.blocks.13.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
73
+ "image_prefix.visual.blocks.13.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
74
+ "image_prefix.visual.blocks.13.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
75
+ "image_prefix.visual.blocks.13.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
76
+ "image_prefix.visual.blocks.13.norm1.weight": "model-00003-of-00004.safetensors",
77
+ "image_prefix.visual.blocks.13.norm2.weight": "model-00003-of-00004.safetensors",
78
+ "image_prefix.visual.blocks.14.attn.proj.bias": "model-00003-of-00004.safetensors",
79
+ "image_prefix.visual.blocks.14.attn.proj.weight": "model-00003-of-00004.safetensors",
80
+ "image_prefix.visual.blocks.14.attn.qkv.bias": "model-00003-of-00004.safetensors",
81
+ "image_prefix.visual.blocks.14.attn.qkv.weight": "model-00003-of-00004.safetensors",
82
+ "image_prefix.visual.blocks.14.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
83
+ "image_prefix.visual.blocks.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
84
+ "image_prefix.visual.blocks.14.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
85
+ "image_prefix.visual.blocks.14.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
86
+ "image_prefix.visual.blocks.14.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
87
+ "image_prefix.visual.blocks.14.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
88
+ "image_prefix.visual.blocks.14.norm1.weight": "model-00003-of-00004.safetensors",
89
+ "image_prefix.visual.blocks.14.norm2.weight": "model-00003-of-00004.safetensors",
90
+ "image_prefix.visual.blocks.15.attn.proj.bias": "model-00003-of-00004.safetensors",
91
+ "image_prefix.visual.blocks.15.attn.proj.weight": "model-00003-of-00004.safetensors",
92
+ "image_prefix.visual.blocks.15.attn.qkv.bias": "model-00003-of-00004.safetensors",
93
+ "image_prefix.visual.blocks.15.attn.qkv.weight": "model-00003-of-00004.safetensors",
94
+ "image_prefix.visual.blocks.15.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
95
+ "image_prefix.visual.blocks.15.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
96
+ "image_prefix.visual.blocks.15.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
97
+ "image_prefix.visual.blocks.15.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
98
+ "image_prefix.visual.blocks.15.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
99
+ "image_prefix.visual.blocks.15.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
100
+ "image_prefix.visual.blocks.15.norm1.weight": "model-00003-of-00004.safetensors",
101
+ "image_prefix.visual.blocks.15.norm2.weight": "model-00003-of-00004.safetensors",
102
+ "image_prefix.visual.blocks.16.attn.proj.bias": "model-00004-of-00004.safetensors",
103
+ "image_prefix.visual.blocks.16.attn.proj.weight": "model-00004-of-00004.safetensors",
104
+ "image_prefix.visual.blocks.16.attn.qkv.bias": "model-00004-of-00004.safetensors",
105
+ "image_prefix.visual.blocks.16.attn.qkv.weight": "model-00004-of-00004.safetensors",
106
+ "image_prefix.visual.blocks.16.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
107
+ "image_prefix.visual.blocks.16.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
108
+ "image_prefix.visual.blocks.16.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
109
+ "image_prefix.visual.blocks.16.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
110
+ "image_prefix.visual.blocks.16.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
111
+ "image_prefix.visual.blocks.16.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
112
+ "image_prefix.visual.blocks.16.norm1.weight": "model-00004-of-00004.safetensors",
113
+ "image_prefix.visual.blocks.16.norm2.weight": "model-00004-of-00004.safetensors",
114
+ "image_prefix.visual.blocks.17.attn.proj.bias": "model-00004-of-00004.safetensors",
115
+ "image_prefix.visual.blocks.17.attn.proj.weight": "model-00004-of-00004.safetensors",
116
+ "image_prefix.visual.blocks.17.attn.qkv.bias": "model-00004-of-00004.safetensors",
117
+ "image_prefix.visual.blocks.17.attn.qkv.weight": "model-00004-of-00004.safetensors",
118
+ "image_prefix.visual.blocks.17.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
119
+ "image_prefix.visual.blocks.17.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
120
+ "image_prefix.visual.blocks.17.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
121
+ "image_prefix.visual.blocks.17.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
122
+ "image_prefix.visual.blocks.17.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
123
+ "image_prefix.visual.blocks.17.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
124
+ "image_prefix.visual.blocks.17.norm1.weight": "model-00004-of-00004.safetensors",
125
+ "image_prefix.visual.blocks.17.norm2.weight": "model-00004-of-00004.safetensors",
126
+ "image_prefix.visual.blocks.18.attn.proj.bias": "model-00004-of-00004.safetensors",
127
+ "image_prefix.visual.blocks.18.attn.proj.weight": "model-00004-of-00004.safetensors",
128
+ "image_prefix.visual.blocks.18.attn.qkv.bias": "model-00004-of-00004.safetensors",
129
+ "image_prefix.visual.blocks.18.attn.qkv.weight": "model-00004-of-00004.safetensors",
130
+ "image_prefix.visual.blocks.18.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
131
+ "image_prefix.visual.blocks.18.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
132
+ "image_prefix.visual.blocks.18.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
133
+ "image_prefix.visual.blocks.18.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
134
+ "image_prefix.visual.blocks.18.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
135
+ "image_prefix.visual.blocks.18.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
136
+ "image_prefix.visual.blocks.18.norm1.weight": "model-00004-of-00004.safetensors",
137
+ "image_prefix.visual.blocks.18.norm2.weight": "model-00004-of-00004.safetensors",
138
+ "image_prefix.visual.blocks.19.attn.proj.bias": "model-00004-of-00004.safetensors",
139
+ "image_prefix.visual.blocks.19.attn.proj.weight": "model-00004-of-00004.safetensors",
140
+ "image_prefix.visual.blocks.19.attn.qkv.bias": "model-00004-of-00004.safetensors",
141
+ "image_prefix.visual.blocks.19.attn.qkv.weight": "model-00004-of-00004.safetensors",
142
+ "image_prefix.visual.blocks.19.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
143
+ "image_prefix.visual.blocks.19.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
144
+ "image_prefix.visual.blocks.19.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
145
+ "image_prefix.visual.blocks.19.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
146
+ "image_prefix.visual.blocks.19.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
147
+ "image_prefix.visual.blocks.19.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
148
+ "image_prefix.visual.blocks.19.norm1.weight": "model-00004-of-00004.safetensors",
149
+ "image_prefix.visual.blocks.19.norm2.weight": "model-00004-of-00004.safetensors",
150
+ "image_prefix.visual.blocks.2.attn.proj.bias": "model-00003-of-00004.safetensors",
151
+ "image_prefix.visual.blocks.2.attn.proj.weight": "model-00003-of-00004.safetensors",
152
+ "image_prefix.visual.blocks.2.attn.qkv.bias": "model-00003-of-00004.safetensors",
153
+ "image_prefix.visual.blocks.2.attn.qkv.weight": "model-00003-of-00004.safetensors",
154
+ "image_prefix.visual.blocks.2.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
155
+ "image_prefix.visual.blocks.2.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
156
+ "image_prefix.visual.blocks.2.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
157
+ "image_prefix.visual.blocks.2.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
158
+ "image_prefix.visual.blocks.2.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
159
+ "image_prefix.visual.blocks.2.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
160
+ "image_prefix.visual.blocks.2.norm1.weight": "model-00003-of-00004.safetensors",
161
+ "image_prefix.visual.blocks.2.norm2.weight": "model-00003-of-00004.safetensors",
162
+ "image_prefix.visual.blocks.20.attn.proj.bias": "model-00004-of-00004.safetensors",
163
+ "image_prefix.visual.blocks.20.attn.proj.weight": "model-00004-of-00004.safetensors",
164
+ "image_prefix.visual.blocks.20.attn.qkv.bias": "model-00004-of-00004.safetensors",
165
+ "image_prefix.visual.blocks.20.attn.qkv.weight": "model-00004-of-00004.safetensors",
166
+ "image_prefix.visual.blocks.20.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
167
+ "image_prefix.visual.blocks.20.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
168
+ "image_prefix.visual.blocks.20.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
169
+ "image_prefix.visual.blocks.20.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
170
+ "image_prefix.visual.blocks.20.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
171
+ "image_prefix.visual.blocks.20.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
172
+ "image_prefix.visual.blocks.20.norm1.weight": "model-00004-of-00004.safetensors",
173
+ "image_prefix.visual.blocks.20.norm2.weight": "model-00004-of-00004.safetensors",
174
+ "image_prefix.visual.blocks.21.attn.proj.bias": "model-00004-of-00004.safetensors",
175
+ "image_prefix.visual.blocks.21.attn.proj.weight": "model-00004-of-00004.safetensors",
176
+ "image_prefix.visual.blocks.21.attn.qkv.bias": "model-00004-of-00004.safetensors",
177
+ "image_prefix.visual.blocks.21.attn.qkv.weight": "model-00004-of-00004.safetensors",
178
+ "image_prefix.visual.blocks.21.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
179
+ "image_prefix.visual.blocks.21.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
180
+ "image_prefix.visual.blocks.21.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
181
+ "image_prefix.visual.blocks.21.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
182
+ "image_prefix.visual.blocks.21.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
183
+ "image_prefix.visual.blocks.21.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
184
+ "image_prefix.visual.blocks.21.norm1.weight": "model-00004-of-00004.safetensors",
185
+ "image_prefix.visual.blocks.21.norm2.weight": "model-00004-of-00004.safetensors",
186
+ "image_prefix.visual.blocks.22.attn.proj.bias": "model-00004-of-00004.safetensors",
187
+ "image_prefix.visual.blocks.22.attn.proj.weight": "model-00004-of-00004.safetensors",
188
+ "image_prefix.visual.blocks.22.attn.qkv.bias": "model-00004-of-00004.safetensors",
189
+ "image_prefix.visual.blocks.22.attn.qkv.weight": "model-00004-of-00004.safetensors",
190
+ "image_prefix.visual.blocks.22.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
191
+ "image_prefix.visual.blocks.22.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
192
+ "image_prefix.visual.blocks.22.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
193
+ "image_prefix.visual.blocks.22.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
194
+ "image_prefix.visual.blocks.22.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
195
+ "image_prefix.visual.blocks.22.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
196
+ "image_prefix.visual.blocks.22.norm1.weight": "model-00004-of-00004.safetensors",
197
+ "image_prefix.visual.blocks.22.norm2.weight": "model-00004-of-00004.safetensors",
198
+ "image_prefix.visual.blocks.23.attn.proj.bias": "model-00004-of-00004.safetensors",
199
+ "image_prefix.visual.blocks.23.attn.proj.weight": "model-00004-of-00004.safetensors",
200
+ "image_prefix.visual.blocks.23.attn.qkv.bias": "model-00004-of-00004.safetensors",
201
+ "image_prefix.visual.blocks.23.attn.qkv.weight": "model-00004-of-00004.safetensors",
202
+ "image_prefix.visual.blocks.23.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
203
+ "image_prefix.visual.blocks.23.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
204
+ "image_prefix.visual.blocks.23.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
205
+ "image_prefix.visual.blocks.23.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
206
+ "image_prefix.visual.blocks.23.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
207
+ "image_prefix.visual.blocks.23.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
208
+ "image_prefix.visual.blocks.23.norm1.weight": "model-00004-of-00004.safetensors",
209
+ "image_prefix.visual.blocks.23.norm2.weight": "model-00004-of-00004.safetensors",
210
+ "image_prefix.visual.blocks.24.attn.proj.bias": "model-00004-of-00004.safetensors",
211
+ "image_prefix.visual.blocks.24.attn.proj.weight": "model-00004-of-00004.safetensors",
212
+ "image_prefix.visual.blocks.24.attn.qkv.bias": "model-00004-of-00004.safetensors",
213
+ "image_prefix.visual.blocks.24.attn.qkv.weight": "model-00004-of-00004.safetensors",
214
+ "image_prefix.visual.blocks.24.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
215
+ "image_prefix.visual.blocks.24.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
216
+ "image_prefix.visual.blocks.24.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
217
+ "image_prefix.visual.blocks.24.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
218
+ "image_prefix.visual.blocks.24.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
219
+ "image_prefix.visual.blocks.24.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
220
+ "image_prefix.visual.blocks.24.norm1.weight": "model-00004-of-00004.safetensors",
221
+ "image_prefix.visual.blocks.24.norm2.weight": "model-00004-of-00004.safetensors",
222
+ "image_prefix.visual.blocks.25.attn.proj.bias": "model-00004-of-00004.safetensors",
223
+ "image_prefix.visual.blocks.25.attn.proj.weight": "model-00004-of-00004.safetensors",
224
+ "image_prefix.visual.blocks.25.attn.qkv.bias": "model-00004-of-00004.safetensors",
225
+ "image_prefix.visual.blocks.25.attn.qkv.weight": "model-00004-of-00004.safetensors",
226
+ "image_prefix.visual.blocks.25.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
227
+ "image_prefix.visual.blocks.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
228
+ "image_prefix.visual.blocks.25.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
229
+ "image_prefix.visual.blocks.25.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
230
+ "image_prefix.visual.blocks.25.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
231
+ "image_prefix.visual.blocks.25.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
232
+ "image_prefix.visual.blocks.25.norm1.weight": "model-00004-of-00004.safetensors",
233
+ "image_prefix.visual.blocks.25.norm2.weight": "model-00004-of-00004.safetensors",
234
+ "image_prefix.visual.blocks.26.attn.proj.bias": "model-00004-of-00004.safetensors",
235
+ "image_prefix.visual.blocks.26.attn.proj.weight": "model-00004-of-00004.safetensors",
236
+ "image_prefix.visual.blocks.26.attn.qkv.bias": "model-00004-of-00004.safetensors",
237
+ "image_prefix.visual.blocks.26.attn.qkv.weight": "model-00004-of-00004.safetensors",
238
+ "image_prefix.visual.blocks.26.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
239
+ "image_prefix.visual.blocks.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
240
+ "image_prefix.visual.blocks.26.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
241
+ "image_prefix.visual.blocks.26.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
242
+ "image_prefix.visual.blocks.26.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
243
+ "image_prefix.visual.blocks.26.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
244
+ "image_prefix.visual.blocks.26.norm1.weight": "model-00004-of-00004.safetensors",
245
+ "image_prefix.visual.blocks.26.norm2.weight": "model-00004-of-00004.safetensors",
246
+ "image_prefix.visual.blocks.27.attn.proj.bias": "model-00004-of-00004.safetensors",
247
+ "image_prefix.visual.blocks.27.attn.proj.weight": "model-00004-of-00004.safetensors",
248
+ "image_prefix.visual.blocks.27.attn.qkv.bias": "model-00004-of-00004.safetensors",
249
+ "image_prefix.visual.blocks.27.attn.qkv.weight": "model-00004-of-00004.safetensors",
250
+ "image_prefix.visual.blocks.27.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
251
+ "image_prefix.visual.blocks.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
252
+ "image_prefix.visual.blocks.27.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
253
+ "image_prefix.visual.blocks.27.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
254
+ "image_prefix.visual.blocks.27.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
255
+ "image_prefix.visual.blocks.27.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
256
+ "image_prefix.visual.blocks.27.norm1.weight": "model-00004-of-00004.safetensors",
257
+ "image_prefix.visual.blocks.27.norm2.weight": "model-00004-of-00004.safetensors",
258
+ "image_prefix.visual.blocks.28.attn.proj.bias": "model-00004-of-00004.safetensors",
259
+ "image_prefix.visual.blocks.28.attn.proj.weight": "model-00004-of-00004.safetensors",
260
+ "image_prefix.visual.blocks.28.attn.qkv.bias": "model-00004-of-00004.safetensors",
261
+ "image_prefix.visual.blocks.28.attn.qkv.weight": "model-00004-of-00004.safetensors",
262
+ "image_prefix.visual.blocks.28.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
263
+ "image_prefix.visual.blocks.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
264
+ "image_prefix.visual.blocks.28.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
265
+ "image_prefix.visual.blocks.28.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
266
+ "image_prefix.visual.blocks.28.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
267
+ "image_prefix.visual.blocks.28.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
268
+ "image_prefix.visual.blocks.28.norm1.weight": "model-00004-of-00004.safetensors",
269
+ "image_prefix.visual.blocks.28.norm2.weight": "model-00004-of-00004.safetensors",
270
+ "image_prefix.visual.blocks.29.attn.proj.bias": "model-00004-of-00004.safetensors",
271
+ "image_prefix.visual.blocks.29.attn.proj.weight": "model-00004-of-00004.safetensors",
272
+ "image_prefix.visual.blocks.29.attn.qkv.bias": "model-00004-of-00004.safetensors",
273
+ "image_prefix.visual.blocks.29.attn.qkv.weight": "model-00004-of-00004.safetensors",
274
+ "image_prefix.visual.blocks.29.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
275
+ "image_prefix.visual.blocks.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
276
+ "image_prefix.visual.blocks.29.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
277
+ "image_prefix.visual.blocks.29.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
278
+ "image_prefix.visual.blocks.29.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
279
+ "image_prefix.visual.blocks.29.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
280
+ "image_prefix.visual.blocks.29.norm1.weight": "model-00004-of-00004.safetensors",
281
+ "image_prefix.visual.blocks.29.norm2.weight": "model-00004-of-00004.safetensors",
282
+ "image_prefix.visual.blocks.3.attn.proj.bias": "model-00003-of-00004.safetensors",
283
+ "image_prefix.visual.blocks.3.attn.proj.weight": "model-00003-of-00004.safetensors",
284
+ "image_prefix.visual.blocks.3.attn.qkv.bias": "model-00003-of-00004.safetensors",
285
+ "image_prefix.visual.blocks.3.attn.qkv.weight": "model-00003-of-00004.safetensors",
286
+ "image_prefix.visual.blocks.3.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
287
+ "image_prefix.visual.blocks.3.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
288
+ "image_prefix.visual.blocks.3.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
289
+ "image_prefix.visual.blocks.3.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
290
+ "image_prefix.visual.blocks.3.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
291
+ "image_prefix.visual.blocks.3.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
292
+ "image_prefix.visual.blocks.3.norm1.weight": "model-00003-of-00004.safetensors",
293
+ "image_prefix.visual.blocks.3.norm2.weight": "model-00003-of-00004.safetensors",
294
+ "image_prefix.visual.blocks.30.attn.proj.bias": "model-00004-of-00004.safetensors",
295
+ "image_prefix.visual.blocks.30.attn.proj.weight": "model-00004-of-00004.safetensors",
296
+ "image_prefix.visual.blocks.30.attn.qkv.bias": "model-00004-of-00004.safetensors",
297
+ "image_prefix.visual.blocks.30.attn.qkv.weight": "model-00004-of-00004.safetensors",
298
+ "image_prefix.visual.blocks.30.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
299
+ "image_prefix.visual.blocks.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
300
+ "image_prefix.visual.blocks.30.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
301
+ "image_prefix.visual.blocks.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
302
+ "image_prefix.visual.blocks.30.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
303
+ "image_prefix.visual.blocks.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
304
+ "image_prefix.visual.blocks.30.norm1.weight": "model-00004-of-00004.safetensors",
305
+ "image_prefix.visual.blocks.30.norm2.weight": "model-00004-of-00004.safetensors",
306
+ "image_prefix.visual.blocks.31.attn.proj.bias": "model-00004-of-00004.safetensors",
307
+ "image_prefix.visual.blocks.31.attn.proj.weight": "model-00004-of-00004.safetensors",
308
+ "image_prefix.visual.blocks.31.attn.qkv.bias": "model-00004-of-00004.safetensors",
309
+ "image_prefix.visual.blocks.31.attn.qkv.weight": "model-00004-of-00004.safetensors",
310
+ "image_prefix.visual.blocks.31.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
311
+ "image_prefix.visual.blocks.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
312
+ "image_prefix.visual.blocks.31.mlp.gate_proj.bias": "model-00004-of-00004.safetensors",
313
+ "image_prefix.visual.blocks.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
314
+ "image_prefix.visual.blocks.31.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
315
+ "image_prefix.visual.blocks.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
316
+ "image_prefix.visual.blocks.31.norm1.weight": "model-00004-of-00004.safetensors",
317
+ "image_prefix.visual.blocks.31.norm2.weight": "model-00004-of-00004.safetensors",
318
+ "image_prefix.visual.blocks.4.attn.proj.bias": "model-00003-of-00004.safetensors",
319
+ "image_prefix.visual.blocks.4.attn.proj.weight": "model-00003-of-00004.safetensors",
320
+ "image_prefix.visual.blocks.4.attn.qkv.bias": "model-00003-of-00004.safetensors",
321
+ "image_prefix.visual.blocks.4.attn.qkv.weight": "model-00003-of-00004.safetensors",
322
+ "image_prefix.visual.blocks.4.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
323
+ "image_prefix.visual.blocks.4.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
324
+ "image_prefix.visual.blocks.4.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
325
+ "image_prefix.visual.blocks.4.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
326
+ "image_prefix.visual.blocks.4.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
327
+ "image_prefix.visual.blocks.4.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
328
+ "image_prefix.visual.blocks.4.norm1.weight": "model-00003-of-00004.safetensors",
329
+ "image_prefix.visual.blocks.4.norm2.weight": "model-00003-of-00004.safetensors",
330
+ "image_prefix.visual.blocks.5.attn.proj.bias": "model-00003-of-00004.safetensors",
331
+ "image_prefix.visual.blocks.5.attn.proj.weight": "model-00003-of-00004.safetensors",
332
+ "image_prefix.visual.blocks.5.attn.qkv.bias": "model-00003-of-00004.safetensors",
333
+ "image_prefix.visual.blocks.5.attn.qkv.weight": "model-00003-of-00004.safetensors",
334
+ "image_prefix.visual.blocks.5.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
335
+ "image_prefix.visual.blocks.5.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
336
+ "image_prefix.visual.blocks.5.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
337
+ "image_prefix.visual.blocks.5.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
338
+ "image_prefix.visual.blocks.5.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
339
+ "image_prefix.visual.blocks.5.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
340
+ "image_prefix.visual.blocks.5.norm1.weight": "model-00003-of-00004.safetensors",
341
+ "image_prefix.visual.blocks.5.norm2.weight": "model-00003-of-00004.safetensors",
342
+ "image_prefix.visual.blocks.6.attn.proj.bias": "model-00003-of-00004.safetensors",
343
+ "image_prefix.visual.blocks.6.attn.proj.weight": "model-00003-of-00004.safetensors",
344
+ "image_prefix.visual.blocks.6.attn.qkv.bias": "model-00003-of-00004.safetensors",
345
+ "image_prefix.visual.blocks.6.attn.qkv.weight": "model-00003-of-00004.safetensors",
346
+ "image_prefix.visual.blocks.6.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
347
+ "image_prefix.visual.blocks.6.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
348
+ "image_prefix.visual.blocks.6.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
349
+ "image_prefix.visual.blocks.6.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
350
+ "image_prefix.visual.blocks.6.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
351
+ "image_prefix.visual.blocks.6.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
352
+ "image_prefix.visual.blocks.6.norm1.weight": "model-00003-of-00004.safetensors",
353
+ "image_prefix.visual.blocks.6.norm2.weight": "model-00003-of-00004.safetensors",
354
+ "image_prefix.visual.blocks.7.attn.proj.bias": "model-00003-of-00004.safetensors",
355
+ "image_prefix.visual.blocks.7.attn.proj.weight": "model-00003-of-00004.safetensors",
356
+ "image_prefix.visual.blocks.7.attn.qkv.bias": "model-00003-of-00004.safetensors",
357
+ "image_prefix.visual.blocks.7.attn.qkv.weight": "model-00003-of-00004.safetensors",
358
+ "image_prefix.visual.blocks.7.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
359
+ "image_prefix.visual.blocks.7.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
360
+ "image_prefix.visual.blocks.7.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
361
+ "image_prefix.visual.blocks.7.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
362
+ "image_prefix.visual.blocks.7.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
363
+ "image_prefix.visual.blocks.7.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
364
+ "image_prefix.visual.blocks.7.norm1.weight": "model-00003-of-00004.safetensors",
365
+ "image_prefix.visual.blocks.7.norm2.weight": "model-00003-of-00004.safetensors",
366
+ "image_prefix.visual.blocks.8.attn.proj.bias": "model-00003-of-00004.safetensors",
367
+ "image_prefix.visual.blocks.8.attn.proj.weight": "model-00003-of-00004.safetensors",
368
+ "image_prefix.visual.blocks.8.attn.qkv.bias": "model-00003-of-00004.safetensors",
369
+ "image_prefix.visual.blocks.8.attn.qkv.weight": "model-00003-of-00004.safetensors",
370
+ "image_prefix.visual.blocks.8.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
371
+ "image_prefix.visual.blocks.8.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
372
+ "image_prefix.visual.blocks.8.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
373
+ "image_prefix.visual.blocks.8.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
374
+ "image_prefix.visual.blocks.8.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
375
+ "image_prefix.visual.blocks.8.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
376
+ "image_prefix.visual.blocks.8.norm1.weight": "model-00003-of-00004.safetensors",
377
+ "image_prefix.visual.blocks.8.norm2.weight": "model-00003-of-00004.safetensors",
378
+ "image_prefix.visual.blocks.9.attn.proj.bias": "model-00003-of-00004.safetensors",
379
+ "image_prefix.visual.blocks.9.attn.proj.weight": "model-00003-of-00004.safetensors",
380
+ "image_prefix.visual.blocks.9.attn.qkv.bias": "model-00003-of-00004.safetensors",
381
+ "image_prefix.visual.blocks.9.attn.qkv.weight": "model-00003-of-00004.safetensors",
382
+ "image_prefix.visual.blocks.9.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
383
+ "image_prefix.visual.blocks.9.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
384
+ "image_prefix.visual.blocks.9.mlp.gate_proj.bias": "model-00003-of-00004.safetensors",
385
+ "image_prefix.visual.blocks.9.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
386
+ "image_prefix.visual.blocks.9.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
387
+ "image_prefix.visual.blocks.9.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
388
+ "image_prefix.visual.blocks.9.norm1.weight": "model-00003-of-00004.safetensors",
389
+ "image_prefix.visual.blocks.9.norm2.weight": "model-00003-of-00004.safetensors",
390
+ "image_prefix.visual.merger.ln_q.weight": "model-00004-of-00004.safetensors",
391
+ "image_prefix.visual.merger.mlp.0.bias": "model-00004-of-00004.safetensors",
392
+ "image_prefix.visual.merger.mlp.0.weight": "model-00004-of-00004.safetensors",
393
+ "image_prefix.visual.merger.mlp.2.bias": "model-00004-of-00004.safetensors",
394
+ "image_prefix.visual.merger.mlp.2.weight": "model-00004-of-00004.safetensors",
395
+ "image_prefix.visual.patch_embed.proj.weight": "model-00003-of-00004.safetensors",
396
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
397
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
398
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
399
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
400
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.0.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
403
+ "model.layers.0.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.0.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.0.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
406
+ "model.layers.0.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
407
+ "model.layers.0.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
408
+ "model.layers.0.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
409
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
410
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
411
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
412
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
413
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
414
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
415
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
416
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
417
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
418
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
419
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
420
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
421
+ "model.layers.1.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
422
+ "model.layers.1.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
423
+ "model.layers.1.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
424
+ "model.layers.1.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
425
+ "model.layers.1.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
426
+ "model.layers.1.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
427
+ "model.layers.1.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
428
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
429
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
430
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
431
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
432
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
433
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
434
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
435
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
436
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
437
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
438
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
439
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
440
+ "model.layers.10.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
441
+ "model.layers.10.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
442
+ "model.layers.10.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
443
+ "model.layers.10.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
444
+ "model.layers.10.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
445
+ "model.layers.10.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
446
+ "model.layers.10.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
447
+ "model.layers.10.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
448
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
449
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
450
+ "model.layers.10.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
451
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
452
+ "model.layers.10.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
453
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
454
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
455
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
456
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
457
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
458
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
459
+ "model.layers.11.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
460
+ "model.layers.11.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
461
+ "model.layers.11.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
462
+ "model.layers.11.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
463
+ "model.layers.11.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
464
+ "model.layers.11.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
465
+ "model.layers.11.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
466
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
467
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
468
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
469
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
470
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
471
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
472
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
473
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
474
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
475
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
476
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
477
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
478
+ "model.layers.12.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
479
+ "model.layers.12.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
480
+ "model.layers.12.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
481
+ "model.layers.12.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
482
+ "model.layers.12.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
483
+ "model.layers.12.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
484
+ "model.layers.12.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
485
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
486
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
487
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
488
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
489
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
490
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
491
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
492
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
493
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
494
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
495
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
496
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
497
+ "model.layers.13.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
498
+ "model.layers.13.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
499
+ "model.layers.13.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
500
+ "model.layers.13.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
501
+ "model.layers.13.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
502
+ "model.layers.13.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
503
+ "model.layers.13.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
504
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
505
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
506
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
507
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
508
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
509
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
510
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
511
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
512
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
513
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
514
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
515
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
516
+ "model.layers.14.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
517
+ "model.layers.14.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
518
+ "model.layers.14.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
519
+ "model.layers.14.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
520
+ "model.layers.14.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
521
+ "model.layers.14.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
522
+ "model.layers.14.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
523
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
524
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
525
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
526
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
527
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
528
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
529
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
530
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
531
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
532
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
533
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
534
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
535
+ "model.layers.15.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
536
+ "model.layers.15.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
537
+ "model.layers.15.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
538
+ "model.layers.15.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
539
+ "model.layers.15.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
540
+ "model.layers.15.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
541
+ "model.layers.15.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
542
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
543
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
544
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
545
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
546
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
547
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
548
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
549
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
550
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
551
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
552
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
553
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
554
+ "model.layers.16.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
555
+ "model.layers.16.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
556
+ "model.layers.16.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
557
+ "model.layers.16.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
558
+ "model.layers.16.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
559
+ "model.layers.16.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
560
+ "model.layers.16.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
561
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
562
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
563
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
564
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
565
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
566
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
567
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
568
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
569
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
570
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
571
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
572
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
573
+ "model.layers.17.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
574
+ "model.layers.17.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
575
+ "model.layers.17.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
576
+ "model.layers.17.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
577
+ "model.layers.17.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
578
+ "model.layers.17.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
579
+ "model.layers.17.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
580
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
581
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
582
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
583
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
584
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
585
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
586
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
587
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
588
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
589
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
590
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
591
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
592
+ "model.layers.18.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
593
+ "model.layers.18.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
594
+ "model.layers.18.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
595
+ "model.layers.18.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
596
+ "model.layers.18.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
597
+ "model.layers.18.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
598
+ "model.layers.18.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
599
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
600
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
601
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
602
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
603
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
604
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
605
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
606
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
607
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
608
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
609
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
610
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
611
+ "model.layers.19.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
612
+ "model.layers.19.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
613
+ "model.layers.19.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
614
+ "model.layers.19.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
615
+ "model.layers.19.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
616
+ "model.layers.19.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
617
+ "model.layers.19.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
618
+ "model.layers.19.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
619
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
620
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
621
+ "model.layers.19.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
622
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
623
+ "model.layers.19.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
624
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
625
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
626
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
627
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
628
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
629
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
630
+ "model.layers.2.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
631
+ "model.layers.2.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
632
+ "model.layers.2.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
633
+ "model.layers.2.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
634
+ "model.layers.2.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
635
+ "model.layers.2.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
636
+ "model.layers.2.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
637
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
638
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
639
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
640
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
641
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
642
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
643
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
644
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
645
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
646
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
647
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
648
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
649
+ "model.layers.20.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
650
+ "model.layers.20.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
651
+ "model.layers.20.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
652
+ "model.layers.20.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
653
+ "model.layers.20.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
654
+ "model.layers.20.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
655
+ "model.layers.20.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
656
+ "model.layers.20.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
657
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
658
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
659
+ "model.layers.20.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
660
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
661
+ "model.layers.20.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
662
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
663
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
664
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
665
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
666
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
667
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
668
+ "model.layers.21.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
669
+ "model.layers.21.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
670
+ "model.layers.21.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
671
+ "model.layers.21.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
672
+ "model.layers.21.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
673
+ "model.layers.21.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
674
+ "model.layers.21.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
675
+ "model.layers.21.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
676
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
677
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
678
+ "model.layers.21.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
679
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
680
+ "model.layers.21.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
681
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
682
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00004.safetensors",
683
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
684
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
685
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
686
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
687
+ "model.layers.22.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
688
+ "model.layers.22.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
689
+ "model.layers.22.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
690
+ "model.layers.22.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
691
+ "model.layers.22.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
692
+ "model.layers.22.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
693
+ "model.layers.22.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
694
+ "model.layers.22.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
695
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
696
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
697
+ "model.layers.22.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
698
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
699
+ "model.layers.22.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
700
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
701
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00004.safetensors",
702
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
703
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
704
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
705
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
706
+ "model.layers.23.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
707
+ "model.layers.23.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
708
+ "model.layers.23.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
709
+ "model.layers.23.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
710
+ "model.layers.23.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
711
+ "model.layers.23.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
712
+ "model.layers.23.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
713
+ "model.layers.23.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
714
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
715
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
716
+ "model.layers.23.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
717
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
718
+ "model.layers.23.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
719
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
720
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00004.safetensors",
721
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
722
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
723
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
724
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
725
+ "model.layers.24.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
726
+ "model.layers.24.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
727
+ "model.layers.24.self_attn.casa_attn.o_proj_casa.weight": "model-00002-of-00004.safetensors",
728
+ "model.layers.24.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
729
+ "model.layers.24.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
730
+ "model.layers.24.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
731
+ "model.layers.24.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
732
+ "model.layers.24.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
733
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
734
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
735
+ "model.layers.24.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
736
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
737
+ "model.layers.24.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
738
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
739
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
740
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
741
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
742
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
743
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
744
+ "model.layers.25.self_attn.casa_attn.k_proj_casa.bias": "model-00002-of-00004.safetensors",
745
+ "model.layers.25.self_attn.casa_attn.k_proj_casa.weight": "model-00002-of-00004.safetensors",
746
+ "model.layers.25.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
747
+ "model.layers.25.self_attn.casa_attn.q_proj_casa.bias": "model-00002-of-00004.safetensors",
748
+ "model.layers.25.self_attn.casa_attn.q_proj_casa.weight": "model-00002-of-00004.safetensors",
749
+ "model.layers.25.self_attn.casa_attn.v_proj_casa.bias": "model-00002-of-00004.safetensors",
750
+ "model.layers.25.self_attn.casa_attn.v_proj_casa.weight": "model-00002-of-00004.safetensors",
751
+ "model.layers.25.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
752
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
753
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
754
+ "model.layers.25.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
755
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
756
+ "model.layers.25.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
757
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
758
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
759
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
760
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
761
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
762
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
763
+ "model.layers.26.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
764
+ "model.layers.26.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
765
+ "model.layers.26.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
766
+ "model.layers.26.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
767
+ "model.layers.26.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
768
+ "model.layers.26.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
769
+ "model.layers.26.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
770
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
771
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
772
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
773
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
774
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
775
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
776
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
777
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
778
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
779
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
780
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
781
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
782
+ "model.layers.27.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
783
+ "model.layers.27.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
784
+ "model.layers.27.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
785
+ "model.layers.27.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
786
+ "model.layers.27.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
787
+ "model.layers.27.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
788
+ "model.layers.27.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
789
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
790
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
791
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
792
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
793
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
794
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
795
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
796
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
797
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
798
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
799
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
800
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
801
+ "model.layers.28.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
802
+ "model.layers.28.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
803
+ "model.layers.28.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
804
+ "model.layers.28.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
805
+ "model.layers.28.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
806
+ "model.layers.28.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
807
+ "model.layers.28.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
808
+ "model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
809
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
810
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
811
+ "model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
812
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
813
+ "model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
814
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
815
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
816
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
817
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
818
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
819
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
820
+ "model.layers.29.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
821
+ "model.layers.29.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
822
+ "model.layers.29.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
823
+ "model.layers.29.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
824
+ "model.layers.29.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
825
+ "model.layers.29.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
826
+ "model.layers.29.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
827
+ "model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
828
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
829
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
830
+ "model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
831
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
832
+ "model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
833
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
834
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
835
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
836
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
837
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
838
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
839
+ "model.layers.3.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
840
+ "model.layers.3.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
841
+ "model.layers.3.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
842
+ "model.layers.3.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
843
+ "model.layers.3.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
844
+ "model.layers.3.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
845
+ "model.layers.3.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
846
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
847
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
848
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
849
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
850
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
851
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
852
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
853
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
854
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
855
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
856
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
857
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
858
+ "model.layers.30.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
859
+ "model.layers.30.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
860
+ "model.layers.30.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
861
+ "model.layers.30.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
862
+ "model.layers.30.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
863
+ "model.layers.30.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
864
+ "model.layers.30.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
865
+ "model.layers.30.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
866
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
867
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
868
+ "model.layers.30.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
869
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
870
+ "model.layers.30.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
871
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
872
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
873
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
874
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
875
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
876
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
877
+ "model.layers.31.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
878
+ "model.layers.31.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
879
+ "model.layers.31.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
880
+ "model.layers.31.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
881
+ "model.layers.31.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
882
+ "model.layers.31.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
883
+ "model.layers.31.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
884
+ "model.layers.31.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
885
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
886
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
887
+ "model.layers.31.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
888
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
889
+ "model.layers.31.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
890
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
891
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
892
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
893
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
894
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
895
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
896
+ "model.layers.32.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
897
+ "model.layers.32.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
898
+ "model.layers.32.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
899
+ "model.layers.32.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
900
+ "model.layers.32.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
901
+ "model.layers.32.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
902
+ "model.layers.32.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
903
+ "model.layers.32.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
904
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
905
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
906
+ "model.layers.32.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
907
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
908
+ "model.layers.32.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
909
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
910
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
911
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
912
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
913
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
914
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
915
+ "model.layers.33.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
916
+ "model.layers.33.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
917
+ "model.layers.33.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
918
+ "model.layers.33.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
919
+ "model.layers.33.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
920
+ "model.layers.33.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
921
+ "model.layers.33.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
922
+ "model.layers.33.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
923
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
924
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
925
+ "model.layers.33.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
926
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
927
+ "model.layers.33.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
928
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
929
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00004.safetensors",
930
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
931
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
932
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
933
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
934
+ "model.layers.34.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
935
+ "model.layers.34.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
936
+ "model.layers.34.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
937
+ "model.layers.34.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
938
+ "model.layers.34.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
939
+ "model.layers.34.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
940
+ "model.layers.34.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
941
+ "model.layers.34.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
942
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
943
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
944
+ "model.layers.34.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
945
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
946
+ "model.layers.34.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
947
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
948
+ "model.layers.35.input_layernorm.weight": "model-00003-of-00004.safetensors",
949
+ "model.layers.35.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
950
+ "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
951
+ "model.layers.35.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
952
+ "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
953
+ "model.layers.35.self_attn.casa_attn.k_proj_casa.bias": "model-00003-of-00004.safetensors",
954
+ "model.layers.35.self_attn.casa_attn.k_proj_casa.weight": "model-00003-of-00004.safetensors",
955
+ "model.layers.35.self_attn.casa_attn.o_proj_casa.weight": "model-00003-of-00004.safetensors",
956
+ "model.layers.35.self_attn.casa_attn.q_proj_casa.bias": "model-00003-of-00004.safetensors",
957
+ "model.layers.35.self_attn.casa_attn.q_proj_casa.weight": "model-00003-of-00004.safetensors",
958
+ "model.layers.35.self_attn.casa_attn.v_proj_casa.bias": "model-00003-of-00004.safetensors",
959
+ "model.layers.35.self_attn.casa_attn.v_proj_casa.weight": "model-00003-of-00004.safetensors",
960
+ "model.layers.35.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
961
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
962
+ "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
963
+ "model.layers.35.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
964
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
965
+ "model.layers.35.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
966
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
967
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
968
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
969
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
970
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
971
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
972
+ "model.layers.4.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
973
+ "model.layers.4.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
974
+ "model.layers.4.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
975
+ "model.layers.4.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
976
+ "model.layers.4.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
977
+ "model.layers.4.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
978
+ "model.layers.4.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
979
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
980
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
981
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
982
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
983
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
984
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
985
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
986
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
987
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
988
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
989
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
990
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
991
+ "model.layers.5.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
992
+ "model.layers.5.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
993
+ "model.layers.5.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
994
+ "model.layers.5.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
995
+ "model.layers.5.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
996
+ "model.layers.5.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
997
+ "model.layers.5.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
998
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
999
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
1000
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
1001
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
1002
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
1003
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1004
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1005
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
1006
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
1007
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
1008
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
1009
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
1010
+ "model.layers.6.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
1011
+ "model.layers.6.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
1012
+ "model.layers.6.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
1013
+ "model.layers.6.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
1014
+ "model.layers.6.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
1015
+ "model.layers.6.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
1016
+ "model.layers.6.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
1017
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
1018
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
1019
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
1020
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
1021
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
1022
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1023
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1024
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
1025
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
1026
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
1027
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
1028
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
1029
+ "model.layers.7.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
1030
+ "model.layers.7.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
1031
+ "model.layers.7.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
1032
+ "model.layers.7.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
1033
+ "model.layers.7.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
1034
+ "model.layers.7.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
1035
+ "model.layers.7.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
1036
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
1037
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
1038
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
1039
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
1040
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
1041
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1042
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1043
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
1044
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
1045
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
1046
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
1047
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
1048
+ "model.layers.8.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
1049
+ "model.layers.8.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
1050
+ "model.layers.8.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
1051
+ "model.layers.8.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
1052
+ "model.layers.8.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
1053
+ "model.layers.8.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
1054
+ "model.layers.8.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
1055
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
1056
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
1057
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
1058
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
1059
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
1060
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1061
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1062
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00004.safetensors",
1063
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
1064
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
1065
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
1066
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
1067
+ "model.layers.9.self_attn.casa_attn.k_proj_casa.bias": "model-00001-of-00004.safetensors",
1068
+ "model.layers.9.self_attn.casa_attn.k_proj_casa.weight": "model-00001-of-00004.safetensors",
1069
+ "model.layers.9.self_attn.casa_attn.o_proj_casa.weight": "model-00001-of-00004.safetensors",
1070
+ "model.layers.9.self_attn.casa_attn.q_proj_casa.bias": "model-00001-of-00004.safetensors",
1071
+ "model.layers.9.self_attn.casa_attn.q_proj_casa.weight": "model-00001-of-00004.safetensors",
1072
+ "model.layers.9.self_attn.casa_attn.v_proj_casa.bias": "model-00001-of-00004.safetensors",
1073
+ "model.layers.9.self_attn.casa_attn.v_proj_casa.weight": "model-00001-of-00004.safetensors",
1074
+ "model.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
1075
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
1076
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
1077
+ "model.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
1078
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
1079
+ "model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1080
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1081
+ "model.norm.weight": "model-00003-of-00004.safetensors"
1082
+ }
1083
+ }
modeling_qwen2_5vl_casa.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any
3
+ from typing import cast as type_cast
4
+
5
+ import torch
6
+ from transformers.cache_utils import DynamicCache
7
+ from transformers.generation.utils import GenerateOutput
8
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
9
+ Qwen2_5_VLCausalLMOutputWithPast,
10
+ Qwen2_5_VLForConditionalGeneration,
11
+ )
12
+
13
+ from .image_encoder import Qwen25VLEncoder
14
+ from .configuration_qwen2_5vl_casa import Qwen2_5_VLCASAConfig
15
+ from .language_qwen2_5vl_casa import (
16
+ Qwen2_5_VLAttention_CASA,
17
+ QwenCASAAttention,
18
+ QwenCASAAttentionHandler,
19
+ add_casa_layers,
20
+ )
21
+
22
+
23
+ class V2Qwen2_5VL(Qwen2_5_VLForConditionalGeneration): # pyright: ignore[reportIncompatibleMethodOverride]
24
+ config_class = Qwen2_5_VLCASAConfig
25
+
26
+ def __init__(self, config: Qwen2_5_VLCASAConfig, **kwargs: Any) -> None:
27
+ del kwargs
28
+ super().__init__(config)
29
+ self.image_prefix = Qwen25VLEncoder(self.visual) # type: ignore[assignment]
30
+ self.visual = None
31
+ self.model.apply(partial(add_casa_layers, xa_layers=self.config.xa_layers))
32
+
33
+ def get_device(self) -> str:
34
+ """Return the device type of the model"""
35
+ return next(self.parameters()).device.type
36
+
37
+ @property
38
+ def token_dim(self) -> int:
39
+ """Returns the number of dimensions for the token representation"""
40
+ return self.config.hidden_size
41
+
42
+ def _update_model_kwargs_for_generation(
43
+ self,
44
+ outputs: Any,
45
+ model_kwargs: dict[str, Any],
46
+ is_encoder_decoder: bool = False,
47
+ num_new_tokens: int = 1,
48
+ ):
49
+ """This is required to handle multiple gen calls for subtitles"""
50
+ # Call parent to get default updates
51
+ model_kwargs = super()._update_model_kwargs_for_generation(
52
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
53
+ )
54
+ # Used by prepare_inputs_for_generation
55
+ model_kwargs["__is_first_gen_call__"] = False
56
+ return model_kwargs
57
+
58
+ def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
59
+ self,
60
+ input_ids: torch.Tensor,
61
+ past_key_values: DynamicCache | None = None,
62
+ **kwargs: Any,
63
+ ):
64
+ """Required to handle cache_position = None with QwenVL"""
65
+ __is_first_gen_call__ = kwargs.pop("__is_first_gen_call__", True)
66
+ if past_key_values is not None and (
67
+ kwargs.get("cache_position") is None
68
+ or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
69
+ ):
70
+ # We're continuing from a cached state
71
+ past_length = past_key_values._seen_tokens
72
+ kwargs["cache_position"] = torch.arange(
73
+ past_length,
74
+ past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
75
+ dtype=torch.long,
76
+ device=input_ids.device,
77
+ )
78
+
79
+ return super().prepare_inputs_for_generation(
80
+ input_ids,
81
+ past_key_values=past_key_values,
82
+ **kwargs,
83
+ )
84
+
85
+ def prepare_multimodal_inputs(
86
+ self,
87
+ # text only training
88
+ input_ids: torch.Tensor | None = None,
89
+ inputs_embeds: torch.Tensor | None = None,
90
+ attention_mask: torch.Tensor | None = None,
91
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
92
+ labels: torch.Tensor | None = None,
93
+ # image values
94
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
95
+ pre_image_tokens: list[int] | None = None,
96
+ post_image_tokens: list[int] | None = None,
97
+ **_kwargs: Any,
98
+ ) -> dict:
99
+ """Get a batch data mixing text and image data"""
100
+ del _kwargs
101
+
102
+ processed_inputs: dict = {
103
+ "input_ids": input_ids,
104
+ "inputs_embeds": inputs_embeds,
105
+ "labels": labels,
106
+ "attention_mask": attention_mask,
107
+ "image_embeds_insertion_points": image_embeds_insertion_points,
108
+ }
109
+ if pixel_values is not None:
110
+ processed_inputs.update(self.image_prefix(pixel_values))
111
+ assert "image_embeds" in processed_inputs
112
+ assert (
113
+ isinstance(processed_inputs["image_embeds"], torch.Tensor)
114
+ and processed_inputs["image_embeds"].ndim == 3
115
+ ) or (
116
+ isinstance(processed_inputs["image_embeds"], list)
117
+ and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
118
+ )
119
+
120
+ # Add kwargs necessary to compute cu_seqlens windows for CASA
121
+ processed_inputs["casa_windows_info"] = {
122
+ "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
123
+ "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
124
+ }
125
+
126
+ return processed_inputs
127
+
128
+ def forward( # type: ignore[override] # pylint: disable=W0221
129
+ self,
130
+ input_ids: torch.Tensor | None = None,
131
+ inputs_embeds: torch.Tensor | None = None,
132
+ attention_mask: torch.Tensor | None = None,
133
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
134
+ labels: torch.Tensor | None = None,
135
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
136
+ reinit_casa_handler: bool = True,
137
+ pre_image_tokens: list[int] | None = None,
138
+ post_image_tokens: list[int] | None = None,
139
+ **kwargs: Any,
140
+ ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
141
+ """Multi-modal forward pass"""
142
+
143
+ if reinit_casa_handler:
144
+ processed_inputs = self.prepare_multimodal_inputs(
145
+ input_ids=input_ids,
146
+ inputs_embeds=inputs_embeds,
147
+ attention_mask=attention_mask,
148
+ image_embeds_insertion_points=image_embeds_insertion_points,
149
+ pixel_values=pixel_values,
150
+ labels=labels,
151
+ post_image_tokens=post_image_tokens,
152
+ pre_image_tokens=pre_image_tokens,
153
+ )
154
+ inputs_embeds = type_cast(
155
+ torch.Tensor, self.model.embed_tokens(processed_inputs["input_ids"])
156
+ )
157
+ casa_attention_handler: QwenCASAAttentionHandler | None = None
158
+ image_embeds = processed_inputs.get("image_embeds", None)
159
+ attention_mask = processed_inputs["attention_mask"]
160
+ inst_points = processed_inputs.get("image_embeds_insertion_points", None)
161
+ if image_embeds is None:
162
+ inst_points = None
163
+ casa_attention_handler = QwenCASAAttentionHandler(
164
+ # for text tokens, we don't need the actual values
165
+ inputs_embeds=torch.zeros_like(inputs_embeds),
166
+ # for image embeddings, we put real inputs as this will be fixed
167
+ image_embeds=[] if image_embeds is None else image_embeds,
168
+ image_embeds_insertion_points=inst_points,
169
+ # attention mask is only needed at inference / left padding
170
+ attention_mask=None if self.training else processed_inputs["attention_mask"],
171
+ rope_fn=self.model.rotary_emb,
172
+ windows=self.config.casa_windows,
173
+ casa_windows_info=processed_inputs.pop("casa_windows_info", None),
174
+ use_asymetric_q_kv=self.config.casa_use_asymetric_qkv,
175
+ # extra for Qwen
176
+ get_rope_index=self.get_rope_index,
177
+ grid_thw=processed_inputs.get("grid_thw", None),
178
+ )
179
+ self.update_casa_states(casa_attention_handler)
180
+ else:
181
+ inputs_embeds = self.model.embed_tokens(input_ids)
182
+
183
+ # Run Qwen with the attention layers replaced to use CASA
184
+ assert inputs_embeds is not None, "Could not compute input embeddings!"
185
+ out = super().forward(
186
+ inputs_embeds=inputs_embeds, # type: ignore[arg-type]
187
+ attention_mask=attention_mask,
188
+ pixel_values=None,
189
+ **kwargs,
190
+ )
191
+
192
+ return out
193
+
194
+ @torch.no_grad()
195
+ def generate_from_image( # pyright: ignore[reportInconsistentOverload]
196
+ self,
197
+ input_ids: torch.Tensor | None = None,
198
+ inputs_embeds: torch.Tensor | None = None,
199
+ attention_mask: torch.Tensor | None = None,
200
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
201
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
202
+ pre_image_tokens: list[int] | None = None,
203
+ post_image_tokens: list[int] | None = None,
204
+ position_ids_offset: int | None = None,
205
+ reset_streaming: bool = True,
206
+ **kwargs: Any,
207
+ ) -> GenerateOutput | torch.LongTensor:
208
+ """Custom generate function"""
209
+ assert input_ids is not None and inputs_embeds is None, (
210
+ "Input IDs must be provided for generation"
211
+ )
212
+
213
+ # init self-attention KVCache
214
+ if kwargs.get("past_key_values", None) is None:
215
+ kwargs["past_key_values"] = DynamicCache()
216
+
217
+ # To avoid generate warning
218
+ if kwargs.get("pad_token_id", None) is None:
219
+ kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
220
+ if isinstance(kwargs["pad_token_id"], (list, tuple)):
221
+ kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
222
+
223
+ # Init CASA states
224
+ processed_inputs = self.prepare_multimodal_inputs(
225
+ input_ids=input_ids,
226
+ inputs_embeds=inputs_embeds,
227
+ attention_mask=attention_mask,
228
+ image_embeds_insertion_points=image_embeds_insertion_points,
229
+ pixel_values=pixel_values,
230
+ labels=None,
231
+ pre_image_tokens=pre_image_tokens,
232
+ post_image_tokens=post_image_tokens,
233
+ )
234
+
235
+ if pixel_values is not None:
236
+ assert (image_embeds := processed_inputs.get("image_embeds", None)) is not None
237
+ assert (
238
+ insrt_pts := processed_inputs.get("image_embeds_insertion_points", None)
239
+ ) is not None
240
+ casa_attention_handler = QwenCASAAttentionHandler(
241
+ inputs_embeds=torch.empty(
242
+ (input_ids.shape[0], input_ids.shape[1], image_embeds[0].shape[-1]),
243
+ dtype=image_embeds[0].dtype,
244
+ device=image_embeds[0].device,
245
+ ),
246
+ image_embeds=image_embeds,
247
+ image_embeds_insertion_points=insrt_pts,
248
+ attention_mask=attention_mask,
249
+ rope_fn=self.model.rotary_emb,
250
+ windows=self.config.casa_windows,
251
+ casa_windows_info=processed_inputs.pop("casa_windows_info", None),
252
+ use_asymetric_q_kv=self.config.casa_use_asymetric_qkv,
253
+ get_rope_index=self.get_rope_index,
254
+ grid_thw=processed_inputs.get("grid_thw", None),
255
+ position_ids_offset=position_ids_offset or kwargs["past_key_values"]._seen_tokens,
256
+ )
257
+ self.update_casa_states(casa_attention_handler)
258
+ self.start_casa_streaming_states()
259
+ pixel_values = None
260
+
261
+ # Generate
262
+ outputs = self.generate(
263
+ input_ids,
264
+ attention_mask=attention_mask,
265
+ pixel_values=pixel_values,
266
+ use_cache=True,
267
+ reinit_casa_handler=False,
268
+ **kwargs,
269
+ )
270
+
271
+ if reset_streaming:
272
+ self.reset_casa_streaming_states()
273
+ return outputs
274
+
275
+ def update_casa_states(self, handler: QwenCASAAttentionHandler | None):
276
+ """Update handler in all layers"""
277
+
278
+ def __update__(m: torch.nn.Module):
279
+ nonlocal handler
280
+
281
+ if isinstance(m, Qwen2_5_VLAttention_CASA):
282
+ m.casa_attention_handler = handler
283
+
284
+ self.apply(__update__)
285
+
286
+ def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
287
+ def __reset__(m: torch.nn.Module):
288
+ if isinstance(m, QwenCASAAttention):
289
+ m._set_streaming(False, ())
290
+ m.reset_streaming()
291
+ if clean_cache:
292
+ del m.streaming_state.k
293
+ del m.streaming_state.v
294
+ m.streaming_state.k = None # pyright: ignore[reportAttributeAccessIssue]
295
+ m.streaming_state.v = None # pyright: ignore[reportAttributeAccessIssue]
296
+
297
+ elif isinstance(m, Qwen2_5_VLAttention_CASA):
298
+ del m.casa_attention_handler
299
+ m.casa_attention_handler = None
300
+
301
+ self.apply(__reset__)
302
+
303
+ def start_casa_streaming_states(self) -> None:
304
+ def __start__(m: torch.nn.Module):
305
+ if isinstance(m, QwenCASAAttention):
306
+ m._set_streaming(True, ())
307
+
308
+ self.apply(__start__)
processing.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=no-member # avoid weird pylint warnings from SentencePieceProcessor
2
+ """Text and Image processor for CASA models using Qwen2.5_VL image encoder"""
3
+
4
+ from math import ceil
5
+ from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload
6
+ from typing import cast as type_cast
7
+
8
+ import torch
9
+ import torchvision.transforms.v2 as T
10
+ from einops import rearrange
11
+ from PIL import Image
12
+ from torchvision.transforms import InterpolationMode
13
+ from torchvision.transforms.functional import to_tensor as pil_to_tensor
14
+ from torchvision.transforms.v2 import functional as F
15
+ from transformers.image_processing_utils import BaseImageProcessor
16
+ from transformers.processing_utils import ProcessorMixin
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+
22
+
23
+ ImageMessage = TypedDict(
24
+ "ImageMessage",
25
+ {
26
+ "type": Literal["image"],
27
+ "image": str | Image.Image | None,
28
+ },
29
+ )
30
+
31
+ TextMessage = TypedDict(
32
+ "TextMessage",
33
+ {
34
+ "type": Literal["text"],
35
+ "text": str,
36
+ },
37
+ )
38
+
39
+ MessageContent = list[ImageMessage | TextMessage]
40
+
41
+ Message = TypedDict(
42
+ "Message",
43
+ {
44
+ "role": Literal["system", "user", "assistant"],
45
+ "content": MessageContent,
46
+ },
47
+ )
48
+
49
+ ProcessorInput = list[list[Message]] | list[Message]
50
+
51
+ __INTERP_NAME_TO_MODE__ = {
52
+ "nearest": InterpolationMode.NEAREST,
53
+ "bilinear": InterpolationMode.BILINEAR,
54
+ "bicubic": InterpolationMode.BICUBIC,
55
+ "lanczos": InterpolationMode.LANCZOS,
56
+ }
57
+
58
+ __INTERP_INT_TO_MODE__ = {
59
+ 0: InterpolationMode.NEAREST,
60
+ 2: InterpolationMode.BILINEAR,
61
+ 3: InterpolationMode.BICUBIC,
62
+ 4: InterpolationMode.BOX,
63
+ 5: InterpolationMode.HAMMING,
64
+ 1: InterpolationMode.LANCZOS,
65
+ }
66
+
67
+
68
+ @overload
69
+ def universal_resize(
70
+ img: Image.Image,
71
+ size: tuple[int, int],
72
+ interpolation: str | InterpolationMode | int = "bilinear",
73
+ antialias: bool = True,
74
+ ) -> Image.Image: ...
75
+ @overload
76
+ def universal_resize(
77
+ img: torch.Tensor,
78
+ size: tuple[int, int],
79
+ interpolation: str | InterpolationMode | int = "bilinear",
80
+ antialias: bool = True,
81
+ ) -> torch.Tensor: ...
82
+ def universal_resize(
83
+ img: Image.Image | torch.Tensor,
84
+ size: tuple[int, int],
85
+ interpolation: str | InterpolationMode | int = "bilinear",
86
+ antialias: bool = True,
87
+ ) -> Image.Image | torch.Tensor:
88
+ """Resize that works for PIL.Image, CHW tensor, or BCHW tensor"""
89
+ if isinstance(interpolation, str):
90
+ interpolation = __INTERP_NAME_TO_MODE__[interpolation]
91
+ elif isinstance(interpolation, int):
92
+ interpolation = __INTERP_INT_TO_MODE__[interpolation]
93
+
94
+ return F.resize(
95
+ img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias
96
+ )
97
+
98
+
99
+ @overload
100
+ def convert_to_rgb(img: Image.Image) -> Image.Image: ...
101
+ @overload
102
+ def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ...
103
+ def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor:
104
+ """Convert any image to RGB in a way that does not throw PIL warning"""
105
+ if isinstance(img, torch.Tensor):
106
+ return img
107
+ if img.mode == "RGB": # no changes
108
+ return img
109
+ if img.mode == "P": # palette images need to be converted to RGBA first
110
+ return img.convert("RGBA").convert("RGB")
111
+ return img.convert("RGB")
112
+
113
+
114
+ class QwenImageProcessor(BaseImageProcessor):
115
+ """Resizing for the Qwen2.5VL encoder. Note that the normalization is
116
+ handled in the image_encoder in the model forward"""
117
+
118
+ def __init__(
119
+ self,
120
+ img_size: int = 448,
121
+ interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic",
122
+ max_ratio: int = 10,
123
+ round_to_patch_size: int = 56,
124
+ use_fast: bool = True,
125
+ **kwargs: Any,
126
+ ) -> None:
127
+ # this will also be used in V2llms to determine whether to remove
128
+ # the temporal conv
129
+ self._num_target_channels = 588
130
+ self._merge_size = 2
131
+ self._patch_size = 14
132
+ super().__init__(
133
+ use_fast=use_fast,
134
+ do_normalize=False,
135
+ **kwargs,
136
+ )
137
+ self.img_size = img_size
138
+ self.interpolation = interpolation
139
+ self.max_ratio = max_ratio
140
+ self.round_to_patch_size = round_to_patch_size
141
+
142
+ def resize_transform(
143
+ self, img: Image.Image | torch.Tensor, img_size: int | None = None
144
+ ) -> Image.Image | torch.Tensor:
145
+ if img_size is None:
146
+ img_size = self.img_size
147
+ max_area = img_size**2
148
+ if isinstance(img, Image.Image):
149
+ img = convert_to_rgb(img)
150
+ w_og, h_og = img.size
151
+ else:
152
+ h_og, w_og = img.shape[-2:]
153
+ w, h = w_og, h_og
154
+
155
+ # Qwen requires max ratio of 10 between max and min sizes
156
+ if self.max_ratio > 0:
157
+ w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio)
158
+
159
+ # resize to max area
160
+ current_area = w * h
161
+ if current_area > max_area:
162
+ scale = (max_area / current_area) ** 0.5
163
+ w, h = int(w * scale), int(h * scale)
164
+
165
+ # resize to patch size
166
+ if self.round_to_patch_size > 0:
167
+ w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size
168
+ h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size
169
+
170
+ # resize
171
+ if w != w_og or h != h_og:
172
+ img = universal_resize(img, (h, w), self.interpolation)
173
+ if isinstance(img, torch.Tensor):
174
+ img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img))
175
+ return img
176
+
177
+ def __process_one__(
178
+ self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None
179
+ ) -> torch.Tensor:
180
+ """Same operation as __process_one_with_processor__ but without going through numpy"""
181
+ video_or_img = self.resize_transform(video_or_img, img_size)
182
+ if isinstance(video_or_img, Image.Image):
183
+ video_or_img = pil_to_tensor(video_or_img)
184
+ assert isinstance(video_or_img, torch.Tensor)
185
+ if video_or_img.ndim == 3:
186
+ video_or_img = video_or_img[None]
187
+ assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, (
188
+ f"Invalid shape {video_or_img.shape}."
189
+ )
190
+ t, c, h, w = video_or_img.shape
191
+ p = self._patch_size
192
+ m = self._merge_size
193
+
194
+ # Convert to RGB
195
+ if c == 1:
196
+ video_or_img = video_or_img.expand((-1, 3, -1, -1))
197
+ if c == 4:
198
+ video_or_img = video_or_img[:, :3]
199
+ c = video_or_img.shape[1]
200
+ assert c == 3, "Expecting RGB image in QwenNormalize"
201
+
202
+ # Reshape to t h w c' format
203
+ h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p
204
+ rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m)
205
+
206
+ video_or_img = rearrange(
207
+ video_or_img,
208
+ "t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)",
209
+ **rearrange_dict,
210
+ )
211
+ assert video_or_img.shape[-1] == self._num_target_channels, (
212
+ f"{video_or_img.shape[-1]} != {self._num_target_channels}"
213
+ )
214
+ video_or_img = video_or_img.view((-1, h, w, self._num_target_channels))
215
+
216
+ return video_or_img
217
+
218
+ @overload
219
+ def process_images(
220
+ self, image: Image.Image | torch.Tensor, img_size: int | None = None
221
+ ) -> torch.Tensor: ...
222
+ @overload
223
+ def process_images(
224
+ self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None
225
+ ) -> list[torch.Tensor]: ...
226
+ def process_images(
227
+ self,
228
+ image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor],
229
+ img_size: int | None = None,
230
+ ) -> torch.Tensor | list[torch.Tensor]:
231
+ if isinstance(image, list):
232
+ return [self.__process_one__(_x, img_size) for _x in image]
233
+ return self.__process_one__(image, img_size)
234
+
235
+
236
+ class ProcessorOutput(dict):
237
+ input_ids: torch.Tensor
238
+ attention_mask: torch.Tensor
239
+ image_embeds_insertion_points: list[torch.Tensor] | None
240
+ pixel_values: torch.Tensor | list[torch.Tensor] | None
241
+
242
+ def to(
243
+ self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16
244
+ ) -> "ProcessorOutput":
245
+ return ProcessorOutput(
246
+ {
247
+ "input_ids": self["input_ids"].to(device),
248
+ "attention_mask": self["attention_mask"].to(device),
249
+ "image_embeds_insertion_points": self["image_embeds_insertion_points"],
250
+ "pixel_values": (
251
+ self["pixel_values"].to(dtype).to(device)
252
+ if isinstance(self["pixel_values"], torch.Tensor)
253
+ else [x.to(dtype).to(device) for x in self["pixel_values"]]
254
+ if self["pixel_values"] is not None
255
+ else None
256
+ ),
257
+ }
258
+ )
259
+
260
+
261
+ class BaseProcessor(ProcessorMixin):
262
+ def __init__(
263
+ self,
264
+ tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer",
265
+ pre_image_tokens: tuple[int, ...] = (),
266
+ post_image_tokens: tuple[int, ...] = (),
267
+ system_start_tokens: tuple[int, ...] = (),
268
+ system_end_tokens: tuple[int, ...] = (),
269
+ user_start_tokens: tuple[int, ...] = (),
270
+ user_end_tokens: tuple[int, ...] = (),
271
+ asst_start_tokens: tuple[int, ...] = (),
272
+ asst_end_tokens: tuple[int, ...] = (),
273
+ allow_system_prompt: bool = True,
274
+ pad_token: int = 0,
275
+ bos_token: int | None = None,
276
+ ) -> None:
277
+ self.pre_image_tokens = list(pre_image_tokens)
278
+ self.post_image_tokens = list(post_image_tokens)
279
+ self.system_start_tokens = list(system_start_tokens)
280
+ self.system_end_tokens = list(system_end_tokens)
281
+ self.user_start_tokens = list(user_start_tokens)
282
+ self.user_end_tokens = list(user_end_tokens)
283
+ self.asst_start_tokens = list(asst_start_tokens)
284
+ self.asst_end_tokens = list(asst_end_tokens)
285
+ self._allow_system_prompt = allow_system_prompt
286
+ self.tokenizer = tokenizer
287
+ self._image_processor = None
288
+ self._pad_token = pad_token
289
+ self.bos_token = bos_token
290
+
291
+ @property
292
+ def image_processor(self) -> QwenImageProcessor:
293
+ assert self._image_processor is not None
294
+ return self._image_processor
295
+
296
+ def _process_content(
297
+ self,
298
+ message_content: MessageContent,
299
+ role: Literal["system", "user", "assistant"],
300
+ tokenized_messages: list[torch.Tensor],
301
+ insertion_points: list[int],
302
+ image_list: list[torch.Tensor | None],
303
+ token_count: int,
304
+ img_size: int | None = None,
305
+ **kwargs: Any,
306
+ ) -> int:
307
+ mapping = {
308
+ "user": (self.user_start_tokens, self.user_end_tokens),
309
+ "assistant": (self.asst_start_tokens, self.asst_end_tokens),
310
+ "system": (self.system_start_tokens, self.system_end_tokens),
311
+ }
312
+ if role.lower() not in mapping:
313
+ raise ValueError(f"Unknown role '{role}' encountered in messages.")
314
+ start_tokens, end_tokens = mapping[role.lower()]
315
+ # 1) Add the start tokens
316
+ if start_tokens:
317
+ tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long))
318
+ token_count += len(start_tokens)
319
+ # 2) Process the message content one by one (potentially interleaved image and text)
320
+ for part in message_content:
321
+ elt_type = part["type"]
322
+ if elt_type == "image":
323
+ part = cast(ImageMessage, part)
324
+ self._process_image_message(
325
+ part,
326
+ tokenized_messages,
327
+ image_list,
328
+ img_size=img_size,
329
+ )
330
+ token_count += len(self.pre_image_tokens)
331
+ insertion_points.append(token_count)
332
+ token_count += len(self.post_image_tokens)
333
+ else:
334
+ part = cast(TextMessage, part)
335
+ self._process_text_message(
336
+ part["text"],
337
+ role=role,
338
+ token_list=tokenized_messages,
339
+ **kwargs,
340
+ )
341
+ token_count += tokenized_messages[-1].size(0)
342
+ # 3) Add the end tokens
343
+ if end_tokens:
344
+ tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long))
345
+ token_count += len(end_tokens)
346
+ return token_count
347
+
348
+ def _process_text_message(
349
+ self,
350
+ message: str,
351
+ role: Literal["system", "user", "assistant"],
352
+ token_list: list[torch.Tensor],
353
+ **kwargs: Any,
354
+ ) -> None:
355
+ if role.lower() == "system" and not self._allow_system_prompt:
356
+ raise ValueError("System prompts are not allowed in this tokenizer configuration.")
357
+ tokens = self.tokenizer.encode(
358
+ message, add_special_tokens=False, return_tensors="pt", **kwargs
359
+ )
360
+ tokens = cast(torch.Tensor, tokens)
361
+ token_list.append(tokens.flatten().to(torch.long))
362
+
363
+ def _process_image_message(
364
+ self,
365
+ message: ImageMessage,
366
+ token_list: list[torch.Tensor],
367
+ image_list: list[torch.Tensor | None],
368
+ img_size: int | None = None,
369
+ ) -> None:
370
+ img = message["image"]
371
+ if img is None:
372
+ image_list.append(None)
373
+ else:
374
+ image_list.append(
375
+ self.image_processor.process_images(
376
+ self._load_image(img), img_size=img_size
377
+ ).squeeze(0)
378
+ )
379
+ if self.pre_image_tokens:
380
+ token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long))
381
+
382
+ if self.post_image_tokens:
383
+ token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long))
384
+
385
+ def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image:
386
+ if isinstance(image_path_or_image, str):
387
+ return Image.open(image_path_or_image).convert("RGB")
388
+ return image_path_or_image
389
+
390
+ def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor:
391
+ return torch.nn.functional.pad(
392
+ tokens,
393
+ (0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0),
394
+ value=pad_value,
395
+ )
396
+
397
+ def pad_tokenized_messages(
398
+ self,
399
+ tokenized_messages_batch: list[torch.Tensor],
400
+ image_insertion_points_batch: list[torch.Tensor] | None = None,
401
+ ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
402
+ max_len = max(len(x) for x in tokenized_messages_batch)
403
+ if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left":
404
+ image_insertion_points_batch = [
405
+ x + max_len - len(tokenized_messages_batch[idx])
406
+ for idx, x in enumerate(image_insertion_points_batch)
407
+ ]
408
+ input_ids = torch.stack(
409
+ [
410
+ self._maybe_pad(s, max_len - s.size(0), self._pad_token)
411
+ for s in tokenized_messages_batch
412
+ ],
413
+ dim=0,
414
+ )
415
+ attention_mask = torch.stack(
416
+ [
417
+ self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0)
418
+ for s in tokenized_messages_batch
419
+ ],
420
+ dim=0,
421
+ )
422
+ return input_ids, attention_mask, image_insertion_points_batch
423
+
424
+ def tokenize_messages(
425
+ self,
426
+ messages: ProcessorInput,
427
+ suppress_bos_token: bool = False,
428
+ **kwargs: Any,
429
+ ) -> ProcessorOutput | None:
430
+ """Tokenize a batch of messages into token IDs suitable for Helium1 CASA model.
431
+
432
+ Args:
433
+ messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages),
434
+ where each message is a list of dictionaries with 'role' and 'content' keys.
435
+ continue_final_message (bool, optional): If True, the final message in each list will not have an end token added.
436
+ Defaults to False.
437
+ suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added.
438
+ Defaults to False.
439
+ **kwargs: Additional keyword arguments passed to the underlying encode method.
440
+ """
441
+ if not messages:
442
+ return None
443
+ if isinstance(messages[0], dict):
444
+ messages = [messages] # type: ignore[assignment]
445
+
446
+ messages = cast(list[list[Message]], messages)
447
+ image_insertion_points_batch = []
448
+ tokenized_messages_batch = []
449
+ image_list: list[torch.Tensor | None] = []
450
+ for msgs in messages:
451
+ # msgs.append({
452
+ # "role": "assistant",
453
+ # "content": [{"type": "text", "text": ""}]
454
+ # })
455
+ tokenized_messages = []
456
+ if not suppress_bos_token and self.bos_token is not None:
457
+ tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long))
458
+ insertion_points = []
459
+ token_count = 0
460
+ for msg in msgs:
461
+ token_count = self._process_content(
462
+ msg["content"],
463
+ role=msg["role"],
464
+ tokenized_messages=tokenized_messages,
465
+ insertion_points=insertion_points,
466
+ image_list=image_list,
467
+ token_count=token_count,
468
+ **kwargs,
469
+ )
470
+ tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long))
471
+ image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long))
472
+
473
+ if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant":
474
+ # Remove the assistant end tokens from the final message
475
+ end_token_len = len(self.asst_end_tokens)
476
+ tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len]
477
+ if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user":
478
+ # Remove the assistant end tokens from the final message
479
+ end_token_len = len(self.asst_end_tokens)
480
+ tokenized_messages_batch[-1] = torch.cat(
481
+ [
482
+ tokenized_messages_batch[-1],
483
+ torch.Tensor(self.asst_start_tokens).to(torch.long),
484
+ ]
485
+ )
486
+
487
+ input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages(
488
+ tokenized_messages_batch, image_insertion_points_batch
489
+ )
490
+
491
+ if image_list:
492
+ assert sum(img is None for img in image_list) % len(image_list) == 0, (
493
+ "Either all or no image must be None."
494
+ )
495
+ pixel_values: None | torch.Tensor | list[torch.Tensor]
496
+ if image_list[0] is None:
497
+ pixel_values = None
498
+ else:
499
+ pixel_values = cast(list[torch.Tensor], image_list)
500
+ return ProcessorOutput(
501
+ input_ids=input_ids,
502
+ image_embeds_insertion_points=image_embeds_insertion_points,
503
+ attention_mask=attention_mask,
504
+ pixel_values=pixel_values,
505
+ )
processing_qwen2_5vl_casa.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
4
+
5
+ from .processing import BaseProcessor, QwenImageProcessor
6
+
7
+
8
+ class QwenCASAProcessor(BaseProcessor):
9
+ attributes = ["tokenizer"]
10
+ tokenizer_class = "Qwen2Tokenizer"
11
+
12
+ def __init__(
13
+ self,
14
+ tokenizer: Qwen2Tokenizer,
15
+ pre_image_tokens: tuple[int, ...] = (151652,),
16
+ post_image_tokens: tuple[int, ...] = (151653,),
17
+ system_start_tokens: tuple[int, ...] = (151644, 8948, 198),
18
+ system_end_tokens: tuple[int, ...] = (151645, 198),
19
+ user_start_tokens: tuple[int, ...] = (151644, 872, 198),
20
+ user_end_tokens: tuple[int, ...] = (151645, 198),
21
+ asst_start_tokens: tuple[int, ...] = (151644, 77091, 198),
22
+ asst_end_tokens: tuple[int, ...] = (151645, 198),
23
+ image_size: int = 448,
24
+ **kwargs: Any,
25
+ ):
26
+ del kwargs
27
+ super().__init__(
28
+ tokenizer=tokenizer,
29
+ pre_image_tokens=pre_image_tokens,
30
+ post_image_tokens=post_image_tokens,
31
+ system_start_tokens=system_start_tokens,
32
+ system_end_tokens=system_end_tokens,
33
+ user_start_tokens=user_start_tokens,
34
+ user_end_tokens=user_end_tokens,
35
+ asst_start_tokens=asst_start_tokens,
36
+ asst_end_tokens=asst_end_tokens,
37
+ )
38
+
39
+ self._image_processor = QwenImageProcessor(img_size=image_size)
processor_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_qwen2_5vl_casa.QwenCASAProcessor"
4
+ },
5
+ "image_size": 448,
6
+ "post_image_tokens": [
7
+ 151653
8
+ ],
9
+ "pre_image_tokens": [
10
+ 151652
11
+ ],
12
+ "processor_class": "QwenCASAProcessor"
13
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151646": {
29
+ "content": "<|object_ref_start|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "151647": {
37
+ "content": "<|object_ref_end|>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "151648": {
45
+ "content": "<|box_start|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "151649": {
53
+ "content": "<|box_end|>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "151650": {
61
+ "content": "<|quad_start|>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "151651": {
69
+ "content": "<|quad_end|>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "151652": {
77
+ "content": "<|vision_start|>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "151653": {
85
+ "content": "<|vision_end|>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "151654": {
93
+ "content": "<|vision_pad|>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "151655": {
101
+ "content": "<|image_pad|>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "151656": {
109
+ "content": "<|video_pad|>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "151657": {
117
+ "content": "<tool_call>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": false
123
+ },
124
+ "151658": {
125
+ "content": "</tool_call>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": false
131
+ },
132
+ "151659": {
133
+ "content": "<|fim_prefix|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": false
139
+ },
140
+ "151660": {
141
+ "content": "<|fim_middle|>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": false
147
+ },
148
+ "151661": {
149
+ "content": "<|fim_suffix|>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": false
155
+ },
156
+ "151662": {
157
+ "content": "<|fim_pad|>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": false
163
+ },
164
+ "151663": {
165
+ "content": "<|repo_name|>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": false
171
+ },
172
+ "151664": {
173
+ "content": "<|file_sep|>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": false
179
+ }
180
+ },
181
+ "additional_special_tokens": [
182
+ "<|im_start|>",
183
+ "<|im_end|>",
184
+ "<|object_ref_start|>",
185
+ "<|object_ref_end|>",
186
+ "<|box_start|>",
187
+ "<|box_end|>",
188
+ "<|quad_start|>",
189
+ "<|quad_end|>",
190
+ "<|vision_start|>",
191
+ "<|vision_end|>",
192
+ "<|vision_pad|>",
193
+ "<|image_pad|>",
194
+ "<|video_pad|>"
195
+ ],
196
+ "bos_token": null,
197
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
198
+ "clean_up_tokenization_spaces": false,
199
+ "eos_token": "<|im_end|>",
200
+ "errors": "replace",
201
+ "model_max_length": 131072,
202
+ "pad_token": "<|endoftext|>",
203
+ "split_special_tokens": false,
204
+ "tokenizer_class": "Qwen2Tokenizer",
205
+ "unk_token": null,
206
+ "add_bos_token": false
207
+ }
utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=protected-access
2
+ """Utils to handle CASA layers construction"""
3
+
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass, fields
6
+ from typing import Any, Callable, Generic, TypeVar
7
+
8
+ import torch
9
+
10
+
11
+ def delta_w_factory(
12
+ org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
13
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
14
+ """Factory for building linear op where the weights are the sum of two layers' weights"""
15
+
16
+ def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
17
+ nonlocal org_lin, new_lin
18
+ bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
19
+ return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)
20
+
21
+ return _delta_w_fwd
22
+
23
+
24
+ @dataclass
25
+ class StreamingState:
26
+ """Streaming State used by CASA layers at inference to save
27
+ e.g. the offset, the KV Cache and other persistent states"""
28
+
29
+ offset: int = 0
30
+
31
+ def _is_valid_field(self, key: str) -> bool:
32
+ return key in {x.name for x in fields(self)}
33
+
34
+ def _init_field(self, key: str) -> None:
35
+ """Init function for non-arggment dependent defauls"""
36
+ assert self._is_valid_field(key)
37
+ if key == "offset":
38
+ self.offset = 0
39
+ else:
40
+ # for fields which should be set explicitly and cannot be auto-initialized
41
+ setattr(self, key, None)
42
+
43
+ def init(self) -> None:
44
+ for key in [x.name for x in fields(self)]:
45
+ self._init_field(key)
46
+
47
+ def _reset_field(self, name: str) -> None:
48
+ """Resets the given field"""
49
+ self._init_field(name)
50
+
51
+ def reset(self) -> None:
52
+ for f in fields(self):
53
+ self._reset_field(f.name)
54
+
55
+ def _get_field(self, f: str) -> Any:
56
+ """Get field and init if not"""
57
+ assert self._is_valid_field(f)
58
+ if getattr(self, f) is None:
59
+ self._init_field(f)
60
+ return getattr(self, f)
61
+
62
+ def _set_field(self, f: str, value: Any) -> None:
63
+ assert self._is_valid_field(f)
64
+ setattr(self, f, value)
65
+
66
+
67
+ StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)
68
+
69
+
70
+ class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method
71
+ """Overrides Audiocraft's Streaming modules with additional small utils"""
72
+
73
+ def __init__(self, state_class: type) -> None:
74
+ torch.nn.Module.__init__(self)
75
+ self.is_streaming: bool = False
76
+ self.enable_viz: tuple[str, ...] = ()
77
+ self._streaming_state: StreamingStateT = state_class()
78
+
79
+ @property
80
+ def streaming_state(self) -> StreamingStateT:
81
+ return self._streaming_state
82
+
83
+ def _apply_named_streaming(self, fn: Callable):
84
+ """Apply function to all streaming modules"""
85
+ for name, module in self.named_modules():
86
+ if isinstance(module, StreamingModule):
87
+ fn(name, module)
88
+
89
+ def reset_streaming(self):
90
+ """Reset the streaming state."""
91
+
92
+ def _reset(_: str, module: StreamingModule):
93
+ module._streaming_state.reset()
94
+
95
+ self._apply_named_streaming(_reset)
96
+
97
+ def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
98
+ """Set all streaming modules in streaming mode"""
99
+
100
+ def _set_streaming(_, module: StreamingModule) -> None:
101
+ module.is_streaming = streaming
102
+ module.enable_viz = viz
103
+ if streaming:
104
+ module.streaming_state.init()
105
+
106
+ self._apply_named_streaming(_set_streaming)
107
+
108
+ @contextmanager
109
+ def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
110
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
111
+ self._set_streaming(stream, viz)
112
+ try:
113
+ yield
114
+ finally:
115
+ self._set_streaming(False, ())
116
+ self.reset_streaming()
vocab.json ADDED
The diff for this file is too large to render. See raw diff