Image-Text-to-Text
Transformers
Safetensors
English
CASA_Helium1_VL_2B
custom_code
ameroyer mboehle commited on
Commit
fc8600b
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: mboehle <mboehle@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ model-00002-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ model-00003-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
4
+ readme_images/CASA.png filter=lfs diff=lfs merge=lfs -text
5
+ readme_images/casa_explainer.mp4 filter=lfs diff=lfs merge=lfs -text
6
+ readme_images/half_res_trimmed.mp4 filter=lfs diff=lfs merge=lfs -text
Notice ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ CASA-Helium1-VL-2B's image encoder is finetuned from the image encoder of Qwen2.5-VL-3B.
2
+ Qwen is licensed under the Qwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
README.md ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-sa-4.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - kyutai/helium-1-2b
7
+ datasets:
8
+ - HuggingFaceM4/FineVision
9
+ - mvp-lab/LLaVA-OneVision-1.5-Instruct-Data
10
+ ---
11
+
12
+ <img align="right" src="readme_images/CASA.png" width="150px" >
13
+
14
+ # Model Card for CASA-Helium1-VL-2B
15
+
16
+ **CASA** ([Project Page][blog] . [arXiv][casa-arxiv] . [github][casa-git]) stands for **C**ross-**A**ttention via **S**elf-**A**ttention.
17
+ **CASA** is a vision-language fusion paradigm that aims to improve on cross-attention while preserving its practical benefits.
18
+
19
+ Specifically, **CASA** layers inject visual tokens into a text stream by using image-to-text cross-attention while additionally enabling
20
+ text-to-text self interaction in the same layer, and contained to smaller local attention windows.
21
+ This simple modification enables natural gating in the cross-attention mechanism, improving its performance and substantially closing the gap to standard token insertion methods.
22
+ For qualitative samples of CASA used for live video captioning, please check the [associated HuggingFace space](https://huggingface.co/spaces/kyutai/casa-samples).
23
+
24
+ ![](readme_images/casa_explainer.mp4)
25
+
26
+ ## Model Details
27
+
28
+ ### Model Description
29
+
30
+
31
+ This model page contains the model weights for CASA trained from a pretrained text-only Helium1-2B backbone and from the image encoder from Qwen2.5-VL-3B.
32
+ In the collection, we also provides weights for:
33
+ - [`CASA-Qwen2_5-VL-3B`](https://huggingface.co/kyutai/CASA-Qwen2_5-VL-3B): A CASA model adapted from the full pretrained `Qwen2.5-VL-3B` (keeping the backbone LLM weights are kept frozen)
34
+ - [`CASA-Qwen2_5-VL-3B-LiveCC`](https://huggingface.co/kyutai/CASA-Qwen2_5-VL-3B-LiveCC): A CASA model adapted from the full pretrained `Qwen2.5-VL-3B` and futher finetuned for live video captioning.
35
+ - [`Helium1-VL-2B`](https://huggingface.co/kyutai/Helium1-VL-2B): A reference VLM trained from Helium1-2B with standard token insertion mechanism in the same setting as `CASA-Helium1-VL-2B`.
36
+
37
+ Model Summary:
38
+ - **Developed by:** Kyutai
39
+ - **Model type:** Multimodal vision+text model based on Cross-Attention
40
+ - **Language(s) (NLP):** English
41
+ - **License:** CC-BY-NC-SA-4.0
42
+ - **LLM Backboner from:** [Helium1 2B](https://huggingface.co/kyutai/helium-1-2b)
43
+ - **Image Encoder from:** [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
44
+ - **Terms of use:** As the released models include frozen weights of the Qwen2.5VL-3B image encoder, the weights are subject to the [Qwen RESEARCH LICENSE AGREEMENT](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct/blob/main/LICENSE)
45
+
46
+
47
+ ### Model Sources
48
+
49
+ - **Project Page** [kyutai.org/casa][blog]
50
+ - **Preprint** [arXiv][casa-arxiv]
51
+ - **Repository:** [Github kyutai-labs/casa][casa-git]
52
+
53
+ ## Uses
54
+
55
+ ### Direct Use
56
+ The intended use of the Helium model is research and development of vision-language systems, including but not limited to image or video understanding.
57
+
58
+ `CASA-Helium1-VL-2B`, `Helium1-VL-2B` and `CASA-Qwen2_5-VL-2B` can be used as vision-language models to analyze or interpret images as input signals.
59
+
60
+ `CASA-Qwen2_5-VL-2B-LiveCC` can be used as a vision-language model on streaming videos as inputs at 2fps.
61
+
62
+
63
+ The models can be used primarly with English as a language. For most downstream use cases, the model should be aligned with supervised fine-tuning, RLHF or related methods.
64
+
65
+
66
+ ### Out-of-Scope Use
67
+
68
+ The model should not be used in other languages than the ones on which it was trained.
69
+ The model is not intended to be used to impersonate other people or any malicious use of any kind.
70
+
71
+
72
+ ## Bias, Risks, and Limitations
73
+ Our CASA-Helium1 model was not aligned to human preferences. As such, the model can generate incorrect, biased, harmful or generally unhelpful content. Thus, the model should not be used for downstream applications without further alignment, evaluations and mitigations of risks.
74
+
75
+ ### Recommendations
76
+
77
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
78
+
79
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
80
+
81
+ ## How to Get Started with the Model
82
+
83
+ See our [github repository][casa-git] for additional scripts to perform benchmark evaluation and live video captioning.
84
+
85
+
86
+ Below is a short snippet to show you how to load our models, process inputs, and run inference, using a standard HuggingFace `transformers` pipeline and chat template.
87
+
88
+ ```python
89
+ # Minimal requirements:
90
+ # /// script
91
+ # requires-python = ">=3.10"
92
+ # dependencies = [
93
+ # "rich",
94
+ # "einops>=0.8.1",
95
+ # "torch==2.7.0",
96
+ # "transformers==4.51.3",
97
+ # "torchvision==0.22.0",
98
+ # "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"
99
+ # ]
100
+ # ///
101
+ import torch
102
+ from transformers.models.auto.modeling_auto import AutoModel
103
+ from transformers.models.auto.processing_auto import AutoProcessor
104
+
105
+ model_id = "kyutai/CASA-Helium1-VL-2B"
106
+ model = AutoModel.from_pretrained(
107
+ model_id,
108
+ torch_dtype=torch.bfloat16,
109
+ attn_implementation="flash_attention_2",
110
+ trust_remote_code=True,
111
+ ).cuda()
112
+ processor = AutoProcessor.from_pretrained(
113
+ model_id,
114
+ trust_remote_code=True,
115
+ )
116
+
117
+ conversation = [
118
+ {
119
+ "role": "user",
120
+ "content": [
121
+ {
122
+ "type": "image",
123
+ "image": "assets/casa_model.png",
124
+ },
125
+ {
126
+ "type": "text",
127
+ "text": "Describe this image.",
128
+ },
129
+ ],
130
+ },
131
+ ]
132
+ inputs = processor.tokenize_messages(messages=conversation)
133
+ inputs = inputs.to(model.device)
134
+ input_len = inputs["input_ids"].shape[1]
135
+ output_ids = model.generate_from_image(
136
+ **inputs,
137
+ max_new_tokens=512,
138
+ pre_image_tokens=processor.pre_image_tokens,
139
+ post_image_tokens=processor.post_image_tokens,
140
+ eos_token_id=model.generation_config.eos_token_id,
141
+ )[0, input_len:]
142
+ response = processor.tokenizer.decode(output_ids, skip_special_tokens=True)
143
+ print(response)
144
+ ```
145
+
146
+
147
+
148
+ ## Training Details
149
+
150
+ Please have a look at our associated [research paper][casa-arxiv] for details on the training pipeline.
151
+
152
+ ### Training Data
153
+
154
+ To train our CASA-Helium models we use the [FineVision](https://huggingface.co/datasets/HuggingFaceM4/FineVision)
155
+ dataset as well as a small, non overlapping, subset of [Llava-OneVision-1.5-Instruct](https://github.com/EvolvingLMMs-Lab/LLaVA-OneVision-1.5)
156
+
157
+
158
+ ## Evaluation
159
+ We evaluate our models on a range of benchmarks covering document understanding (`DocVQA`), chart understanding (`ChartQA`, `InfoVQA`),
160
+ visual text reading (`TextVQA`, `OCRBench`), and general QA (`RealWorldQA`, `AI2D`, `GQA`, `MME`). Results are reported below. Please refer to our [project page][blog] and [arxiv paper][casa-arxiv] for additional evaluation.
161
+
162
+ <table style="border-collapse: collapse;">
163
+ <tr>
164
+ <th rowspan="2" align="left">Model</th>
165
+ <th colspan="3" align="center">Document / Chart</th>
166
+ <th colspan="2" align="center">Scene Text</th>
167
+ <th colspan="4" align="center">Knowledge / QA</th>
168
+ </tr>
169
+ <tr>
170
+ <th>ChartQA</th>
171
+ <th>DocVQA</th>
172
+ <th>InfoVQA</th>
173
+ <th>OCRBench</th>
174
+ <th>TextVQA</th>
175
+ <th>RealWorldQA</th>
176
+ <th>AI2D</th>
177
+ <th>GQA</th>
178
+ <th>MME</th>
179
+ </tr>
180
+ </thead>
181
+ <tbody>
182
+ <tr>
183
+ <td align="left">Helium1-VL-2B</td>
184
+ <td>81.6</td><td>89.1</td><td>61.8</td>
185
+ <td>728</td><td>75.5</td>
186
+ <td>59.9</td><td>67.7</td><td>55.5</td><td>1732</td>
187
+ </tr>
188
+ <tr>
189
+ <td align="left"><span style="color:#fb923c;"><strong>CASA-Helium1-VL-2B</strong></span></td>
190
+ <td>73.4</td><td>83.7</td><td>48.6</td>
191
+ <td>723</td><td>71.0</td>
192
+ <td>58.3</td><td>63.3</td><td>54.6</td><td>1572</td>
193
+ </tr>
194
+ <tr>
195
+ <td align="left"><span style="color:#60a5fa;">mPLUG-Owl3 8B</span></td>
196
+ <td>59.2<sup>†</sup></td><td>55.9<sup>†</sup></td><td>36.8<sup>†</sup></td>
197
+ <td>527<sup>†</sup></td><td>69.0</td>
198
+ <td>63.9<sup>†</sup></td><td>73.4</td><td>65.0</td><td>1940<sup>†</sup></td>
199
+ </tr>
200
+ <tr>
201
+ <td align="left"><span style="color:#60a5fa;">mPLUG-Owl3 2B</span></td>
202
+ <td>48.5<sup>†</sup></td><td>48.2<sup>†</sup></td><td>28.1<sup>†</sup></td>
203
+ <td>450<sup>†</sup></td><td>62.6</td>
204
+ <td>56.9<sup>†</sup></td><td>62.6</td><td>61.0</td><td>1551<sup>†</sup></td>
205
+ </tr>
206
+ </tbody>
207
+ </table>
208
+ <p>
209
+ <sup>†</sup> Reproduced with the publicly available models on Hugging Face. &ensp;
210
+ <!-- ◇ Results and model not publicly available. -->
211
+ </p>
212
+
213
+ <p align="center">
214
+ <em>
215
+ Results for <code>CASA-Helium1-VL-2B</code> compared to a recent cross-attention baseline (blue), and our token insertion
216
+ (<code>Helium1-VL-2B</code> trained in the same conditions. CASA outperforms current SoTA
217
+ cross-attention-based VLMs, narrowing the gap to insertion-based approaches.
218
+ </em>
219
+ </p>
220
+
221
+ <table style="border-collapse: collapse;">
222
+ <thead>
223
+ <tr>
224
+ <th rowspan="2" align="left">Model</th>
225
+ <th colspan="3" align="center">Document / Chart</th>
226
+ <th colspan="2" align="center">Scene Text</th>
227
+ <th colspan="4" align="center">Knowledge / QA</th>
228
+ </tr>
229
+ <tr>
230
+ <th>ChartQA</th>
231
+ <th>DocVQA</th>
232
+ <th>InfoVQA</th>
233
+ <th>OCRBench</th>
234
+ <th>TextVQA</th>
235
+ <th>RealWorldQA</th>
236
+ <th>AI2D</th>
237
+ <th>GQA</th>
238
+ <th>MME</th>
239
+ </tr>
240
+ </thead>
241
+
242
+ <tbody>
243
+ <tr>
244
+ <td align="left">
245
+ Qwen2.5-VL-3B
246
+ </td>
247
+ <td>84.0</td><td>93.6</td><td>77.1</td>
248
+ <td>797</td><td>79.3</td>
249
+ <td>62.2<sup>†</sup></td><td>81.6</td><td>61.0<sup>†</sup></td><td>2249<sup>†</sup></td>
250
+ </tr>
251
+ <tr>
252
+ <td align="left">
253
+ <span style="color:#fb923c;"><strong>CASA-Qwen2_5-VL-3B</strong></span>
254
+ </td>
255
+ <td>82.4</td><td>88.9</td><td>59.6</td>
256
+ <td>790</td><td>77.4</td>
257
+ <td>62.5</td><td>75.1</td><td>59.4</td><td>1918</td>
258
+ </tr>
259
+ </tbody>
260
+ </table>
261
+
262
+ <p>
263
+ <sup>†</sup> Reproduced with the publicly available models on Hugging Face.
264
+ </p>
265
+
266
+ <p align="center">
267
+ <em>
268
+ Results for <code>CASA-Qwen2_5-VL-3B</code>, adapted from frozen Qwen2.5-VL. CASA reaches performance close to the original
269
+ insertion-based model while while training only
270
+ the CASA layers and last blocks of the image encoder.
271
+ </em>
272
+ </p>
273
+
274
+
275
+ ## Technical Specifications
276
+
277
+
278
+ ### Compute Infrastructure
279
+
280
+ `CASA-Helium1-2B` was trained starting from a `Helium1-2B` LLM and the image encoder from `Qwen2.5-VL-3B`.
281
+ We finetune the whole LLM backbone as well as the last four blocks of the image encoder.
282
+ The currently released model was trained on four DGX nodes with 8 H100 GPUs.
283
+
284
+
285
+ #### Software
286
+
287
+ Our training code and inference code was implemented in Pytorch.
288
+
289
+ ## Citation
290
+
291
+ ```
292
+ @article{kyutai2025casa,
293
+ author = {Moritz Böhle and Amélie Royer and Juliette Marrie and Edouard Grave and Patrick Pérez},
294
+ year = {2025},
295
+ title = {CASA: Cross-Attention vis Self-Attention},
296
+ journal = {ArXiv},
297
+ url = {https://arxiv.org/abs/2512.19535}
298
+ }
299
+ ```
300
+
301
+
302
+ ## Model Card Authors and Contact
303
+
304
+ * Amelie Royer
305
+ * Moritz Boehle
306
+ * Juliette Marrie
307
+
308
+
309
+ [blog]: https://kyutai.org/casa
310
+ [casa-arxiv]: https://arxiv.org/abs/2512.19535
311
+ [casa-git]: https://github.com/kyutai-labs/casa
__init__.py ADDED
File without changes
casa_attention.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CASA layers"""
2
+
3
+ import bisect
4
+ from dataclasses import dataclass
5
+ from itertools import accumulate
6
+ from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload
7
+ from typing import cast as type_cast
8
+
9
+ import torch
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+ from .utils import StreamingModule, StreamingState, delta_w_factory
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers.configuration_utils import PretrainedConfig
16
+
17
+ try:
18
+ from flash_attn import flash_attn_varlen_func
19
+ except ImportError:
20
+ flash_attn_varlen_func = None # type: ignore
21
+
22
+
23
+ WindowsComputeKwargs = TypedDict(
24
+ "WindowsComputeKwargs",
25
+ {
26
+ "num_post_image_tokens": int,
27
+ "num_pre_image_tokens": int,
28
+ },
29
+ total=False,
30
+ )
31
+
32
+
33
+ def __split_n_merge__(
34
+ x: torch.Tensor,
35
+ sample_lengths: list[int],
36
+ padding_side: Literal["left", "right"] = "right",
37
+ pad_value: int | float | bool = 0,
38
+ ) -> torch.Tensor:
39
+ max_sample_length = max(sample_lengths)
40
+ pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2))
41
+ return torch.stack(
42
+ [
43
+ torch.nn.functional.pad(
44
+ _x,
45
+ pad_tuple + (0, max_sample_length - _x.shape[0])
46
+ if padding_side == "right"
47
+ else pad_tuple + (max_sample_length - _x.shape[0], 0),
48
+ value=pad_value,
49
+ )
50
+ for _x in torch.split(x, sample_lengths, dim=0)
51
+ ],
52
+ dim=0,
53
+ )
54
+
55
+
56
+ @overload
57
+ def insert_image_tokens(
58
+ inputs_embeds: torch.Tensor,
59
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
60
+ image_embeds_insertion_points: list[torch.Tensor],
61
+ recover_batch_dim: Literal[True],
62
+ attention_mask: torch.Tensor | None = None,
63
+ padding_side: Literal["left", "right"] = "right",
64
+ keep_only_attended: bool = False,
65
+ pad_output: int | float | bool = 0.0,
66
+ ) -> tuple[
67
+ torch.Tensor,
68
+ None,
69
+ torch.Tensor | None,
70
+ torch.Tensor,
71
+ ]: ...
72
+ @overload
73
+ def insert_image_tokens(
74
+ inputs_embeds: torch.Tensor,
75
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
76
+ image_embeds_insertion_points: list[torch.Tensor],
77
+ recover_batch_dim: Literal[False],
78
+ attention_mask: torch.Tensor | None = None,
79
+ padding_side: Literal["left", "right"] = "right",
80
+ keep_only_attended: bool = False,
81
+ pad_output: int | float | bool = 0.0,
82
+ ) -> tuple[
83
+ torch.Tensor,
84
+ list[int],
85
+ torch.Tensor | None,
86
+ torch.Tensor,
87
+ ]: ...
88
+ def insert_image_tokens(
89
+ inputs_embeds: torch.Tensor,
90
+ image_embeds: torch.Tensor | Sequence[torch.Tensor],
91
+ image_embeds_insertion_points: list[torch.Tensor],
92
+ recover_batch_dim: bool = True,
93
+ attention_mask: torch.Tensor | None = None,
94
+ padding_side: Literal["left", "right"] = "right",
95
+ keep_only_attended: bool = False,
96
+ pad_output: int | float | bool = 0.0,
97
+ ) -> tuple[
98
+ torch.Tensor | torch.Tensor,
99
+ list[int] | None,
100
+ torch.Tensor | torch.Tensor | None,
101
+ torch.Tensor | torch.Tensor,
102
+ ]:
103
+ """
104
+ Insert image embeddings into text embeddings
105
+
106
+ Args:
107
+ inputs_embeds (torch.Tensor): (B, S, D) input token embeddings.
108
+ image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings.
109
+ image_embeds_insertion_points (list[torch.Tensor]): Insertion indices.
110
+ attention_mask (torch.Tensor, optional): (B, S) attention mask.
111
+ padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images.
112
+ return_indices (bool): Whether to return gather indices or the fused sequence directly.
113
+ keep_only_attended: This is only applicable when recover_batch_dim is False; whether to
114
+ remove any non-attended tokens in the whole array. In this case, the attention
115
+ mask returned is **still the original one**, so we can remember which indices have been
116
+ removed
117
+ Returns:
118
+ output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence
119
+ image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list
120
+ attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding.
121
+ image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions.
122
+ """
123
+ if isinstance(image_embeds, list) and len(image_embeds) == 0:
124
+ batch_size, text_seq_length, token_dim = inputs_embeds.shape
125
+ if recover_batch_dim:
126
+ return (
127
+ inputs_embeds,
128
+ None,
129
+ attention_mask,
130
+ torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool),
131
+ )
132
+ else:
133
+ flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1]
134
+ return (
135
+ torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])),
136
+ [text_seq_length] * inputs_embeds.shape[0],
137
+ attention_mask.flatten() if attention_mask is not None else None,
138
+ torch.zeros((flattened_seq_length, 1), dtype=torch.bool),
139
+ )
140
+
141
+ # Sanity checks
142
+ if isinstance(image_embeds, torch.Tensor):
143
+ assert inputs_embeds.shape[-1] == image_embeds.shape[-1]
144
+ else:
145
+ assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds)
146
+
147
+ batch_size, text_seq_length, token_dim = inputs_embeds.shape
148
+ image_seq_length = [x.shape[0] for x in image_embeds]
149
+
150
+ # Flatten insertion points
151
+ insertion_offset = []
152
+ counter, offset_from_text, offset_from_image = 0, 0, 0
153
+ for sample in image_embeds_insertion_points:
154
+ for pt in sample:
155
+ insertion_offset.append(pt + offset_from_image + offset_from_text)
156
+ offset_from_image += image_seq_length[counter]
157
+ counter += 1
158
+ offset_from_text += text_seq_length
159
+ image_insert_positions = [
160
+ x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx])
161
+ ]
162
+
163
+ # Flatten image embeds
164
+ if isinstance(image_embeds, list):
165
+ image_embeds = torch.cat(image_embeds, dim=0)
166
+ else:
167
+ image_embeds = type_cast(torch.Tensor, image_embeds)
168
+ image_embeds = torch.reshape(image_embeds, (-1, token_dim))
169
+
170
+ # Flatten text embeds across batch dim (B x S, D)
171
+ inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim))
172
+ flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length)
173
+ text_insert_positions = sorted(
174
+ set(range(flattened_seq_length)).difference(set(image_insert_positions))
175
+ )
176
+
177
+ # Scatter image embeds in the flattened dict
178
+ # scatter text related stuff
179
+ output = torch.empty(
180
+ (flattened_seq_length, token_dim),
181
+ device=inputs_embeds.device,
182
+ dtype=inputs_embeds.dtype,
183
+ )
184
+ txt_positions_tensor = torch.Tensor(text_insert_positions).to(
185
+ dtype=torch.long, device=inputs_embeds.device
186
+ )
187
+ output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds)
188
+ attention_mask_new: torch.Tensor | None = None
189
+ if attention_mask is not None:
190
+ attention_mask_new = torch.ones(
191
+ (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
192
+ )
193
+ attention_mask_new.scatter_(
194
+ 0, txt_positions_tensor, attention_mask.flatten().to(torch.bool)
195
+ )
196
+
197
+ # scatter image related stuff
198
+ image_tokens_mask = torch.zeros(
199
+ (flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
200
+ )
201
+ img_positions_tensor = torch.Tensor(image_insert_positions).to(
202
+ device=inputs_embeds.device, dtype=torch.long
203
+ )
204
+ output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds)
205
+ image_tokens_mask.scatter_(0, img_positions_tensor, True)
206
+
207
+ # Compute expected sample length, taking into account the real batch
208
+ # i.e. recover the batch dimension of image embeddings
209
+ sample_lengths = []
210
+ counter = 0
211
+ for sample_idx, pts in enumerate(image_embeds_insertion_points):
212
+ num_image_tokens = 0
213
+ for _ in pts:
214
+ num_image_tokens += image_seq_length[counter]
215
+ counter += 1
216
+ if keep_only_attended and attention_mask is not None:
217
+ attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item()
218
+ sample_lengths.append(attended_seq_length + num_image_tokens)
219
+ else:
220
+ sample_lengths.append(text_seq_length + num_image_tokens)
221
+
222
+ # For CASA attention, we can keep stuff flatten ad return
223
+ # the sample_lengths for the blockwise attention
224
+ if not recover_batch_dim:
225
+ if keep_only_attended and attention_mask_new is not None:
226
+ output = output[attention_mask_new]
227
+ image_tokens_mask = image_tokens_mask[attention_mask_new]
228
+ return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None]
229
+
230
+ # Otherwise, time to (pad) and reshape
231
+ # Easy case: everything has the same length
232
+ if all(x == sample_lengths[0] for x in sample_lengths):
233
+ output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim))
234
+ image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1))
235
+ if attention_mask_new is not None:
236
+ attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0]))
237
+ # if there is any size mismatch we break into a
238
+ # list and pad again
239
+ else:
240
+ # split and merge
241
+ output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output)
242
+ # note that the extra padding tokens are also marked as image tokens to be removed later
243
+ image_tokens_mask = __split_n_merge__(
244
+ image_tokens_mask, sample_lengths, padding_side, True
245
+ )[:, :, None]
246
+ if attention_mask_new is not None:
247
+ attention_mask_new = __split_n_merge__(
248
+ attention_mask_new, sample_lengths, padding_side, 0
249
+ )
250
+ # Return
251
+ return output, sample_lengths, attention_mask_new, image_tokens_mask
252
+
253
+
254
+ def get_sample_lengths_from_insertion_points(
255
+ image_embeds_insertion_points: list[torch.Tensor],
256
+ image_embeds: torch.Tensor | list[torch.Tensor] | None,
257
+ total_seq_len: int | None = None,
258
+ attention_mask: torch.Tensor | None = None,
259
+ **kwargs: WindowsComputeKwargs,
260
+ ) -> tuple[list[tuple[int, bool]], list[int]]:
261
+ """Compute sample lengths as if each image insertion point defines a
262
+ new document (ex document ID)
263
+ """
264
+ num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0))
265
+ num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0))
266
+ squashed_samples_lengths = type_cast(
267
+ list[list[int]] | None, kwargs.get("squashed_samples_lengths", None)
268
+ )
269
+ if squashed_samples_lengths is not None:
270
+ assert len(squashed_samples_lengths) == len(image_embeds_insertion_points)
271
+
272
+ def __insert_next_sample__(
273
+ batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False
274
+ ) -> None:
275
+ nonlocal attention_mask
276
+ nonlocal text_sample_lengths, full_sample_lengths
277
+ nonlocal cum_samples_lengths, current_image_offset
278
+ nonlocal last_image_idx, current_image_idx, current_length
279
+ # Add the sample between [last_insrt_pt, insrt_pt] with breaks in
280
+ # between any squashed samples we find on the way
281
+ start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt)
282
+ added_sample = False
283
+ for end_of_sample in cum_samples_lengths[start_pt:]:
284
+ # we will break the loop at the end when end_of_sample = insrt_pt
285
+ end_of_sample = min(end_of_sample, insrt_pt)
286
+
287
+ # Add between [last_insrt_pt, end_of_sample]
288
+ current_length = end_of_sample - last_insrt_pt
289
+ if attention_mask is not None:
290
+ current_length -= int(
291
+ torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item()
292
+ )
293
+ if current_length > 0:
294
+ added_sample = True
295
+ text_sample_lengths.append(
296
+ (current_length, end_of_batch_sample and insrt_pt == end_of_sample)
297
+ )
298
+ # add image tokens to current_length
299
+ if current_image_idx > 0 and image_embeds is not None:
300
+ images_in_sample = [
301
+ img_idx
302
+ for img_idx in range(last_image_idx, current_image_idx)
303
+ if img_idx < len(image_embeds_insertion_points[batch_idx])
304
+ and last_insrt_pt
305
+ <= image_embeds_insertion_points[batch_idx][img_idx]
306
+ < end_of_sample
307
+ ]
308
+ if len(images_in_sample) > 0:
309
+ num_image_tokens = sum(
310
+ _x.shape[0]
311
+ for _x in image_embeds[
312
+ current_image_offset + images_in_sample[0] : current_image_offset
313
+ + images_in_sample[-1]
314
+ + 1
315
+ ]
316
+ )
317
+ current_length += num_image_tokens
318
+ full_sample_lengths.append(current_length)
319
+
320
+ # prepare for next loop
321
+ last_insrt_pt = end_of_sample
322
+ if end_of_sample == insrt_pt:
323
+ break
324
+ # End of loop: Catching weird use case where we may end up on a span
325
+ # full of padding tokens which will not get added due to current_length > 0
326
+ if end_of_batch_sample:
327
+ assert added_sample, "Weird edge case. Don't do that, thank you"
328
+ text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
329
+
330
+ # End of loop: Catching weird use case where we may end up on a span
331
+ # full of padding tokens which will not get added due to current_length > 0
332
+ if end_of_batch_sample:
333
+ assert added_sample, "Weird edge case. Don't do that, thank you"
334
+ text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
335
+
336
+ current_image_offset = 0
337
+ text_sample_lengths, full_sample_lengths = [], []
338
+ cum_samples_lengths: list[int] = []
339
+ current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
340
+ for batch_idx, pts in enumerate(image_embeds_insertion_points):
341
+ if squashed_samples_lengths is not None:
342
+ cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx]))
343
+ else:
344
+ assert total_seq_len is not None
345
+ cum_samples_lengths = [total_seq_len]
346
+
347
+ for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()):
348
+ # check if the images are consecutive in which way we want
349
+ # them to belong to the same window
350
+ if current_image_idx >= 1 and insrt_pt == (
351
+ image_embeds_insertion_points[batch_idx][current_image_idx - 1]
352
+ + num_pre_image_tokens
353
+ + num_post_image_tokens
354
+ ):
355
+ continue
356
+ # Otherwise, we found a new sample
357
+ # not very important but for completeness: the insertion points come *after*
358
+ # the pre-image tokens per design but for the document-id mask it is more consistent to
359
+ # have them correspond to the same image
360
+ insrt_pt -= num_pre_image_tokens
361
+
362
+ # Update text and full sample lengths
363
+ if insrt_pt > last_insrt_pt:
364
+ __insert_next_sample__(
365
+ batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False
366
+ )
367
+ last_image_idx = current_image_idx
368
+ last_insrt_pt = insrt_pt
369
+
370
+ # End of batch: add sample in progress and reset
371
+ current_image_idx += 1
372
+ if cum_samples_lengths[-1] > last_insrt_pt:
373
+ __insert_next_sample__(
374
+ batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True
375
+ )
376
+ current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
377
+ current_image_offset += len(pts)
378
+
379
+ # Sanity checks that the is_eob are correctly place
380
+ assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), (
381
+ f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs"
382
+ f" from original batch size ({len(image_embeds_insertion_points)})"
383
+ )
384
+ return text_sample_lengths, full_sample_lengths
385
+
386
+
387
+ class CASAAttentionHandler:
388
+ def __init__(
389
+ self,
390
+ inputs_embeds: torch.Tensor,
391
+ image_embeds: torch.Tensor | list[torch.Tensor],
392
+ image_embeds_insertion_points: list[torch.Tensor],
393
+ attention_mask: torch.Tensor | None = None,
394
+ rope_fn: Callable | None = None,
395
+ windows: Literal["batch", "squashed", "images", "turn_based"] = "images",
396
+ use_asymetric_q_kv: bool = True,
397
+ casa_windows_info: None | dict = None,
398
+ ):
399
+ """Initialize the structure holding the query buffer for CASA attention layers
400
+ (ie the **flattened** text+image inserted tokens).
401
+ Note that this structure is shared across all casa layers, and it gets updated
402
+ with the current hidden states at every layer; this is merely a buffer to keep
403
+ scatter_ operations in-plae as much as possible
404
+
405
+ In this module, the embeddings related values (image_tokens_mask,
406
+ text_sample_lengths etc) are stored under the assumption of a tensor
407
+ which is *flatened* and *witout padding tokens*
408
+ Only the attention mask is kept as-is (text-only, batched, padded) to
409
+ be able to recover original shapes when needed
410
+ """
411
+ super().__init__()
412
+ assert windows == "images" # for inference code release
413
+ # Note 1: Unless overriden, text/full_sample_lengths are defined such that one
414
+ # document = one sample in the batch
415
+ if attention_mask is None:
416
+ text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds]
417
+ else:
418
+ text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask]
419
+ (
420
+ full_inputs_embeds,
421
+ full_sample_lengths,
422
+ # Full attention mask is only needed at inference to
423
+ # flatten the KV-Cache and remove padding tokens
424
+ _,
425
+ self.image_tokens_mask,
426
+ ) = insert_image_tokens(
427
+ inputs_embeds=inputs_embeds,
428
+ image_embeds=image_embeds,
429
+ image_embeds_insertion_points=image_embeds_insertion_points,
430
+ attention_mask=attention_mask,
431
+ recover_batch_dim=False,
432
+ keep_only_attended=attention_mask is not None,
433
+ )
434
+ assert self.image_tokens_mask.ndim == 2
435
+ self.image_embeds = image_embeds
436
+ self.image_embeds_insertion_points = image_embeds_insertion_points
437
+ self.attention_mask = None if attention_mask is None else attention_mask.bool()
438
+ self.use_asymetric_qkv = use_asymetric_q_kv
439
+ # At inference, we have to use asymetric QKV for efficiency
440
+ if self.attention_mask is not None:
441
+ self.use_asymetric_qkv = True
442
+
443
+ # Build CASA windows
444
+ assert casa_windows_info is not None
445
+ text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points(
446
+ image_embeds_insertion_points=image_embeds_insertion_points,
447
+ image_embeds=image_embeds,
448
+ total_seq_len=inputs_embeds.shape[1],
449
+ attention_mask=self.attention_mask,
450
+ **casa_windows_info, # pyright: ignore
451
+ )
452
+
453
+ # Sanity checks on the sample lengths
454
+ self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0]
455
+ self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0]
456
+
457
+ assert len(self.text_sample_lengths) == len(self.full_sample_lengths), (
458
+ f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}"
459
+ f" != full sample lengths {len(self.full_sample_lengths)}"
460
+ )
461
+ if self.attention_mask is None:
462
+ num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1]
463
+ else:
464
+ num_unpadded_text_tokens = int(
465
+ torch.sum(type_cast(torch.Tensor, attention_mask)).item()
466
+ )
467
+ assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, (
468
+ f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
469
+ )
470
+ assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], (
471
+ f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
472
+ )
473
+
474
+ # Finally we can compute cu_seqlen based on sample lengths
475
+ self.max_seqlen_q = max(self.text_sample_lengths)[0]
476
+ self.cu_seqlens_q = self.get_cu_seqlens(
477
+ [x[0] for x in self.text_sample_lengths], device=inputs_embeds.device
478
+ )
479
+
480
+ self.max_seqlen_kv = max(self.full_sample_lengths)
481
+ self.cu_seqlens_kv = self.get_cu_seqlens(
482
+ self.full_sample_lengths, device=inputs_embeds.device
483
+ )
484
+
485
+ # For inference: We save the length of the current document
486
+ # to trim the KV cache appropriately
487
+ self.current_doc_lengths = self.full_sample_lengths
488
+
489
+ # Precompute position embeddings
490
+ self.position_embeds = None
491
+ self.rope_fn = rope_fn
492
+ if self.rope_fn is not None:
493
+ self.position_embeds = self.compute_position_embeddings(
494
+ self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds
495
+ )
496
+
497
+ @property
498
+ def batch_lengths(self) -> list[int]:
499
+ """Return a (batch_size,) list of integers containing the
500
+ number of (non-padded) text tokens for each sample in the batch"""
501
+ bls = [0]
502
+ for ln, eob in self.text_sample_lengths:
503
+ bls[-1] += ln
504
+ if eob:
505
+ bls.append(0)
506
+ return bls[:-1]
507
+
508
+ @property
509
+ def full_batch_lengths(self) -> list[int]:
510
+ """Same as batch_lengths for text+image tokens"""
511
+ bls = [0]
512
+ for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths):
513
+ bls[-1] += ln
514
+ if eob:
515
+ bls.append(0)
516
+ return bls[:-1]
517
+
518
+ def get_cu_seqlens(
519
+ self, sample_lengths: list[int], device: torch.device | None
520
+ ) -> torch.Tensor:
521
+ """Update cu_seqlengths according to the given sample_lengths"""
522
+ return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to(
523
+ dtype=torch.int32, device=device
524
+ )
525
+
526
+ def compute_position_embeddings(
527
+ self,
528
+ rope_fn: Callable,
529
+ sample_lengths: list[int],
530
+ dummy_for_dtype_and_device: torch.Tensor,
531
+ ) -> tuple[torch.Tensor, torch.Tensor]:
532
+ """Compute info required for position embeddings. Can be override e.g. for Qwen"""
533
+ # option 1: Standard range
534
+ # position_ids = torch.arange(0, full_inputs_embeds.shape[0])
535
+ # option 2: Follows document boundary
536
+ position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0)
537
+ return rope_fn(
538
+ dummy_for_dtype_and_device,
539
+ position_ids.to(dummy_for_dtype_and_device.device)[None, ...],
540
+ )
541
+
542
+ def get_position_embedding(
543
+ self,
544
+ key: Literal["q", "kv"],
545
+ num_queries: int = 0,
546
+ ) -> tuple[torch.Tensor, torch.Tensor] | None:
547
+ if self.position_embeds is None:
548
+ return None
549
+ cos, sin = self.position_embeds
550
+ bls = self.full_batch_lengths
551
+ # For Q, we only want the text-only posembeds
552
+ if key == "q" and self.use_asymetric_qkv:
553
+ bls = self.batch_lengths
554
+ cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]]
555
+ elif key not in {"q", "kv"}:
556
+ raise ValueError(f"Unknow for position embedding {key}")
557
+
558
+ # Easy case: training or first step at inference: we use all the posembeds
559
+ if num_queries == 0:
560
+ return cos, sin
561
+ # If num queries is given, we need to trim for *every sample in the batch*
562
+ cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)]
563
+ sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)]
564
+ return torch.cat(cos, dim=1), torch.cat(sin, dim=1)
565
+
566
+ def get_full_embeds(
567
+ self, hidden_states: torch.Tensor, norm_fn: Callable | None
568
+ ) -> torch.Tensor:
569
+ """Update attended hidden states in the current query buffer
570
+
571
+ :param hidden_states: (b, s, d) Tensor input to the CASA attention layer"
572
+ """
573
+ assert self.image_embeds is not None
574
+ return insert_image_tokens(
575
+ inputs_embeds=hidden_states,
576
+ image_embeds=self.image_embeds
577
+ if norm_fn is None
578
+ else norm_fn(self.image_embeds)
579
+ if isinstance(self.image_embeds, torch.Tensor)
580
+ else [norm_fn(_x) for _x in self.image_embeds],
581
+ image_embeds_insertion_points=self.image_embeds_insertion_points,
582
+ attention_mask=self.attention_mask,
583
+ recover_batch_dim=False,
584
+ keep_only_attended=self.attention_mask is not None,
585
+ )[0][None, :, :]
586
+
587
+ def recover_text_embeds(
588
+ self,
589
+ hidden_states_out: torch.Tensor,
590
+ hidden_states_in: torch.Tensor,
591
+ update_image_embeddings: bool = False,
592
+ ) -> torch.Tensor:
593
+ """Returns text embeddings from the query buffer, including non-attended tokens at inference"""
594
+ if update_image_embeddings and not self.use_asymetric_qkv:
595
+ raise NotImplementedError("Implement image embeddings updates for asymetric QKV")
596
+ # Remove image tokens in the symetric case
597
+ if not self.use_asymetric_qkv:
598
+ hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]]
599
+
600
+ # if there's not attention mask, we are in the right padded case
601
+ # (keep_only_attended = False) we can directly return the query
602
+ # outputs (which don't contain the image)
603
+ if self.attention_mask is None:
604
+ return hidden_states_out
605
+
606
+ # Otherwise, we need to "scatter" back only the text-attended tokens to the original
607
+ # hidden states, which contain the paddings
608
+ num_queries = hidden_states_in.shape[1]
609
+
610
+ # Case 1: the padded hidden_states_in is larger than hidden_states_out
611
+ # we rebatch+pad hidden_state_out before doing the scattering
612
+ if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]:
613
+ s = torch.split(hidden_states_out, self.batch_lengths, dim=0)
614
+ assert max(_s.shape[0] for _s in s) <= num_queries # sanity check
615
+ s = [
616
+ torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0)
617
+ for _s in s
618
+ ]
619
+ return torch.where(
620
+ self.attention_mask[:, -num_queries:, None],
621
+ torch.stack(s),
622
+ hidden_states_in,
623
+ )
624
+ # If both have the smae shape, it means hidden_states_in contained no padding
625
+ # so we can directly return hidden states out
626
+ return hidden_states_out
627
+
628
+ def extend(self, num_tokens: int, offset: int = 0):
629
+ """Extend all necessary values of the Handler for infenrece
630
+ Note: this implementation curently assumes a single conversation at a time
631
+ (otherwise image tokens mask would have to change) and that tokens added are
632
+ attended to"""
633
+ # image embeds is inserted in the first step and stored in the KV cache
634
+ self.image_embeds = None
635
+
636
+ # Update attention mask (non-flattened) (assumes all new tokens are attended to)
637
+ if self.attention_mask is not None:
638
+ self.attention_mask = torch.nn.functional.pad(
639
+ self.attention_mask, (0, num_tokens), value=1
640
+ )
641
+
642
+ # Update image token mask (assumes only one image/conversation
643
+ # is started at once so that we always extend by zero)
644
+ # Note that the mask is stored flattened to avoid padding so we have to
645
+ # do something a bit ugly and inefficient here
646
+ imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0)
647
+ imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask]
648
+ self.image_tokens_mask = torch.cat(imtokmask, dim=0)
649
+
650
+ # Recompute cumulative document lengths after assigning the new
651
+ # number of tokens to each sample in the batch
652
+ for idx, (ln, is_eob) in enumerate(self.text_sample_lengths):
653
+ if is_eob:
654
+ self.text_sample_lengths[idx] = (num_tokens + ln, is_eob)
655
+ self.full_sample_lengths[idx] += num_tokens
656
+
657
+ # Recompute cu sequlen
658
+ # First step: Technically this never occurs, but we keep it for completeness
659
+ if offset == 0:
660
+ self.max_seqlen_q = max(self.text_sample_lengths)[0]
661
+ self.cu_seqlens_q = self.get_cu_seqlens(
662
+ [x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device
663
+ )
664
+
665
+ self.max_seqlen_kv = max(self.full_sample_lengths)
666
+ self.cu_seqlens_kv = self.get_cu_seqlens(
667
+ self.full_sample_lengths, device=self.cu_seqlens_kv.device
668
+ )
669
+ # Step > 0: the annoying part is since flashattn_varlen does not accept
670
+ # 0-len documents, we need to remove documents from the KV Cache when they're past
671
+ # their windows. In our current setting, this means we only want to keep the latest
672
+ # documents
673
+ else:
674
+ self.max_seqlen_q = num_tokens
675
+ self.cu_seqlens_q = self.get_cu_seqlens(
676
+ [num_tokens for (_, eob) in self.text_sample_lengths if eob],
677
+ device=self.cu_seqlens_q.device,
678
+ )
679
+
680
+ final_doc_lengths = [
681
+ ln
682
+ for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths)
683
+ if eob
684
+ ]
685
+ self.current_doc_lengths = final_doc_lengths
686
+ self.max_seqlen_kv = max(self.current_doc_lengths)
687
+ self.cu_seqlens_kv = self.get_cu_seqlens(
688
+ final_doc_lengths,
689
+ device=self.cu_seqlens_kv.device,
690
+ )
691
+ # Update position embeddings
692
+ if self.rope_fn is not None and self.position_embeds is not None:
693
+ self.position_embeds = self.compute_position_embeddings(
694
+ self.rope_fn,
695
+ self.full_sample_lengths,
696
+ dummy_for_dtype_and_device=self.position_embeds[0],
697
+ )
698
+
699
+
700
+ @dataclass
701
+ class CASAAttentionStreamingState(StreamingState):
702
+ """Streaming State for CASA Atention module. Keep the hidden"""
703
+
704
+ k: torch.Tensor = None # pyright: ignore[reportAssignmentType]
705
+ v: torch.Tensor = None # pyright: ignore[reportAssignmentType]
706
+ recover_batched_trims: list[int] = None # pyright: ignore[reportAssignmentType]
707
+ casa_handler: CASAAttentionHandler = None # pyright: ignore[reportAssignmentType]
708
+
709
+ def maybe_get_casa_handler(
710
+ self,
711
+ casa_handler: CASAAttentionHandler | None,
712
+ is_first_casa_layer: bool = False,
713
+ num_queries: int = -1,
714
+ ) -> CASAAttentionHandler | None:
715
+ # Set given Casa Handler the first time we reach this
716
+ if self.casa_handler is None:
717
+ self.casa_handler = casa_handler # pyright: ignore
718
+ # subsequent calls: we need to extend shape to accomodate new tokens
719
+ # however because CASA handler is shared across layers, we only need to do it once
720
+ if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer:
721
+ # since CasaHandler is shared, we only use its extend step once
722
+ self.casa_handler.extend(num_queries, offset=self.offset)
723
+ return self.casa_handler
724
+
725
+ def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor:
726
+ """Recover batched key/value states with left padding"""
727
+ s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1)
728
+ mlen = max(_s.shape[1] for _s in s)
729
+ # Remember the added padding so that we can re-flatten KV later
730
+ if self.recover_batched_trims is None:
731
+ self.recover_batched_trims = [mlen - _s.shape[1] for _s in s]
732
+ s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s]
733
+ return torch.cat(s, dim=0)
734
+
735
+ def __get_flattened_kv__(
736
+ self, k: torch.Tensor | None = None, v: torch.Tensor | None = None
737
+ ) -> tuple[torch.Tensor, torch.Tensor]:
738
+ """
739
+ Flattened and remove padding to act with flash_attn_func
740
+ """
741
+ k = self.k if k is None else k
742
+ v = self.v if v is None else v
743
+ assert k is not None and v is not None
744
+
745
+ # Since every batch at least contributes one document,
746
+ # we can use this to check whether we are in streaming mode with dropped docs.
747
+ # If so, we should trim the kv cache accordingly
748
+ if len(self.casa_handler.current_doc_lengths) == len(k):
749
+ k = torch.cat(
750
+ [
751
+ _k[self.recover_batched_trims[idx] :][-doc_len:]
752
+ for idx, _k, doc_len in zip(
753
+ range(len(k)), k, self.casa_handler.current_doc_lengths
754
+ )
755
+ ]
756
+ )
757
+ v = torch.cat(
758
+ [
759
+ _v[self.recover_batched_trims[idx] :][-doc_len:]
760
+ for idx, _v, doc_len in zip(
761
+ range(len(k)), v, self.casa_handler.current_doc_lengths
762
+ )
763
+ ]
764
+ )
765
+ return k[None, ...], v[None, ...]
766
+
767
+ k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)])
768
+ v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)])
769
+ return k[None, ...], v[None, ...]
770
+
771
+ def extend_kv(
772
+ self, key_states: torch.Tensor, value_states: torch.Tensor
773
+ ) -> tuple[torch.Tensor, torch.Tensor]:
774
+ """
775
+ Extend KV Cache while keep
776
+ """
777
+ assert self.casa_handler is not None
778
+ if self.k is None and self.v is None:
779
+ # Init with batch-padded key and value states
780
+ self.k = self.__recover_batched_kv__(key_states)
781
+ self.v = self.__recover_batched_kv__(value_states)
782
+ return self.__get_flattened_kv__()
783
+ if self.k is not None and self.v is not None:
784
+ # this is during generation; normally there is no padding at this stage
785
+ # so we can directly reshape the flattened key states
786
+ rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3])
787
+ self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1)
788
+ self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1)
789
+ return self.__get_flattened_kv__()
790
+
791
+ raise ValueError("Impossible configuration (k and v updates are desynchronized )")
792
+
793
+
794
+ class CASAAttention(StreamingModule[CASAAttentionStreamingState]):
795
+ def __init__(
796
+ self,
797
+ config: "PretrainedConfig",
798
+ layer_idx: int | None,
799
+ self_attn: torch.nn.Module | None = None,
800
+ input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
801
+ ):
802
+ super().__init__(CASAAttentionStreamingState)
803
+ self.head_dim = config.head_dim
804
+ self.config = config
805
+
806
+ self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0)
807
+ self.use_delta_w = config.casa_delta_w
808
+
809
+ self.q_proj_casa = self.init_from_config_proj("q", config)
810
+ self.k_proj_casa = self.init_from_config_proj("k", config)
811
+ self.v_proj_casa = self.init_from_config_proj("v", config)
812
+ self.o_proj_casa = self.init_from_config_proj("o", config)
813
+
814
+ # Delta_w
815
+ self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
816
+ self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
817
+ self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
818
+ self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
819
+
820
+ if config.casa_delta_w:
821
+ assert self_attn is not None
822
+ self.set_delta_w(self_attn)
823
+
824
+ # Layer norm
825
+ self.norm_fn: Callable | None = None
826
+ if config.xa_norm_on_images:
827
+ assert input_layernorm_fn is not None
828
+ self.norm_fn = input_layernorm_fn
829
+
830
+ def init_from_mha(self, self_attn: torch.nn.Module):
831
+ assert self_attn is not None
832
+ with torch.no_grad():
833
+ assert hasattr(self_attn, "q_proj")
834
+ for key in ["q", "k", "v", "o"]:
835
+ src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj"))
836
+ tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa"))
837
+ tgt.weight.copy_(src.weight)
838
+ if tgt.bias is not None and src.bias is not None:
839
+ tgt.bias.copy_(src.bias)
840
+
841
+ def set_delta_w(self, self_attn: torch.nn.Module):
842
+ """Delta w setup"""
843
+ self.override_q_proj = delta_w_factory(
844
+ self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj)
845
+ )
846
+ self.override_k_proj = delta_w_factory(
847
+ self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj)
848
+ )
849
+ self.override_v_proj = delta_w_factory(
850
+ self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj)
851
+ )
852
+ self.override_o_proj = delta_w_factory(
853
+ self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj)
854
+ )
855
+
856
+ with torch.no_grad():
857
+ torch.nn.init.zeros_(self.q_proj_casa.weight)
858
+ torch.nn.init.zeros_(self.k_proj_casa.weight)
859
+ torch.nn.init.zeros_(self.v_proj_casa.weight)
860
+ torch.nn.init.zeros_(self.o_proj_casa.weight)
861
+ if self.q_proj_casa.bias is not None:
862
+ torch.nn.init.zeros_(self.q_proj_casa.bias)
863
+ if self.k_proj_casa.bias is not None:
864
+ torch.nn.init.zeros_(self.k_proj_casa.bias)
865
+ if self.v_proj_casa.bias is not None:
866
+ torch.nn.init.zeros_(self.v_proj_casa.bias)
867
+ if self.o_proj_casa.bias is not None:
868
+ torch.nn.init.zeros_(self.o_proj_casa.bias)
869
+
870
+ def init_from_config_proj(
871
+ self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
872
+ ) -> torch.nn.Linear:
873
+ """Initialize the Linear proj in this module"""
874
+ raise NotImplementedError("Abastract class.")
875
+
876
+ def apply_position_embeddings(
877
+ self,
878
+ key: Literal["q", "kv"],
879
+ x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
880
+ casa_handler: CASAAttentionHandler | None,
881
+ num_queries: int = 0,
882
+ unsqueeze_dim: int = 1,
883
+ ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
884
+ """Apply position embeddings to query and key states"""
885
+ raise NotImplementedError("Abastract class.")
886
+
887
+ def forward(
888
+ self,
889
+ hidden_states: torch.Tensor,
890
+ casa_handler: CASAAttentionHandler | None,
891
+ ) -> torch.Tensor | None:
892
+ """Generic forward for CASA uses for instance in `helium1_attention`"""
893
+ og_dtype = hidden_states.dtype
894
+ if self.is_streaming:
895
+ casa_handler = self.streaming_state.maybe_get_casa_handler(
896
+ casa_handler,
897
+ is_first_casa_layer=self.is_first_casa_layer,
898
+ num_queries=hidden_states.shape[1],
899
+ )
900
+
901
+ # Case of text-only samples at training (or inference when no handler was cached)
902
+ # in this case we just skip CASA so we return None (no casa_update)
903
+ if casa_handler is None:
904
+ return None
905
+
906
+ if self.is_streaming:
907
+ assert casa_handler.use_asymetric_qkv, (
908
+ "You should set `use_asymetric_qkv` to True during inference"
909
+ )
910
+
911
+ og_shape = hidden_states.shape
912
+
913
+ # Build Q inputs
914
+ if casa_handler.use_asymetric_qkv:
915
+ q_inputs = hidden_states.flatten(0, 1)[None, ...]
916
+ if casa_handler.attention_mask is not None:
917
+ q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()]
918
+ else:
919
+ q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
920
+
921
+ # Case 1: Training or first inference step
922
+ if not self.is_streaming or self.streaming_state.offset == 0:
923
+ kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
924
+ else:
925
+ # during streaming, the KV cache including image embeddings
926
+ # will be inserted later so for now we only update the incoming queries
927
+ kv_inputs = q_inputs
928
+
929
+ # Compute QKV for the blockwise attention
930
+ bs, total_seq_len = kv_inputs.shape[:2]
931
+ hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim)
932
+ hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim)
933
+
934
+ if self.override_q_proj is None:
935
+ query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q)
936
+ else:
937
+ query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q)
938
+
939
+ if self.override_k_proj is None:
940
+ key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv)
941
+ else:
942
+ key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv)
943
+
944
+ if self.override_v_proj is None:
945
+ value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv)
946
+ else:
947
+ value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv)
948
+
949
+ # Apply position embedding at the right offset
950
+ num_queries = 0
951
+ if self.streaming and self.streaming_state.offset > 0:
952
+ num_queries = og_shape[1]
953
+
954
+ query_states = self.apply_position_embeddings(
955
+ "q", query_states, num_queries=num_queries, casa_handler=casa_handler
956
+ )
957
+ key_states = self.apply_position_embeddings(
958
+ "kv", key_states, num_queries=num_queries, casa_handler=casa_handler
959
+ )
960
+ assert flash_attn_varlen_func is not None, (
961
+ "flash_attention is not installed but required for block-wise attention"
962
+ )
963
+
964
+ # Flashattention has different efficient implem for streaming
965
+ # In that case, the KV cache has to be batched and has been extended
966
+ # to accomodate the shape of ne the new updates
967
+ if self.is_streaming:
968
+ key_states, value_states = self.streaming_state.extend_kv(
969
+ key_states=key_states, value_states=value_states
970
+ )
971
+ if casa_handler.use_asymetric_qkv:
972
+ cu_seqlens_q = casa_handler.cu_seqlens_q
973
+ max_seqlen_q = casa_handler.max_seqlen_q
974
+ else:
975
+ cu_seqlens_q = casa_handler.cu_seqlens_kv
976
+ max_seqlen_q = casa_handler.max_seqlen_kv
977
+ assert cu_seqlens_q[-1] == query_states.shape[1], (
978
+ f"{cu_seqlens_q[-1]} != {query_states.shape[1]}"
979
+ )
980
+ assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], (
981
+ f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}"
982
+ )
983
+ # for quer
984
+ attn_output: torch.Tensor = flash_attn_varlen_func(
985
+ query_states[0].to(torch.bfloat16),
986
+ key_states[0].to(torch.bfloat16),
987
+ value_states[0].to(torch.bfloat16),
988
+ cu_seqlens_q=cu_seqlens_q,
989
+ cu_seqlens_k=casa_handler.cu_seqlens_kv,
990
+ max_seqlen_q=max_seqlen_q,
991
+ max_seqlen_k=casa_handler.max_seqlen_kv,
992
+ dropout_p=0.0,
993
+ # softmax_scale=None, # defaults to 1/sqrt(d)
994
+ causal=True,
995
+ ).to(og_dtype)
996
+
997
+ attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous()
998
+ if self.override_o_proj is None:
999
+ attn_output = self.o_proj_casa(attn_output)
1000
+ else:
1001
+ attn_output = self.override_o_proj(attn_output)
1002
+
1003
+ attn_output = casa_handler.recover_text_embeds(
1004
+ attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds
1005
+ )
1006
+ attn_output = attn_output.reshape(og_shape)
1007
+
1008
+ if self.is_streaming:
1009
+ self.streaming_state.offset += attn_output.shape[1]
1010
+ return attn_output
config.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "attention_dropout": 0.0,
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_helium1_casa.Helium1CASAConfig",
6
+ "AutoModel": "modeling_helium1_casa.V2Helium1"
7
+ },
8
+ "bos_token_id": 1,
9
+ "casa_attention": true,
10
+ "casa_delta_w": false,
11
+ "casa_use_asymetric_qkv": true,
12
+ "casa_windows": "images",
13
+ "eos_token_id": null,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "mask_squash_blockwise": false,
20
+ "max_position_embeddings": 4096,
21
+ "mlp_bias": false,
22
+ "model_type": "CASA_Helium1_VL_2B",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 28,
25
+ "num_key_value_heads": 8,
26
+ "pad_token_id": 3,
27
+ "post_image_tokens": [],
28
+ "pre_image_tokens": [],
29
+ "pretraining_tp": 1,
30
+ "rms_norm_eps": 1e-08,
31
+ "rope_scaling": null,
32
+ "rope_theta": 20000.0,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "bfloat16",
35
+ "transformers_version": "4.51.3",
36
+ "use_cache": true,
37
+ "vision_config": {
38
+ "depth": 32,
39
+ "fullatt_block_indexes": [
40
+ 7,
41
+ 15,
42
+ 23,
43
+ 31
44
+ ],
45
+ "hidden_act": "silu",
46
+ "hidden_size": 1280,
47
+ "image_mean": [
48
+ 0.48145466,
49
+ 0.4578275,
50
+ 0.40821073
51
+ ],
52
+ "image_std": [
53
+ 0.26862954,
54
+ 0.26130258,
55
+ 0.27577711
56
+ ],
57
+ "in_channels": 3,
58
+ "in_chans": 3,
59
+ "intermediate_size": 3420,
60
+ "model_type": "qwen2_5_vl",
61
+ "num_heads": 16,
62
+ "out_dim": 2048,
63
+ "out_hidden_size": 2048,
64
+ "patch_size": 14,
65
+ "spatial_merge_size": 2,
66
+ "spatial_patch_size": 14,
67
+ "temporal_patch_size": 1,
68
+ "tokens_per_second": 2,
69
+ "window_size": 112
70
+ },
71
+ "vocab_size": 64000,
72
+ "xa_custom_norm": true,
73
+ "xa_layers": [],
74
+ "xa_norm_on_images": true,
75
+ "xa_order": "ca_first",
76
+ "xa_update_image_embeds": false
77
+ }
configuration_helium1_casa.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
5
+
6
+
7
+ class Helium1CASAConfig(PretrainedConfig):
8
+ r"""
9
+ Helium1 Config augmented with CASA options
10
+
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 32000):
14
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
15
+ `inputs_ids` passed when calling [`Helium1Model`]
16
+ hidden_size (`int`, *optional*, defaults to 4096):
17
+ Dimension of the hidden representations.
18
+ intermediate_size (`int`, *optional*, defaults to 11008):
19
+ Dimension of the MLP representations.
20
+ num_hidden_layers (`int`, *optional*, defaults to 32):
21
+ Number of hidden layers in the Transformer decoder.
22
+ num_attention_heads (`int`, *optional*, defaults to 32):
23
+ Number of attention heads for each attention layer in the Transformer decoder.
24
+ num_key_value_heads (`int`, *optional*):
25
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
26
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
27
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
28
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
29
+ by meanpooling all the original heads within that group. For more details checkout [this
30
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
31
+ `num_attention_heads`.
32
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
33
+ The non-linear activation function (function or string) in the decoder.
34
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
35
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
36
+ Llama 2 up to 4096, CodeLlama up to 16384.
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
40
+ The epsilon used by the rms normalization layers.
41
+ use_cache (`bool`, *optional*, defaults to `True`):
42
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
43
+ relevant if `config.is_decoder=True`.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id.
46
+ bos_token_id (`int`, *optional*, defaults to 1):
47
+ Beginning of stream token id.
48
+ eos_token_id (`int`, *optional*, defaults to 2):
49
+ End of stream token id.
50
+ pretraining_tp (`int`, *optional*, defaults to 1):
51
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
52
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
53
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
54
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
55
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
+ Whether to tie weight embeddings
57
+ rope_theta (`float`, *optional*, defaults to 10000.0):
58
+ The base period of the RoPE embeddings.
59
+ rope_scaling (`Dict`, *optional*):
60
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
61
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
62
+ accordingly.
63
+ Expected contents:
64
+ `rope_type` (`str`):
65
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
66
+ 'llama3'], with 'default' being the original RoPE implementation.
67
+ `factor` (`float`, *optional*):
68
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
69
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
70
+ original maximum pre-trained length.
71
+ `original_max_position_embeddings` (`int`, *optional*):
72
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
73
+ pretraining.
74
+ `attention_factor` (`float`, *optional*):
75
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
77
+ `factor` field to infer the suggested value.
78
+ `beta_fast` (`float`, *optional*):
79
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
+ ramp function. If unspecified, it defaults to 32.
81
+ `beta_slow` (`float`, *optional*):
82
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
+ ramp function. If unspecified, it defaults to 1.
84
+ `short_factor` (`List[float]`, *optional*):
85
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
86
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
87
+ size divided by the number of attention heads divided by 2
88
+ `long_factor` (`List[float]`, *optional*):
89
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
90
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
91
+ size divided by the number of attention heads divided by 2
92
+ `low_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
94
+ `high_freq_factor` (`float`, *optional*):
95
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
96
+ attention_bias (`bool`, *optional*, defaults to `False`):
97
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
98
+ attention_dropout (`float`, *optional*, defaults to 0.0):
99
+ The dropout ratio for the attention probabilities.
100
+ mlp_bias (`bool`, *optional*, defaults to `False`):
101
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
102
+ head_dim (`int`, *optional*):
103
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
104
+
105
+ """
106
+
107
+ model_type = "helium1_casa"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ # Default tensor parallel plan for base model `Helium1Model`
110
+ base_model_tp_plan = {
111
+ "layers.*.self_attn.q_proj": "colwise",
112
+ "layers.*.self_attn.k_proj": "colwise",
113
+ "layers.*.self_attn.v_proj": "colwise",
114
+ "layers.*.self_attn.o_proj": "rowwise",
115
+ "layers.*.mlp.gate_proj": "colwise",
116
+ "layers.*.mlp.up_proj": "colwise",
117
+ "layers.*.mlp.down_proj": "rowwise",
118
+ }
119
+ base_model_pp_plan = { # pyright: ignore[reportAssignmentType]
120
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
121
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
122
+ "norm": (["hidden_states"], ["hidden_states"]),
123
+ }
124
+
125
+ def __init__(
126
+ self,
127
+ vocab_size: int = 32000,
128
+ hidden_size: int = 4096,
129
+ intermediate_size: int = 11008,
130
+ num_hidden_layers: int = 32,
131
+ num_attention_heads: int = 32,
132
+ num_key_value_heads: None | int = None,
133
+ head_dim: None | int = None,
134
+ hidden_act: str = "silu",
135
+ attention_dropout: float = 0.0,
136
+ max_position_embeddings: int = 2048,
137
+ initializer_range: float = 0.02,
138
+ rms_norm_eps: float = 1e-6,
139
+ use_cache: bool = True,
140
+ tie_word_embeddings: bool = False,
141
+ rope_theta: float = 10000.0,
142
+ pad_token_id: int = 3,
143
+ eos_token_id: int = 2,
144
+ bos_token_id: int = 1,
145
+ pretraining_tp: int = 1,
146
+ rope_scaling: None | dict = None,
147
+ attention_bias: bool = False,
148
+ mlp_bias: bool = False,
149
+ # Our fusion mechanisms
150
+ # Common to all fusion mechanisms
151
+ xa_layers: None | tuple = None,
152
+ xa_order: Literal["ca_first", "parallel", "instead"] = "ca_first",
153
+ xa_norm_on_images: bool = False,
154
+ xa_update_image_embeds: bool = False,
155
+ mask_squash_blockwise: bool = False,
156
+ # CASA
157
+ casa_attention: bool = False,
158
+ casa_delta_w: bool = False,
159
+ casa_windows: Literal["batch", "squashed", "images", "turn_based"] = "batch",
160
+ casa_use_asymetric_qkv: bool = True,
161
+ xa_custom_norm: bool = False,
162
+ # Qwen2.5-VL vision config
163
+ vision_config: dict[str, Any] | None = None,
164
+ **kwargs: Any,
165
+ ):
166
+ from transformers.modeling_rope_utils import rope_config_validation
167
+
168
+ self.vocab_size = vocab_size
169
+ self.max_position_embeddings = max_position_embeddings
170
+ self.hidden_size = hidden_size
171
+ self.intermediate_size = intermediate_size
172
+ self.num_hidden_layers = num_hidden_layers
173
+ self.num_attention_heads = num_attention_heads
174
+
175
+ # for backward compatibility
176
+ if num_key_value_heads is None:
177
+ num_key_value_heads = num_attention_heads
178
+
179
+ self.num_key_value_heads = num_key_value_heads
180
+ self.head_dim = (
181
+ head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
182
+ )
183
+ self.hidden_act = hidden_act
184
+ self.initializer_range = initializer_range
185
+ self.rms_norm_eps = rms_norm_eps
186
+ self.pretraining_tp = pretraining_tp
187
+ self.use_cache = use_cache
188
+ self.rope_theta = rope_theta
189
+ self.rope_scaling = rope_scaling
190
+ self.attention_bias = attention_bias
191
+ self.attention_dropout = attention_dropout
192
+ self.mlp_bias = mlp_bias
193
+ # Validate the correctness of rotary position embeddings parameters
194
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
195
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
196
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
197
+ rope_config_validation(self)
198
+
199
+ self.head_dim = self.hidden_size // self.num_attention_heads
200
+ self.xa_layers = xa_layers
201
+ self.xa_order: Literal["ca_first", "parallel", "instead"] = xa_order
202
+ self.xa_norm_on_images = xa_norm_on_images
203
+ self.xa_update_image_embeds = xa_update_image_embeds
204
+ self.mask_squash_blockwise = mask_squash_blockwise
205
+ # CASA config
206
+ self.casa_attention = casa_attention
207
+ self.casa_delta_w = casa_delta_w
208
+ self.casa_windows: Literal["batch", "squashed", "images", "turn_based"] = casa_windows
209
+ self.casa_use_asymetric_qkv = casa_use_asymetric_qkv
210
+ self.xa_custom_norm = xa_custom_norm
211
+
212
+ if vision_config is None:
213
+ vision_config = dict()
214
+ self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
215
+ self.vision_config.temporal_patch_size = 1
216
+ self.vision_config.image_mean = [0.48145466, 0.4578275, 0.40821073]
217
+ self.vision_config.image_std = [0.26862954, 0.26130258, 0.27577711]
218
+ self.vision_config.out_dim = 2048
219
+
220
+ self.pre_image_tokens = []
221
+ self.post_image_tokens = []
222
+
223
+ super().__init__(
224
+ pad_token_id=pad_token_id,
225
+ bos_token_id=bos_token_id,
226
+ eos_token_id=eos_token_id,
227
+ tie_word_embeddings=tie_word_embeddings,
228
+ **kwargs,
229
+ )
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import argparse
234
+ from pathlib import Path
235
+
236
+ import rich
237
+ import yaml
238
+ from transformers.models.auto.configuration_auto import AutoConfig
239
+
240
+ parser = argparse.ArgumentParser()
241
+ parser.add_argument("--out_dir", type=str, default="./saved_config/")
242
+ parser.add_argument(
243
+ "--ckpt_path",
244
+ type=str,
245
+ default="/lustre/scwpod02/client/kyutai/juliette/experiments/finext_casa_896_xtxt_up_b20_64gpu/fdf76e6774",
246
+ )
247
+ args = parser.parse_args()
248
+ path = Path(args.ckpt_path) / "kyuteye_config.yml"
249
+
250
+ helium_config = AutoConfig.from_pretrained("kyutai/helium-1-2b")
251
+ vision_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct").vision_config
252
+
253
+ # 3) Create YOUR config by merging both
254
+ config = Helium1CASAConfig(
255
+ **helium_config.to_dict(), # all helium parameters
256
+ vision_config=vision_config.to_dict(), # override or add vision_config
257
+ )
258
+
259
+ with open(path) as stream:
260
+ kconfig = yaml.safe_load(stream)
261
+
262
+ # print keys that are in kconfig and in config
263
+ for key in set(kconfig.keys()).intersection(set(config.to_dict().keys())):
264
+ rich.print(f"Overwriting [bold green]{key:>50s}[/]: [bold red]{kconfig[key]}")
265
+ setattr(config, key, kconfig[key])
266
+ # TODO: handle casa_own_norm -> xa_custom_norm
267
+ print("Configuration successfully loaded.")
268
+ # Save config to json
269
+ config.save_pretrained(args.out_dir)
270
+ print(f"Configuration saved to {args.out_dir}/config.json")
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": [
5
+ 103,
6
+ 3
7
+ ],
8
+ "pad_token_id": 3,
9
+ "transformers_version": "4.51.3"
10
+ }
image_encoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qwen2.5VL encoder with delayed normalization"""
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
6
+ Qwen2_5_VisionTransformerPretrainedModel,
7
+ )
8
+
9
+
10
+ def prepare_for_qwen_encoder(
11
+ x: torch.Tensor | list[torch.Tensor], mean: torch.Tensor, std: torch.Tensor
12
+ ) -> tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Preprocessing for Qwen encoder
15
+ Image mean and std come from processor.image_processor.image_mean and image_std
16
+ """
17
+ grid_thw = torch.Tensor([[1, img.shape[0], img.shape[1]] for img in x]).to(x[0].device)
18
+ hws_flatten_shape = torch.prod(grid_thw, dim=-1)
19
+ x = torch.cat(
20
+ [img.reshape((int(hws_flatten_shape[idx].item()), -1)) for idx, img in enumerate(x)],
21
+ dim=0,
22
+ )
23
+ assert x.min() >= 0.0 and x.max() <= 1.0
24
+ og_shape = x.shape
25
+ x = rearrange(x, "L (c d) -> L c d", c=3)
26
+ x = (x - mean) / std
27
+ x = x.view(og_shape).to(torch.bfloat16)
28
+ return x, grid_thw
29
+
30
+
31
+ class Qwen25VLEncoder(torch.nn.Module):
32
+ """Qwen2.5 VL encoder with pre/post processing to be compatible for
33
+ our CASA attention implementation"""
34
+
35
+ def __init__(
36
+ self,
37
+ visual: "Qwen2_5_VisionTransformerPretrainedModel",
38
+ ):
39
+ super().__init__()
40
+ self.visual = visual
41
+ self.image_mean = torch.tensor(self.visual.config.image_mean).view(1, 3, 1)
42
+ self.image_std = torch.tensor(self.visual.config.image_std).view(1, 3, 1)
43
+
44
+ def forward(
45
+ self, x: torch.Tensor | list[torch.Tensor]
46
+ ) -> dict[str, torch.Tensor | list[torch.Tensor]]:
47
+ x, grid_thw = prepare_for_qwen_encoder(
48
+ x, mean=self.image_mean.to(x[0].device), std=self.image_std.to(x[0].device)
49
+ )
50
+
51
+ grid_thw = grid_thw.type(torch.int)
52
+ assert len(x) == grid_thw.prod(dim=1).sum()
53
+ out = self.visual(x, grid_thw=grid_thw)
54
+
55
+ split_sizes = (grid_thw.prod(dim=-1) // self.visual.spatial_merge_size**2).tolist()
56
+ embeds = list(torch.split(out, split_sizes, dim=0)) # Ni * (seq, C)
57
+ return {"image_embeds": embeds, "grid_thw": grid_thw}
language_helium1_casa.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADAPTED FROM https://github.com/huggingface/transformers/blob/main/src/transformers/models/helium/modeling_helium.py
2
+ # GIT HASH 1b222903c3e1cfd9492d75e4b2548aa8bd458674
3
+ import logging
4
+ import math
5
+ from dataclasses import dataclass
6
+ from functools import partial
7
+ from typing import Any, Callable, Literal, Optional
8
+ from typing import cast as type_cast
9
+
10
+ import torch
11
+ from torch import nn
12
+ from transformers import (
13
+ ROPE_INIT_FUNCTIONS, # pyright: ignore[reportPrivateImportUsage]
14
+ dynamic_rope_update, # pyright: ignore[reportPrivateImportUsage]
15
+ )
16
+ from transformers.activations import ACT2FN
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.generation.utils import GenerationMixin
20
+ from transformers.loss.loss_utils import ForCausalLMLoss
21
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
22
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
23
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils.generic import LossKwargs, can_return_tuple
27
+ from transformers.utils.import_utils import is_torch_flex_attn_available
28
+
29
+ from .casa_attention import CASAAttention, CASAAttentionHandler, insert_image_tokens
30
+ from .configuration_helium1_casa import Helium1CASAConfig
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ if is_torch_flex_attn_available():
35
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
36
+
37
+
38
+ def remove_image_tokens(
39
+ inputs_embeds: torch.Tensor,
40
+ image_tokens_mask: torch.Tensor,
41
+ ) -> torch.Tensor:
42
+ """Remove the image tokens from inputs_embeds as indicated by image_tokens_mask
43
+
44
+ :param inputs_embeds: Tokens of shape (Batch, Seqlen, Dims) containing image tokens
45
+ :param image_tokens_mask: 1-0 mask indicating where image tokens are; (Batch, Seqlen)
46
+
47
+ :return: Tokens tensor of shape (Batch, S' < Seqlen, Dims)
48
+ """
49
+ image_seq_lengths = torch.sum(image_tokens_mask, dim=1)[:, 0]
50
+ image_seq_length = int(image_seq_lengths[0].item())
51
+ assert torch.all(image_seq_lengths == image_seq_length)
52
+ new_shape = (
53
+ inputs_embeds.shape[0],
54
+ inputs_embeds.shape[1] - image_seq_length,
55
+ inputs_embeds.shape[-1],
56
+ )
57
+ tokens = torch.masked_select(
58
+ inputs_embeds,
59
+ torch.logical_not(image_tokens_mask).expand((-1, -1, inputs_embeds.shape[-1])),
60
+ )
61
+ return tokens.reshape(new_shape)
62
+
63
+
64
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
65
+ """
66
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
67
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
68
+ """
69
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
70
+ if n_rep == 1:
71
+ return hidden_states
72
+ hidden_states = hidden_states[:, :, None, :, :].expand(
73
+ batch, num_key_value_heads, n_rep, slen, head_dim
74
+ )
75
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
76
+
77
+
78
+ def eager_attention_forward(
79
+ module: "HeliumAttention",
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ attention_mask: None | torch.Tensor,
84
+ scaling: float,
85
+ dropout: float = 0.0,
86
+ **kwargs: Any,
87
+ ):
88
+ del kwargs # unused
89
+ key_states = repeat_kv(key, module.num_key_value_groups)
90
+ value_states = repeat_kv(value, module.num_key_value_groups)
91
+
92
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
93
+ if attention_mask is not None:
94
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
95
+ attn_weights = attn_weights + causal_mask
96
+
97
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
98
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
99
+ attn_output = torch.matmul(attn_weights, value_states)
100
+ attn_output = attn_output.transpose(1, 2).contiguous()
101
+
102
+ return attn_output, attn_weights
103
+
104
+
105
+ # Different Attention Classes
106
+ class HeliumAttention(torch.nn.Module):
107
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
108
+
109
+ def __init__(self, config: Helium1CASAConfig, layer_idx: None | int = None):
110
+ super().__init__()
111
+ self.config = config
112
+ assert layer_idx is not None
113
+ self.layer_idx: int = layer_idx
114
+
115
+ self.apply_rotary_fn = ApplyRotaryPosEmbHelium1()
116
+ self.head_dim = getattr(
117
+ config, "head_dim", config.hidden_size // config.num_attention_heads
118
+ )
119
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
120
+ self.scaling = 1 / math.sqrt(self.head_dim)
121
+ self.attention_dropout = config.attention_dropout
122
+ self.is_causal = True
123
+
124
+ self.q_proj = nn.Linear(
125
+ config.hidden_size,
126
+ config.num_attention_heads * self.head_dim,
127
+ bias=config.attention_bias,
128
+ )
129
+ self.k_proj = nn.Linear(
130
+ config.hidden_size,
131
+ config.num_key_value_heads * self.head_dim,
132
+ bias=config.attention_bias,
133
+ )
134
+ self.v_proj = nn.Linear(
135
+ config.hidden_size,
136
+ config.num_key_value_heads * self.head_dim,
137
+ bias=config.attention_bias,
138
+ )
139
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
140
+
141
+ def forward(
142
+ self,
143
+ hidden_states: torch.Tensor,
144
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
145
+ attention_mask: None | torch.Tensor,
146
+ past_key_values: None | Cache = None,
147
+ cache_position: None | torch.LongTensor = None,
148
+ **kwargs: Unpack[FlashAttentionKwargs],
149
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
150
+ # del (cache_position, past_key_value) # we use our own generate/caching
151
+ bs, seq_len, _ = hidden_states.shape
152
+ # Get QKV
153
+ hidden_shape = (bs, seq_len, -1, self.head_dim)
154
+
155
+ # Embed Queries
156
+ # Shape: (batch_size, num_heads, seq_len, head_dim)
157
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
158
+ num_queries = query_states.shape[2]
159
+
160
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
161
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
162
+
163
+ # Applies rotation
164
+ cos, sin = position_embeddings
165
+ query_states, key_states = self.apply_rotary_fn(
166
+ query_states, key_states, cos, sin, num_queries=num_queries
167
+ )
168
+ assert key_states is not None and query_states is not None
169
+
170
+ attention_interface: Callable = eager_attention_forward
171
+
172
+ if self.config._attn_implementation != "eager":
173
+ if self.config._attn_implementation == "sdpa" and kwargs.get(
174
+ "output_attentions", False
175
+ ):
176
+ print(
177
+ "`torch.nn.functional.scaled_dot_product_attention` does not support"
178
+ " `output_attentions=True`. Falling back to "
179
+ 'eager attention. This warning can be removed using the argument"\
180
+ " `attn_implementation="eager"` when loading the model.'
181
+ )
182
+ else:
183
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
184
+
185
+ if past_key_values is not None:
186
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
187
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
188
+ key_states, value_states = past_key_values.update(
189
+ key_states, value_states, self.layer_idx, cache_kwargs
190
+ )
191
+ attn_output, attn_weights = attention_interface(
192
+ self,
193
+ query_states,
194
+ key_states,
195
+ value_states,
196
+ attention_mask,
197
+ dropout=0.0 if not self.training else self.attention_dropout,
198
+ scaling=self.scaling,
199
+ **kwargs,
200
+ )
201
+ attn_output = attn_output.reshape(bs, num_queries, -1).contiguous()
202
+ attn_output = self.o_proj(attn_output)
203
+
204
+ assert isinstance(attn_output, torch.Tensor)
205
+ return attn_output, attn_weights
206
+
207
+
208
+ class ApplyRotaryPosEmbHelium1:
209
+ @staticmethod
210
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
211
+ """Rotates half the hidden dims of the input."""
212
+ x1 = x[..., : x.shape[-1] // 2]
213
+ x2 = x[..., x.shape[-1] // 2 :]
214
+ return torch.cat((-x2, x1), dim=-1)
215
+
216
+ @staticmethod
217
+ def __call__(
218
+ q: torch.Tensor,
219
+ k: torch.Tensor,
220
+ cos: torch.Tensor,
221
+ sin: torch.Tensor,
222
+ position_ids: torch.Tensor | None = None,
223
+ unsqueeze_dim: int = 1,
224
+ num_queries: int | None = None,
225
+ ) -> tuple[torch.Tensor, torch.Tensor]:
226
+ """Applies Rotary Position Embedding to the query and key tensors.
227
+
228
+ Args:
229
+ q (`torch.Tensor`): The query tensor.
230
+ k (`torch.Tensor`): The key tensor.
231
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
232
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
233
+ position_ids (`torch.Tensor`, *optional*):
234
+ Deprecated and unused.
235
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
236
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
237
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
238
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
239
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
240
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
241
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
242
+ Returns:
243
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
244
+ """
245
+ del position_ids
246
+ cos = cos.unsqueeze(unsqueeze_dim)
247
+ sin = sin.unsqueeze(unsqueeze_dim)
248
+ if num_queries is None:
249
+ offset = 0
250
+ else:
251
+ offset = -num_queries
252
+
253
+ q_embed = (q * cos[:, :, offset:]) + (
254
+ ApplyRotaryPosEmbHelium1.rotate_half(q) * sin[:, :, offset:]
255
+ )
256
+ k_embed = (k * cos) + (ApplyRotaryPosEmbHelium1.rotate_half(k) * sin)
257
+
258
+ return q_embed, k_embed
259
+
260
+
261
+ class HeliumRotaryEmbedding(nn.Module):
262
+ def __init__(self, config: Helium1CASAConfig, device: None | torch.device | str = None):
263
+ super().__init__()
264
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
265
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
266
+ else:
267
+ self.rope_type = "default"
268
+ self.max_seq_len_cached = config.max_position_embeddings
269
+ self.original_max_seq_len = config.max_position_embeddings
270
+
271
+ self.config = config
272
+ assert self.rope_type in ROPE_INIT_FUNCTIONS, (
273
+ f"Invalid rope type {self.rope_type}. Supported types are: {list(ROPE_INIT_FUNCTIONS.keys())}"
274
+ )
275
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
276
+ inv_freq, self.attention_scaling = self.rope_init_fn(config, device=device)
277
+ self.inv_freq: torch.Tensor # only defined for typing
278
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
279
+ self.original_inv_freq = self.inv_freq
280
+
281
+ @torch.no_grad()
282
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
283
+ def forward(
284
+ self, x: torch.Tensor, position_ids: torch.Tensor
285
+ ) -> tuple[torch.Tensor, torch.Tensor]:
286
+ inv_freq_expanded = (
287
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
288
+ )
289
+ position_ids_expanded = position_ids[:, None, :].float()
290
+
291
+ device_type = (
292
+ x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
293
+ )
294
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
295
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
296
+ emb = torch.cat((freqs, freqs), dim=-1)
297
+ cos = emb.cos() * self.attention_scaling
298
+ sin = emb.sin() * self.attention_scaling
299
+
300
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
301
+
302
+
303
+ class Helium1CASAAttention(CASAAttention):
304
+ """A CASA Attention layer compatible with Qwen"""
305
+
306
+ def __init__(
307
+ self,
308
+ config: Helium1CASAConfig,
309
+ layer_idx: int | None,
310
+ self_attn: torch.nn.Module | None = None,
311
+ input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
312
+ ):
313
+ # Only adding this init for typing purposes for the config
314
+ super().__init__(config, layer_idx, self_attn, input_layernorm_fn) # pyright: ignore[reportArgumentType]
315
+
316
+ @staticmethod
317
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
318
+ """Rotates half the hidden dims of the input."""
319
+ x1 = x[..., : x.shape[-1] // 2]
320
+ x2 = x[..., x.shape[-1] // 2 :]
321
+ return torch.cat((-x2, x1), dim=-1)
322
+
323
+ def apply_position_embeddings(
324
+ self,
325
+ key: Literal["q", "kv"],
326
+ x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
327
+ casa_handler: CASAAttentionHandler | None,
328
+ num_queries: int = 0,
329
+ unsqueeze_dim: int = 1,
330
+ ) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
331
+ """Apply position embeddings to query and key states"""
332
+ if casa_handler is not None:
333
+ posemb = casa_handler.get_position_embedding(key, num_queries=num_queries)
334
+
335
+ if posemb is not None:
336
+ x = x.transpose(1, 2).to(torch.float32)
337
+ x = (x * posemb[0].unsqueeze(dim=unsqueeze_dim)) + (
338
+ self.rotate_half(x) * posemb[1].unsqueeze(dim=unsqueeze_dim)
339
+ )
340
+ return x.transpose(1, 2)
341
+ return x
342
+
343
+ def init_from_config_proj(
344
+ self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
345
+ ) -> torch.nn.Linear:
346
+ """Initialize the Linear proj in this module"""
347
+ num_heads = config.num_key_value_heads if key in {"k", "v"} else config.num_attention_heads
348
+ return torch.nn.Linear(
349
+ config.hidden_size,
350
+ num_heads * config.head_dim,
351
+ bias=config.attention_bias if key != "o" else False,
352
+ )
353
+
354
+
355
+ # NORMALISATION LAYER
356
+ def __rms_norm_forward__(
357
+ hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6
358
+ ) -> torch.Tensor:
359
+ input_dtype = hidden_states.dtype
360
+ hidden_states = hidden_states.to(torch.float32)
361
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
362
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
363
+ return weight * hidden_states.to(input_dtype)
364
+
365
+
366
+ class Helium1RMSNorm(nn.Module):
367
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
368
+ """
369
+ Helium1RMSNorm is equivalent to T5LayerNorm
370
+ """
371
+ super().__init__()
372
+ self.weight = nn.Parameter(torch.ones(hidden_size))
373
+ self.variance_epsilon = eps
374
+
375
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
376
+ return __rms_norm_forward__(hidden_states, self.weight, self.variance_epsilon)
377
+
378
+ def extra_repr(self):
379
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
380
+
381
+
382
+ def delta_w_factory_rms_norm(
383
+ org_lin: Helium1RMSNorm, new_lin: Helium1RMSNorm
384
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
385
+ """Factory for building rms norm where the weights are the sum of two layers' weights"""
386
+
387
+ def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
388
+ nonlocal org_lin, new_lin
389
+ return __rms_norm_forward__(
390
+ input, org_lin.weight + new_lin.weight, new_lin.variance_epsilon
391
+ )
392
+
393
+ return _delta_w_fwd
394
+
395
+
396
+ # FULL CONNECTED LAYER
397
+
398
+
399
+ class HeliumMLP(nn.Module):
400
+ def __init__(self, config: Helium1CASAConfig) -> None:
401
+ super().__init__()
402
+ self.config = config
403
+ self.hidden_size = config.hidden_size
404
+ self.intermediate_size = config.intermediate_size
405
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
406
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
407
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
408
+ self.act_fn = ACT2FN[config.hidden_act]
409
+
410
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
411
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
412
+ return down_proj
413
+
414
+
415
+ class HeliumDecoderLayer(nn.Module):
416
+ def __init__(self, config: Helium1CASAConfig, layer_idx: None | int = None):
417
+ super().__init__()
418
+ self.hidden_size = config.hidden_size
419
+ self.config = config
420
+ self.mlp = HeliumMLP(config)
421
+ self.input_layernorm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422
+ self.post_attention_layernorm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
423
+
424
+ # Self-attention
425
+ self.self_attn = HeliumAttention(config=config, layer_idx=layer_idx)
426
+
427
+ # Setup norm for fusion mechanisms; Note that this norm is on the text tokens
428
+ is_xa_layer = layer_idx is None or not config.xa_layers or layer_idx in config.xa_layers
429
+ self.norm_cross: None | Helium1RMSNorm = None
430
+ self.override_norm_cross: Callable[[torch.Tensor], torch.Tensor] | None = None
431
+ if is_xa_layer and config.casa_attention:
432
+ # Custom normalization layer for the extra fusion module
433
+ if self.config.xa_custom_norm:
434
+ self.norm_cross = Helium1RMSNorm(config.hidden_size)
435
+ if config.casa_delta_w:
436
+ self.override_norm_cross = delta_w_factory_rms_norm(
437
+ self.input_layernorm, self.norm_cross
438
+ )
439
+ with torch.no_grad():
440
+ torch.nn.init.ones_(self.norm_cross.weight)
441
+
442
+ # Setup additional norm for images tokens which is set in each individual mechansims
443
+ norm_on_images_fn = (
444
+ None
445
+ if not self.config.xa_norm_on_images
446
+ else self.override_norm_cross
447
+ if self.override_norm_cross is not None
448
+ else self.norm_cross.forward
449
+ if self.norm_cross is not None
450
+ else self.input_layernorm.forward
451
+ )
452
+
453
+ # CASA
454
+ self.casa_attn: Helium1CASAAttention | None = None
455
+ if config.casa_attention and is_xa_layer:
456
+ self.casa_attn = Helium1CASAAttention(
457
+ config, layer_idx, self_attn=self.self_attn, input_layernorm_fn=norm_on_images_fn
458
+ )
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ attention_mask: None | torch.Tensor = None,
464
+ position_ids: None | torch.LongTensor = None,
465
+ past_key_values: None | Cache = None,
466
+ output_attentions: None | bool = False,
467
+ use_cache: None | bool = False,
468
+ cache_position: None | torch.LongTensor = None,
469
+ position_embeddings: None
470
+ | tuple[torch.Tensor, torch.Tensor] = None, # necessary, but kept here for BC
471
+ # CASA
472
+ casa_handler: CASAAttentionHandler | None = None,
473
+ cu_seqlens: torch.Tensor | None = None,
474
+ **kwargs: Unpack[FlashAttentionKwargs],
475
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
476
+ # Image fusion mechanisms
477
+ apply_ca = self.casa_attn is not None
478
+ ca_update: torch.Tensor | None = None
479
+ if (
480
+ self.config.xa_order
481
+ in {
482
+ "parallel",
483
+ "ca_first",
484
+ "instead",
485
+ }
486
+ and apply_ca
487
+ ):
488
+ # Apply layer norm
489
+ assert self.norm_cross is not None
490
+ ca_input = (
491
+ self.override_norm_cross
492
+ if self.override_norm_cross is not None
493
+ else self.norm_cross
494
+ )(hidden_states)
495
+ # CASA
496
+ if self.casa_attn is not None:
497
+ ca_update = self.casa_attn(ca_input, casa_handler=casa_handler)
498
+
499
+ # If we're here, it's because we had proper inputs (no text-only samples)
500
+ # so the output better be not None !
501
+ if ca_update is not None:
502
+ # `instead`: directly return the output of the CA module as residual
503
+ if self.config.xa_order == "instead":
504
+ outputs = (hidden_states + ca_update,)
505
+ if output_attentions:
506
+ outputs += (
507
+ torch.zeros((), device=ca_update.device, dtype=ca_update.dtype),
508
+ )
509
+ return outputs
510
+
511
+ # `ca_first`: update then continue with normal self-attention
512
+ if self.config.xa_order == "ca_first":
513
+ hidden_states = hidden_states + ca_update
514
+ ca_update = None
515
+
516
+ # Self Attention with initial input layer norm
517
+ residual = hidden_states
518
+ hidden_states, self_attn_weights = self.self_attn(
519
+ hidden_states=self.input_layernorm(hidden_states),
520
+ attention_mask=attention_mask,
521
+ position_ids=position_ids,
522
+ past_key_values=past_key_values,
523
+ output_attentions=output_attentions,
524
+ use_cache=use_cache,
525
+ cache_position=cache_position,
526
+ position_embeddings=position_embeddings,
527
+ cu_seqlens=cu_seqlens,
528
+ **kwargs,
529
+ )
530
+ hidden_states = residual + hidden_states
531
+
532
+ # parallel - residual update
533
+ if self.config.xa_order == "parallel" and apply_ca and ca_update is not None:
534
+ hidden_states = hidden_states + ca_update
535
+
536
+ # Fully Connected layer
537
+ residual = hidden_states
538
+ # MLP updates for image embeddings
539
+ if (
540
+ self.config.xa_update_image_embeds
541
+ and self.casa_attn is not None
542
+ and casa_handler is not None
543
+ and casa_handler.image_embeds is not None
544
+ ):
545
+ # Text flattening
546
+ hs = self.post_attention_layernorm(hidden_states).reshape(-1, hidden_states.shape[-1])
547
+ # Image flattening
548
+ img_seq_lengths = [_x.shape[0] for _x in casa_handler.image_embeds]
549
+ img_residual = torch.cat(list(casa_handler.image_embeds), dim=0)
550
+ update = self.mlp(torch.cat([hs, self.post_attention_layernorm(img_residual)], dim=0))
551
+ # update text
552
+ hidden_states = hidden_states + update[: hs.shape[0]].reshape(hidden_states.shape)
553
+ casa_handler.image_embeds = list(
554
+ torch.split(img_residual + update[hs.shape[0] :], img_seq_lengths)
555
+ )
556
+ else:
557
+ hidden_states = self.mlp(self.post_attention_layernorm(hidden_states))
558
+ hidden_states = residual + hidden_states
559
+
560
+ # Outputs
561
+ outputs = (hidden_states,)
562
+ if output_attentions:
563
+ outputs += (self_attn_weights,)
564
+
565
+ return outputs
566
+
567
+
568
+ # FULL HELIUM MODEL
569
+
570
+
571
+ @dataclass
572
+ class CausalHeliumOutput(CausalLMOutputWithPast):
573
+ attention_mask: Optional[torch.Tensor] = None
574
+ num_image_tokens_log: Optional[torch.Tensor] = None
575
+ num_text_tokens_log: Optional[torch.Tensor] = None
576
+
577
+
578
+ class Helium1PreTrainedModel(PreTrainedModel):
579
+ config_class = Helium1CASAConfig
580
+ base_model_prefix = "model"
581
+ supports_gradient_checkpointing = True
582
+ _no_split_modules = ["HeliumDecoderLayer"]
583
+ _skip_keys_device_placement = ["past_key_values"]
584
+ _supports_flash_attn_2 = True
585
+ _supports_sdpa = True
586
+ _supports_flex_attn = True
587
+ _supports_cache_class = True
588
+ _supports_quantized_cache = True
589
+ _supports_static_cache = True
590
+ _supports_attention_backend = True
591
+
592
+ def _init_weights(self, module: torch.nn.Module) -> None:
593
+ std = self.config.initializer_range
594
+ if isinstance(module, nn.Linear):
595
+ module.weight.data.normal_(mean=0.0, std=std)
596
+ if module.bias is not None:
597
+ module.bias.data.zero_()
598
+ elif isinstance(module, nn.Embedding):
599
+ module.weight.data.normal_(mean=0.0, std=std)
600
+ if module.padding_idx is not None:
601
+ module.weight.data[module.padding_idx].zero_()
602
+ elif isinstance(module, Helium1RMSNorm):
603
+ module.weight.data.fill_(1.0)
604
+
605
+
606
+ class Helium1Model(Helium1PreTrainedModel):
607
+ """
608
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
609
+
610
+ Args:
611
+ config: Helium1CASAConfig
612
+ """
613
+
614
+ def __init__(self, config: Helium1CASAConfig):
615
+ Helium1PreTrainedModel.__init__(self, config)
616
+ self.training: bool
617
+ self._gradient_checkpointing_func: Callable
618
+ self.config = config
619
+ self.padding_idx = config.pad_token_id
620
+ self.vocab_size = config.vocab_size
621
+
622
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
623
+ self.layers = nn.ModuleList(
624
+ [HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
625
+ )
626
+ self.norm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
627
+ self.rotary_emb = HeliumRotaryEmbedding(config=config)
628
+ self.gradient_checkpointing = False
629
+
630
+ # Initialize weights and apply final processing
631
+ self.post_init()
632
+
633
+ def get_input_embeddings(self):
634
+ return self.embed_tokens
635
+
636
+ def set_input_embeddings(self, value: nn.Module) -> None:
637
+ self.embed_tokens = value
638
+
639
+ @can_return_tuple
640
+ def forward(
641
+ self,
642
+ input_ids: None | torch.LongTensor = None,
643
+ attention_mask: None | torch.Tensor = None,
644
+ position_ids: None | torch.Tensor = None,
645
+ past_key_values: None | DynamicCache = None,
646
+ inputs_embeds: None | torch.Tensor = None,
647
+ use_cache: None | bool = None,
648
+ output_attentions: None | bool = None,
649
+ output_hidden_states: None | bool = None,
650
+ cache_position: None | torch.Tensor = None,
651
+ # Insertion
652
+ image_tokens_mask: torch.Tensor | None = None,
653
+ # CASA
654
+ casa_handler: CASAAttentionHandler | None = None,
655
+ cu_seqlens: torch.Tensor | None = None,
656
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
657
+ ) -> BaseModelOutputWithPast:
658
+ output_attentions = (
659
+ output_attentions if output_attentions is not None else self.config.output_attentions
660
+ )
661
+ output_hidden_states = (
662
+ output_hidden_states
663
+ if output_hidden_states is not None
664
+ else self.config.output_hidden_states
665
+ )
666
+ use_cache = not self.training and (
667
+ use_cache if use_cache is not None else self.config.use_cache
668
+ )
669
+
670
+ if (input_ids is None) ^ (inputs_embeds is not None):
671
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
672
+
673
+ if self.gradient_checkpointing and self.training and use_cache:
674
+ print(
675
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
676
+ )
677
+ use_cache = False
678
+
679
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
680
+ if not isinstance(past_key_values, (type(None), Cache)):
681
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
682
+
683
+ if inputs_embeds is None:
684
+ inputs_embeds = self.embed_tokens(input_ids)
685
+ assert inputs_embeds is not None
686
+
687
+ if use_cache and past_key_values is None:
688
+ past_key_values = DynamicCache()
689
+
690
+ if cache_position is None:
691
+ past_seen_tokens = 0 if past_key_values is None else past_key_values._seen_tokens
692
+ assert inputs_embeds is not None
693
+ cache_position = torch.arange(
694
+ past_seen_tokens,
695
+ past_seen_tokens + inputs_embeds.shape[1],
696
+ device=inputs_embeds.device,
697
+ )
698
+ assert cache_position is not None
699
+
700
+ if position_ids is None:
701
+ position_ids = cache_position.unsqueeze(0)
702
+
703
+ # Get attention mask
704
+ causal_mask: None | torch.Tensor = self._update_causal_mask(
705
+ attention_mask,
706
+ inputs_embeds,
707
+ cache_position,
708
+ past_key_values,
709
+ output_attentions,
710
+ force_mask=False,
711
+ )
712
+
713
+ # create position embeddings to be shared across the decoder layers
714
+ hidden_states = inputs_embeds
715
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
716
+
717
+ # decoder layers
718
+ all_hidden_states = () if output_hidden_states else None
719
+ all_self_attns = () if output_attentions else None
720
+
721
+ for decoder_layer_idx, decoder_layer in enumerate(
722
+ self.layers[: self.config.num_hidden_layers]
723
+ ):
724
+ is_xa_layer = not self.config.xa_layers or decoder_layer_idx in self.config.xa_layers
725
+ if output_hidden_states is not None:
726
+ if all_hidden_states is None:
727
+ all_hidden_states = ()
728
+ all_hidden_states += (hidden_states,)
729
+
730
+ if self.gradient_checkpointing and self.training:
731
+ layer_outputs = self._gradient_checkpointing_func(
732
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
733
+ hidden_states,
734
+ causal_mask,
735
+ position_ids,
736
+ past_key_values,
737
+ output_attentions,
738
+ use_cache,
739
+ cache_position,
740
+ position_embeddings,
741
+ casa_handler if is_xa_layer else None,
742
+ cu_seqlens,
743
+ )
744
+ else:
745
+ layer_outputs = decoder_layer(
746
+ hidden_states,
747
+ attention_mask=causal_mask,
748
+ position_ids=position_ids,
749
+ past_key_values=past_key_values,
750
+ output_attentions=output_attentions,
751
+ use_cache=use_cache,
752
+ cache_position=cache_position,
753
+ position_embeddings=position_embeddings,
754
+ casa_handler=casa_handler if is_xa_layer else None,
755
+ cu_seqlens=cu_seqlens,
756
+ **flash_attn_kwargs,
757
+ )
758
+
759
+ hidden_states = layer_outputs[0]
760
+
761
+ if output_attentions:
762
+ if all_self_attns is None:
763
+ all_self_attns = ()
764
+ all_self_attns += (layer_outputs[1],)
765
+
766
+ hidden_states = self.norm(hidden_states)
767
+
768
+ # add hidden states from the last decoder layer
769
+ if output_hidden_states:
770
+ if all_hidden_states is None:
771
+ all_hidden_states = ()
772
+ all_hidden_states += (hidden_states,)
773
+
774
+ return BaseModelOutputWithPast(
775
+ last_hidden_state=hidden_states,
776
+ past_key_values=past_key_values if use_cache else None, # pyright: ignore[reportArgumentType]
777
+ hidden_states=all_hidden_states, # pyright: ignore[reportArgumentType]
778
+ attentions=all_self_attns,
779
+ )
780
+
781
+ def _update_causal_mask(
782
+ self,
783
+ attention_mask: torch.Tensor | None,
784
+ input_tensor: torch.Tensor,
785
+ cache_position: torch.Tensor,
786
+ past_key_values: None | DynamicCache | Cache,
787
+ output_attentions: bool = False,
788
+ force_mask: bool = False,
789
+ ) -> torch.Tensor | None:
790
+ if self.config._attn_implementation == "flex_attention":
791
+ if isinstance(attention_mask, torch.Tensor):
792
+ attention_mask = make_flex_block_causal_mask(attention_mask) # type: ignore
793
+ return attention_mask
794
+
795
+ assert attention_mask is None or isinstance(attention_mask, torch.Tensor)
796
+ if self.config._attn_implementation == "flash_attention_2":
797
+ if attention_mask is not None and (force_mask or (attention_mask == 0.0).any()):
798
+ return attention_mask
799
+ return None
800
+
801
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
802
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
803
+ # to infer the attention mask.
804
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
805
+ using_compilable_cache = (
806
+ past_key_values.is_compileable if past_key_values is not None else False
807
+ )
808
+
809
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
810
+ if (
811
+ self.config._attn_implementation == "sdpa"
812
+ and not using_compilable_cache
813
+ and not output_attentions
814
+ ):
815
+ if not force_mask and AttentionMaskConverter._ignore_causal_mask_sdpa(
816
+ attention_mask,
817
+ inputs_embeds=input_tensor,
818
+ past_key_values_length=past_seen_tokens,
819
+ is_training=self.training,
820
+ ):
821
+ return None
822
+
823
+ dtype = input_tensor.dtype
824
+ sequence_length = input_tensor.shape[1]
825
+ if using_compilable_cache and past_key_values is not None:
826
+ target_length = past_key_values.get_max_cache_shape()
827
+ else:
828
+ target_length = (
829
+ attention_mask.shape[-1]
830
+ if isinstance(attention_mask, torch.Tensor)
831
+ else past_seen_tokens + sequence_length
832
+ )
833
+
834
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
835
+ assert target_length is not None
836
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
837
+ attention_mask,
838
+ sequence_length=sequence_length,
839
+ target_length=target_length,
840
+ dtype=dtype,
841
+ cache_position=cache_position,
842
+ batch_size=input_tensor.shape[0],
843
+ )
844
+
845
+ if (
846
+ self.config._attn_implementation == "sdpa"
847
+ and attention_mask is not None
848
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
849
+ and not output_attentions
850
+ ):
851
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
852
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
853
+ # Details: https://github.com/pytorch/pytorch/issues/110213
854
+ min_dtype = torch.finfo(dtype).min
855
+ causal_mask = AttentionMaskConverter._unmask_unattended(
856
+ type_cast(torch.FloatTensor, causal_mask), min_dtype
857
+ )
858
+
859
+ return causal_mask
860
+
861
+ @staticmethod
862
+ def _prepare_4d_causal_attention_mask_with_cache_position(
863
+ attention_mask: torch.Tensor | None,
864
+ sequence_length: int,
865
+ target_length: int,
866
+ dtype: torch.dtype,
867
+ cache_position: torch.Tensor,
868
+ batch_size: int,
869
+ **kwargs: Any,
870
+ ):
871
+ """
872
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
873
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
874
+
875
+ Args:
876
+ attention_mask (`torch.Tensor`):
877
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
878
+ `(batch_size, 1, query_length, key_value_length)`.
879
+ sequence_length (`int`):
880
+ The sequence length being processed.
881
+ target_length (`int`):
882
+ The target length: when generating with static cache, the mask should be as long as the static cache,
883
+ to account for the 0 padding, the part of the cache that is not filled yet.
884
+ dtype (`torch.dtype`):
885
+ The dtype to use for the 4D attention mask.
886
+ cache_position (`torch.Tensor`):
887
+ Indices depicting the position of the input sequence tokens in the sequence.
888
+ batch_size (`torch.Tensor`):
889
+ Batch size.
890
+ """
891
+ del kwargs
892
+ if attention_mask is not None and attention_mask.dim() == 4:
893
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
894
+ causal_mask = attention_mask
895
+ else:
896
+ min_dtype = torch.finfo(dtype).min
897
+ causal_mask = torch.full(
898
+ (sequence_length, target_length),
899
+ fill_value=min_dtype,
900
+ dtype=dtype,
901
+ device=cache_position.device,
902
+ )
903
+ if sequence_length != 1:
904
+ causal_mask = torch.triu(causal_mask, diagonal=1)
905
+ causal_mask *= torch.arange(
906
+ target_length, device=cache_position.device
907
+ ) > cache_position.reshape(-1, 1)
908
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
909
+ if attention_mask is not None:
910
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
911
+ mask_length = attention_mask.shape[-1]
912
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
913
+ :, None, None, :
914
+ ].to(causal_mask.device)
915
+ padding_mask = padding_mask == 0
916
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
917
+ padding_mask, min_dtype
918
+ )
919
+
920
+ return causal_mask
921
+
922
+
923
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
924
+
925
+
926
+ class Helium1ForCausalLM(Helium1PreTrainedModel, GenerationMixin):
927
+ _tied_weights_keys = ["lm_head.weight"]
928
+ _tp_plan = {"lm_head": "colwise_rep"}
929
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
930
+
931
+ def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
932
+ del kwargs
933
+ super().__init__(config)
934
+ self.model: Helium1Model
935
+ self.model = Helium1Model(config)
936
+ self.vocab_size = config.vocab_size
937
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
938
+ self._loss_function = ForCausalLMLoss
939
+
940
+ def get_input_embeddings(self) -> nn.Module:
941
+ return self.model.embed_tokens
942
+
943
+ def set_input_embeddings(self, value: nn.Module) -> None:
944
+ self.model.embed_tokens = value
945
+
946
+ def get_output_embeddings(self) -> nn.Module:
947
+ return self.lm_head
948
+
949
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
950
+ self.lm_head = new_embeddings
951
+
952
+ def set_decoder(self, decoder: Helium1Model) -> None:
953
+ self.model = decoder
954
+
955
+ def get_decoder(self) -> Helium1Model:
956
+ return self.model
957
+
958
+ @can_return_tuple
959
+ def forward(
960
+ self,
961
+ input_ids: None | torch.LongTensor = None,
962
+ attention_mask: None | torch.Tensor = None,
963
+ position_ids: None | torch.LongTensor = None,
964
+ past_key_values: None | Cache = None,
965
+ inputs_embeds: None | torch.Tensor = None,
966
+ image_embeds: None | torch.Tensor | list[torch.Tensor] = None,
967
+ image_embeds_insertion_points: None | list[torch.Tensor] = None,
968
+ labels: None | torch.LongTensor = None,
969
+ use_cache: None | bool = None,
970
+ output_attentions: None | bool = None,
971
+ output_hidden_states: None | bool = None,
972
+ cache_position: None | torch.LongTensor = None,
973
+ logits_to_keep: int | torch.Tensor = 0,
974
+ # CASA
975
+ casa_windows_info: None | dict = None,
976
+ **kwargs: Unpack[KwargsForCausalLM],
977
+ ) -> CausalHeliumOutput:
978
+ r"""
979
+ Helium1 augmented with CASA layers
980
+ """
981
+ output_attentions = (
982
+ output_attentions if output_attentions is not None else self.config.output_attentions
983
+ )
984
+ output_hidden_states = (
985
+ output_hidden_states
986
+ if output_hidden_states is not None
987
+ else self.config.output_hidden_states
988
+ )
989
+ if input_ids is not None:
990
+ assert inputs_embeds is None, (
991
+ "Need to provide only one of `input_ids` or `inputs_embeds`."
992
+ )
993
+ inputs_embeds = self.model.embed_tokens(input_ids)
994
+ assert inputs_embeds is not None
995
+
996
+ # Setup image + text token fusion
997
+ bs, og_seq_len, _ = inputs_embeds.shape
998
+ image_tokens_mask: torch.Tensor | None = None
999
+ casa_handler: CASAAttentionHandler | None = None
1000
+
1001
+ num_image_tokens = -1
1002
+ if image_embeds is not None:
1003
+ num_image_tokens = sum(_x.shape[0] for _x in image_embeds)
1004
+ assert image_embeds_insertion_points is not None, (
1005
+ "Missing image embeddings insertion points"
1006
+ )
1007
+ # B1. CASA layers: We need to init the shared Handler
1008
+ if self.model.config.casa_attention:
1009
+ casa_handler = CASAAttentionHandler(
1010
+ # for text tokens, we don't need the actual values
1011
+ inputs_embeds=torch.zeros_like(inputs_embeds),
1012
+ # for image embeddings, we put real inputs as this will be fixed
1013
+ image_embeds=image_embeds,
1014
+ image_embeds_insertion_points=image_embeds_insertion_points,
1015
+ # attention mask is only needed at inference / left padding
1016
+ attention_mask=None if self.training else attention_mask,
1017
+ rope_fn=self.model.rotary_emb,
1018
+ windows=self.model.config.casa_windows,
1019
+ use_asymetric_q_kv=self.model.config.casa_use_asymetric_qkv,
1020
+ # further params are fed to the funtion computing attention
1021
+ casa_windows_info=casa_windows_info,
1022
+ )
1023
+ # B2. Direct image insertion
1024
+ else:
1025
+ inputs_embeds, _, attention_mask, image_tokens_mask = insert_image_tokens(
1026
+ inputs_embeds=inputs_embeds,
1027
+ image_embeds=image_embeds,
1028
+ image_embeds_insertion_points=image_embeds_insertion_points,
1029
+ attention_mask=attention_mask,
1030
+ padding_side="right" if self.training else "left",
1031
+ recover_batch_dim=True,
1032
+ )
1033
+
1034
+ del image_embeds
1035
+ del input_ids
1036
+ outputs: BaseModelOutputWithPast = self.model(
1037
+ inputs_embeds=inputs_embeds,
1038
+ attention_mask=attention_mask,
1039
+ position_ids=position_ids,
1040
+ past_key_values=past_key_values,
1041
+ use_cache=use_cache,
1042
+ output_attentions=output_attentions,
1043
+ output_hidden_states=output_hidden_states,
1044
+ cache_position=cache_position,
1045
+ image_tokens_mask=image_tokens_mask,
1046
+ casa_handler=casa_handler,
1047
+ **kwargs,
1048
+ )
1049
+
1050
+ hidden_states = outputs.last_hidden_state
1051
+ assert hidden_states is not None
1052
+ if image_tokens_mask is not None:
1053
+ hidden_states = remove_image_tokens(hidden_states, image_tokens_mask)
1054
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1055
+ slice_indices = (
1056
+ slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1057
+ )
1058
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1059
+
1060
+ loss = None
1061
+ if labels is not None:
1062
+ loss = self.loss_function(
1063
+ logits=logits,
1064
+ labels=labels,
1065
+ vocab_size=self.config.vocab_size,
1066
+ **kwargs,
1067
+ )
1068
+ out = CausalHeliumOutput(
1069
+ loss=loss,
1070
+ logits=logits,
1071
+ past_key_values=outputs.past_key_values,
1072
+ hidden_states=outputs.hidden_states,
1073
+ attentions=outputs.attentions,
1074
+ num_image_tokens_log=torch.tensor(num_image_tokens).to(logits.device).to(torch.float),
1075
+ num_text_tokens_log=torch.tensor(og_seq_len).to(logits.device).to(torch.float),
1076
+ )
1077
+ return out
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84b4a971906aaacc7c0acf8ac98d4f59afe073954284790855b70f8ab8488df3
3
+ size 4987411648
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea916a815ac0deb1a397aa6ef2edf5a8baca4811327597fccf2f21ef67cf4295
3
+ size 4993506144
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0bc7d4d08240d3de6fedfbe8711c16345c70e0fbebddfc89685cb361f208aeb
3
+ size 2195900304
model.safetensors.index.json ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 12176723968
4
+ },
5
+ "weight_map": {
6
+ "image_prefix.enc.visual.blocks.0.attn.proj.bias": "model-00002-of-00003.safetensors",
7
+ "image_prefix.enc.visual.blocks.0.attn.proj.weight": "model-00002-of-00003.safetensors",
8
+ "image_prefix.enc.visual.blocks.0.attn.qkv.bias": "model-00002-of-00003.safetensors",
9
+ "image_prefix.enc.visual.blocks.0.attn.qkv.weight": "model-00002-of-00003.safetensors",
10
+ "image_prefix.enc.visual.blocks.0.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
11
+ "image_prefix.enc.visual.blocks.0.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
12
+ "image_prefix.enc.visual.blocks.0.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
13
+ "image_prefix.enc.visual.blocks.0.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
14
+ "image_prefix.enc.visual.blocks.0.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
15
+ "image_prefix.enc.visual.blocks.0.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
16
+ "image_prefix.enc.visual.blocks.0.norm1.weight": "model-00002-of-00003.safetensors",
17
+ "image_prefix.enc.visual.blocks.0.norm2.weight": "model-00002-of-00003.safetensors",
18
+ "image_prefix.enc.visual.blocks.1.attn.proj.bias": "model-00002-of-00003.safetensors",
19
+ "image_prefix.enc.visual.blocks.1.attn.proj.weight": "model-00002-of-00003.safetensors",
20
+ "image_prefix.enc.visual.blocks.1.attn.qkv.bias": "model-00002-of-00003.safetensors",
21
+ "image_prefix.enc.visual.blocks.1.attn.qkv.weight": "model-00002-of-00003.safetensors",
22
+ "image_prefix.enc.visual.blocks.1.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
23
+ "image_prefix.enc.visual.blocks.1.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
24
+ "image_prefix.enc.visual.blocks.1.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
25
+ "image_prefix.enc.visual.blocks.1.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
26
+ "image_prefix.enc.visual.blocks.1.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
27
+ "image_prefix.enc.visual.blocks.1.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
28
+ "image_prefix.enc.visual.blocks.1.norm1.weight": "model-00002-of-00003.safetensors",
29
+ "image_prefix.enc.visual.blocks.1.norm2.weight": "model-00002-of-00003.safetensors",
30
+ "image_prefix.enc.visual.blocks.10.attn.proj.bias": "model-00003-of-00003.safetensors",
31
+ "image_prefix.enc.visual.blocks.10.attn.proj.weight": "model-00003-of-00003.safetensors",
32
+ "image_prefix.enc.visual.blocks.10.attn.qkv.bias": "model-00003-of-00003.safetensors",
33
+ "image_prefix.enc.visual.blocks.10.attn.qkv.weight": "model-00003-of-00003.safetensors",
34
+ "image_prefix.enc.visual.blocks.10.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
35
+ "image_prefix.enc.visual.blocks.10.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
36
+ "image_prefix.enc.visual.blocks.10.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
37
+ "image_prefix.enc.visual.blocks.10.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
38
+ "image_prefix.enc.visual.blocks.10.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
39
+ "image_prefix.enc.visual.blocks.10.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
40
+ "image_prefix.enc.visual.blocks.10.norm1.weight": "model-00003-of-00003.safetensors",
41
+ "image_prefix.enc.visual.blocks.10.norm2.weight": "model-00003-of-00003.safetensors",
42
+ "image_prefix.enc.visual.blocks.11.attn.proj.bias": "model-00003-of-00003.safetensors",
43
+ "image_prefix.enc.visual.blocks.11.attn.proj.weight": "model-00003-of-00003.safetensors",
44
+ "image_prefix.enc.visual.blocks.11.attn.qkv.bias": "model-00003-of-00003.safetensors",
45
+ "image_prefix.enc.visual.blocks.11.attn.qkv.weight": "model-00003-of-00003.safetensors",
46
+ "image_prefix.enc.visual.blocks.11.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
47
+ "image_prefix.enc.visual.blocks.11.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
48
+ "image_prefix.enc.visual.blocks.11.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
49
+ "image_prefix.enc.visual.blocks.11.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
50
+ "image_prefix.enc.visual.blocks.11.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
51
+ "image_prefix.enc.visual.blocks.11.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
52
+ "image_prefix.enc.visual.blocks.11.norm1.weight": "model-00003-of-00003.safetensors",
53
+ "image_prefix.enc.visual.blocks.11.norm2.weight": "model-00003-of-00003.safetensors",
54
+ "image_prefix.enc.visual.blocks.12.attn.proj.bias": "model-00003-of-00003.safetensors",
55
+ "image_prefix.enc.visual.blocks.12.attn.proj.weight": "model-00003-of-00003.safetensors",
56
+ "image_prefix.enc.visual.blocks.12.attn.qkv.bias": "model-00003-of-00003.safetensors",
57
+ "image_prefix.enc.visual.blocks.12.attn.qkv.weight": "model-00003-of-00003.safetensors",
58
+ "image_prefix.enc.visual.blocks.12.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
59
+ "image_prefix.enc.visual.blocks.12.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
60
+ "image_prefix.enc.visual.blocks.12.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
61
+ "image_prefix.enc.visual.blocks.12.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
62
+ "image_prefix.enc.visual.blocks.12.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
63
+ "image_prefix.enc.visual.blocks.12.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
64
+ "image_prefix.enc.visual.blocks.12.norm1.weight": "model-00003-of-00003.safetensors",
65
+ "image_prefix.enc.visual.blocks.12.norm2.weight": "model-00003-of-00003.safetensors",
66
+ "image_prefix.enc.visual.blocks.13.attn.proj.bias": "model-00003-of-00003.safetensors",
67
+ "image_prefix.enc.visual.blocks.13.attn.proj.weight": "model-00003-of-00003.safetensors",
68
+ "image_prefix.enc.visual.blocks.13.attn.qkv.bias": "model-00003-of-00003.safetensors",
69
+ "image_prefix.enc.visual.blocks.13.attn.qkv.weight": "model-00003-of-00003.safetensors",
70
+ "image_prefix.enc.visual.blocks.13.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
71
+ "image_prefix.enc.visual.blocks.13.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
72
+ "image_prefix.enc.visual.blocks.13.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
73
+ "image_prefix.enc.visual.blocks.13.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
74
+ "image_prefix.enc.visual.blocks.13.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
75
+ "image_prefix.enc.visual.blocks.13.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
76
+ "image_prefix.enc.visual.blocks.13.norm1.weight": "model-00003-of-00003.safetensors",
77
+ "image_prefix.enc.visual.blocks.13.norm2.weight": "model-00003-of-00003.safetensors",
78
+ "image_prefix.enc.visual.blocks.14.attn.proj.bias": "model-00003-of-00003.safetensors",
79
+ "image_prefix.enc.visual.blocks.14.attn.proj.weight": "model-00003-of-00003.safetensors",
80
+ "image_prefix.enc.visual.blocks.14.attn.qkv.bias": "model-00003-of-00003.safetensors",
81
+ "image_prefix.enc.visual.blocks.14.attn.qkv.weight": "model-00003-of-00003.safetensors",
82
+ "image_prefix.enc.visual.blocks.14.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
83
+ "image_prefix.enc.visual.blocks.14.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
84
+ "image_prefix.enc.visual.blocks.14.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
85
+ "image_prefix.enc.visual.blocks.14.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
86
+ "image_prefix.enc.visual.blocks.14.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
87
+ "image_prefix.enc.visual.blocks.14.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
88
+ "image_prefix.enc.visual.blocks.14.norm1.weight": "model-00003-of-00003.safetensors",
89
+ "image_prefix.enc.visual.blocks.14.norm2.weight": "model-00003-of-00003.safetensors",
90
+ "image_prefix.enc.visual.blocks.15.attn.proj.bias": "model-00003-of-00003.safetensors",
91
+ "image_prefix.enc.visual.blocks.15.attn.proj.weight": "model-00003-of-00003.safetensors",
92
+ "image_prefix.enc.visual.blocks.15.attn.qkv.bias": "model-00003-of-00003.safetensors",
93
+ "image_prefix.enc.visual.blocks.15.attn.qkv.weight": "model-00003-of-00003.safetensors",
94
+ "image_prefix.enc.visual.blocks.15.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
95
+ "image_prefix.enc.visual.blocks.15.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
96
+ "image_prefix.enc.visual.blocks.15.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
97
+ "image_prefix.enc.visual.blocks.15.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
98
+ "image_prefix.enc.visual.blocks.15.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
99
+ "image_prefix.enc.visual.blocks.15.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
100
+ "image_prefix.enc.visual.blocks.15.norm1.weight": "model-00003-of-00003.safetensors",
101
+ "image_prefix.enc.visual.blocks.15.norm2.weight": "model-00003-of-00003.safetensors",
102
+ "image_prefix.enc.visual.blocks.16.attn.proj.bias": "model-00003-of-00003.safetensors",
103
+ "image_prefix.enc.visual.blocks.16.attn.proj.weight": "model-00003-of-00003.safetensors",
104
+ "image_prefix.enc.visual.blocks.16.attn.qkv.bias": "model-00003-of-00003.safetensors",
105
+ "image_prefix.enc.visual.blocks.16.attn.qkv.weight": "model-00003-of-00003.safetensors",
106
+ "image_prefix.enc.visual.blocks.16.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
107
+ "image_prefix.enc.visual.blocks.16.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
108
+ "image_prefix.enc.visual.blocks.16.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
109
+ "image_prefix.enc.visual.blocks.16.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
110
+ "image_prefix.enc.visual.blocks.16.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
111
+ "image_prefix.enc.visual.blocks.16.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
112
+ "image_prefix.enc.visual.blocks.16.norm1.weight": "model-00003-of-00003.safetensors",
113
+ "image_prefix.enc.visual.blocks.16.norm2.weight": "model-00003-of-00003.safetensors",
114
+ "image_prefix.enc.visual.blocks.17.attn.proj.bias": "model-00003-of-00003.safetensors",
115
+ "image_prefix.enc.visual.blocks.17.attn.proj.weight": "model-00003-of-00003.safetensors",
116
+ "image_prefix.enc.visual.blocks.17.attn.qkv.bias": "model-00003-of-00003.safetensors",
117
+ "image_prefix.enc.visual.blocks.17.attn.qkv.weight": "model-00003-of-00003.safetensors",
118
+ "image_prefix.enc.visual.blocks.17.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
119
+ "image_prefix.enc.visual.blocks.17.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
120
+ "image_prefix.enc.visual.blocks.17.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
121
+ "image_prefix.enc.visual.blocks.17.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
122
+ "image_prefix.enc.visual.blocks.17.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
123
+ "image_prefix.enc.visual.blocks.17.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
124
+ "image_prefix.enc.visual.blocks.17.norm1.weight": "model-00003-of-00003.safetensors",
125
+ "image_prefix.enc.visual.blocks.17.norm2.weight": "model-00003-of-00003.safetensors",
126
+ "image_prefix.enc.visual.blocks.18.attn.proj.bias": "model-00003-of-00003.safetensors",
127
+ "image_prefix.enc.visual.blocks.18.attn.proj.weight": "model-00003-of-00003.safetensors",
128
+ "image_prefix.enc.visual.blocks.18.attn.qkv.bias": "model-00003-of-00003.safetensors",
129
+ "image_prefix.enc.visual.blocks.18.attn.qkv.weight": "model-00003-of-00003.safetensors",
130
+ "image_prefix.enc.visual.blocks.18.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
131
+ "image_prefix.enc.visual.blocks.18.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
132
+ "image_prefix.enc.visual.blocks.18.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
133
+ "image_prefix.enc.visual.blocks.18.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
134
+ "image_prefix.enc.visual.blocks.18.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
135
+ "image_prefix.enc.visual.blocks.18.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
136
+ "image_prefix.enc.visual.blocks.18.norm1.weight": "model-00003-of-00003.safetensors",
137
+ "image_prefix.enc.visual.blocks.18.norm2.weight": "model-00003-of-00003.safetensors",
138
+ "image_prefix.enc.visual.blocks.19.attn.proj.bias": "model-00003-of-00003.safetensors",
139
+ "image_prefix.enc.visual.blocks.19.attn.proj.weight": "model-00003-of-00003.safetensors",
140
+ "image_prefix.enc.visual.blocks.19.attn.qkv.bias": "model-00003-of-00003.safetensors",
141
+ "image_prefix.enc.visual.blocks.19.attn.qkv.weight": "model-00003-of-00003.safetensors",
142
+ "image_prefix.enc.visual.blocks.19.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
143
+ "image_prefix.enc.visual.blocks.19.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
144
+ "image_prefix.enc.visual.blocks.19.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
145
+ "image_prefix.enc.visual.blocks.19.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
146
+ "image_prefix.enc.visual.blocks.19.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
147
+ "image_prefix.enc.visual.blocks.19.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
148
+ "image_prefix.enc.visual.blocks.19.norm1.weight": "model-00003-of-00003.safetensors",
149
+ "image_prefix.enc.visual.blocks.19.norm2.weight": "model-00003-of-00003.safetensors",
150
+ "image_prefix.enc.visual.blocks.2.attn.proj.bias": "model-00002-of-00003.safetensors",
151
+ "image_prefix.enc.visual.blocks.2.attn.proj.weight": "model-00002-of-00003.safetensors",
152
+ "image_prefix.enc.visual.blocks.2.attn.qkv.bias": "model-00002-of-00003.safetensors",
153
+ "image_prefix.enc.visual.blocks.2.attn.qkv.weight": "model-00002-of-00003.safetensors",
154
+ "image_prefix.enc.visual.blocks.2.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
155
+ "image_prefix.enc.visual.blocks.2.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
156
+ "image_prefix.enc.visual.blocks.2.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
157
+ "image_prefix.enc.visual.blocks.2.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
158
+ "image_prefix.enc.visual.blocks.2.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
159
+ "image_prefix.enc.visual.blocks.2.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
160
+ "image_prefix.enc.visual.blocks.2.norm1.weight": "model-00002-of-00003.safetensors",
161
+ "image_prefix.enc.visual.blocks.2.norm2.weight": "model-00002-of-00003.safetensors",
162
+ "image_prefix.enc.visual.blocks.20.attn.proj.bias": "model-00003-of-00003.safetensors",
163
+ "image_prefix.enc.visual.blocks.20.attn.proj.weight": "model-00003-of-00003.safetensors",
164
+ "image_prefix.enc.visual.blocks.20.attn.qkv.bias": "model-00003-of-00003.safetensors",
165
+ "image_prefix.enc.visual.blocks.20.attn.qkv.weight": "model-00003-of-00003.safetensors",
166
+ "image_prefix.enc.visual.blocks.20.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
167
+ "image_prefix.enc.visual.blocks.20.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
168
+ "image_prefix.enc.visual.blocks.20.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
169
+ "image_prefix.enc.visual.blocks.20.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
170
+ "image_prefix.enc.visual.blocks.20.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
171
+ "image_prefix.enc.visual.blocks.20.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
172
+ "image_prefix.enc.visual.blocks.20.norm1.weight": "model-00003-of-00003.safetensors",
173
+ "image_prefix.enc.visual.blocks.20.norm2.weight": "model-00003-of-00003.safetensors",
174
+ "image_prefix.enc.visual.blocks.21.attn.proj.bias": "model-00003-of-00003.safetensors",
175
+ "image_prefix.enc.visual.blocks.21.attn.proj.weight": "model-00003-of-00003.safetensors",
176
+ "image_prefix.enc.visual.blocks.21.attn.qkv.bias": "model-00003-of-00003.safetensors",
177
+ "image_prefix.enc.visual.blocks.21.attn.qkv.weight": "model-00003-of-00003.safetensors",
178
+ "image_prefix.enc.visual.blocks.21.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
179
+ "image_prefix.enc.visual.blocks.21.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
180
+ "image_prefix.enc.visual.blocks.21.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
181
+ "image_prefix.enc.visual.blocks.21.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
182
+ "image_prefix.enc.visual.blocks.21.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
183
+ "image_prefix.enc.visual.blocks.21.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
184
+ "image_prefix.enc.visual.blocks.21.norm1.weight": "model-00003-of-00003.safetensors",
185
+ "image_prefix.enc.visual.blocks.21.norm2.weight": "model-00003-of-00003.safetensors",
186
+ "image_prefix.enc.visual.blocks.22.attn.proj.bias": "model-00003-of-00003.safetensors",
187
+ "image_prefix.enc.visual.blocks.22.attn.proj.weight": "model-00003-of-00003.safetensors",
188
+ "image_prefix.enc.visual.blocks.22.attn.qkv.bias": "model-00003-of-00003.safetensors",
189
+ "image_prefix.enc.visual.blocks.22.attn.qkv.weight": "model-00003-of-00003.safetensors",
190
+ "image_prefix.enc.visual.blocks.22.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
191
+ "image_prefix.enc.visual.blocks.22.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
192
+ "image_prefix.enc.visual.blocks.22.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
193
+ "image_prefix.enc.visual.blocks.22.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
194
+ "image_prefix.enc.visual.blocks.22.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
195
+ "image_prefix.enc.visual.blocks.22.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
196
+ "image_prefix.enc.visual.blocks.22.norm1.weight": "model-00003-of-00003.safetensors",
197
+ "image_prefix.enc.visual.blocks.22.norm2.weight": "model-00003-of-00003.safetensors",
198
+ "image_prefix.enc.visual.blocks.23.attn.proj.bias": "model-00003-of-00003.safetensors",
199
+ "image_prefix.enc.visual.blocks.23.attn.proj.weight": "model-00003-of-00003.safetensors",
200
+ "image_prefix.enc.visual.blocks.23.attn.qkv.bias": "model-00003-of-00003.safetensors",
201
+ "image_prefix.enc.visual.blocks.23.attn.qkv.weight": "model-00003-of-00003.safetensors",
202
+ "image_prefix.enc.visual.blocks.23.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
203
+ "image_prefix.enc.visual.blocks.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
204
+ "image_prefix.enc.visual.blocks.23.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
205
+ "image_prefix.enc.visual.blocks.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
206
+ "image_prefix.enc.visual.blocks.23.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
207
+ "image_prefix.enc.visual.blocks.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
208
+ "image_prefix.enc.visual.blocks.23.norm1.weight": "model-00003-of-00003.safetensors",
209
+ "image_prefix.enc.visual.blocks.23.norm2.weight": "model-00003-of-00003.safetensors",
210
+ "image_prefix.enc.visual.blocks.24.attn.proj.bias": "model-00003-of-00003.safetensors",
211
+ "image_prefix.enc.visual.blocks.24.attn.proj.weight": "model-00003-of-00003.safetensors",
212
+ "image_prefix.enc.visual.blocks.24.attn.qkv.bias": "model-00003-of-00003.safetensors",
213
+ "image_prefix.enc.visual.blocks.24.attn.qkv.weight": "model-00003-of-00003.safetensors",
214
+ "image_prefix.enc.visual.blocks.24.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
215
+ "image_prefix.enc.visual.blocks.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
216
+ "image_prefix.enc.visual.blocks.24.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
217
+ "image_prefix.enc.visual.blocks.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
218
+ "image_prefix.enc.visual.blocks.24.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
219
+ "image_prefix.enc.visual.blocks.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
220
+ "image_prefix.enc.visual.blocks.24.norm1.weight": "model-00003-of-00003.safetensors",
221
+ "image_prefix.enc.visual.blocks.24.norm2.weight": "model-00003-of-00003.safetensors",
222
+ "image_prefix.enc.visual.blocks.25.attn.proj.bias": "model-00003-of-00003.safetensors",
223
+ "image_prefix.enc.visual.blocks.25.attn.proj.weight": "model-00003-of-00003.safetensors",
224
+ "image_prefix.enc.visual.blocks.25.attn.qkv.bias": "model-00003-of-00003.safetensors",
225
+ "image_prefix.enc.visual.blocks.25.attn.qkv.weight": "model-00003-of-00003.safetensors",
226
+ "image_prefix.enc.visual.blocks.25.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
227
+ "image_prefix.enc.visual.blocks.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
228
+ "image_prefix.enc.visual.blocks.25.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
229
+ "image_prefix.enc.visual.blocks.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
230
+ "image_prefix.enc.visual.blocks.25.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
231
+ "image_prefix.enc.visual.blocks.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
232
+ "image_prefix.enc.visual.blocks.25.norm1.weight": "model-00003-of-00003.safetensors",
233
+ "image_prefix.enc.visual.blocks.25.norm2.weight": "model-00003-of-00003.safetensors",
234
+ "image_prefix.enc.visual.blocks.26.attn.proj.bias": "model-00003-of-00003.safetensors",
235
+ "image_prefix.enc.visual.blocks.26.attn.proj.weight": "model-00003-of-00003.safetensors",
236
+ "image_prefix.enc.visual.blocks.26.attn.qkv.bias": "model-00003-of-00003.safetensors",
237
+ "image_prefix.enc.visual.blocks.26.attn.qkv.weight": "model-00003-of-00003.safetensors",
238
+ "image_prefix.enc.visual.blocks.26.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
239
+ "image_prefix.enc.visual.blocks.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
240
+ "image_prefix.enc.visual.blocks.26.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
241
+ "image_prefix.enc.visual.blocks.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
242
+ "image_prefix.enc.visual.blocks.26.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
243
+ "image_prefix.enc.visual.blocks.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
244
+ "image_prefix.enc.visual.blocks.26.norm1.weight": "model-00003-of-00003.safetensors",
245
+ "image_prefix.enc.visual.blocks.26.norm2.weight": "model-00003-of-00003.safetensors",
246
+ "image_prefix.enc.visual.blocks.27.attn.proj.bias": "model-00003-of-00003.safetensors",
247
+ "image_prefix.enc.visual.blocks.27.attn.proj.weight": "model-00003-of-00003.safetensors",
248
+ "image_prefix.enc.visual.blocks.27.attn.qkv.bias": "model-00003-of-00003.safetensors",
249
+ "image_prefix.enc.visual.blocks.27.attn.qkv.weight": "model-00003-of-00003.safetensors",
250
+ "image_prefix.enc.visual.blocks.27.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
251
+ "image_prefix.enc.visual.blocks.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
+ "image_prefix.enc.visual.blocks.27.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
253
+ "image_prefix.enc.visual.blocks.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
254
+ "image_prefix.enc.visual.blocks.27.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
255
+ "image_prefix.enc.visual.blocks.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
256
+ "image_prefix.enc.visual.blocks.27.norm1.weight": "model-00003-of-00003.safetensors",
257
+ "image_prefix.enc.visual.blocks.27.norm2.weight": "model-00003-of-00003.safetensors",
258
+ "image_prefix.enc.visual.blocks.28.attn.proj.bias": "model-00003-of-00003.safetensors",
259
+ "image_prefix.enc.visual.blocks.28.attn.proj.weight": "model-00003-of-00003.safetensors",
260
+ "image_prefix.enc.visual.blocks.28.attn.qkv.bias": "model-00003-of-00003.safetensors",
261
+ "image_prefix.enc.visual.blocks.28.attn.qkv.weight": "model-00003-of-00003.safetensors",
262
+ "image_prefix.enc.visual.blocks.28.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
263
+ "image_prefix.enc.visual.blocks.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
264
+ "image_prefix.enc.visual.blocks.28.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
265
+ "image_prefix.enc.visual.blocks.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
266
+ "image_prefix.enc.visual.blocks.28.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
267
+ "image_prefix.enc.visual.blocks.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
268
+ "image_prefix.enc.visual.blocks.28.norm1.weight": "model-00003-of-00003.safetensors",
269
+ "image_prefix.enc.visual.blocks.28.norm2.weight": "model-00003-of-00003.safetensors",
270
+ "image_prefix.enc.visual.blocks.29.attn.proj.bias": "model-00003-of-00003.safetensors",
271
+ "image_prefix.enc.visual.blocks.29.attn.proj.weight": "model-00003-of-00003.safetensors",
272
+ "image_prefix.enc.visual.blocks.29.attn.qkv.bias": "model-00003-of-00003.safetensors",
273
+ "image_prefix.enc.visual.blocks.29.attn.qkv.weight": "model-00003-of-00003.safetensors",
274
+ "image_prefix.enc.visual.blocks.29.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
275
+ "image_prefix.enc.visual.blocks.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
276
+ "image_prefix.enc.visual.blocks.29.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
277
+ "image_prefix.enc.visual.blocks.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
278
+ "image_prefix.enc.visual.blocks.29.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
279
+ "image_prefix.enc.visual.blocks.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
280
+ "image_prefix.enc.visual.blocks.29.norm1.weight": "model-00003-of-00003.safetensors",
281
+ "image_prefix.enc.visual.blocks.29.norm2.weight": "model-00003-of-00003.safetensors",
282
+ "image_prefix.enc.visual.blocks.3.attn.proj.bias": "model-00002-of-00003.safetensors",
283
+ "image_prefix.enc.visual.blocks.3.attn.proj.weight": "model-00002-of-00003.safetensors",
284
+ "image_prefix.enc.visual.blocks.3.attn.qkv.bias": "model-00002-of-00003.safetensors",
285
+ "image_prefix.enc.visual.blocks.3.attn.qkv.weight": "model-00002-of-00003.safetensors",
286
+ "image_prefix.enc.visual.blocks.3.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
287
+ "image_prefix.enc.visual.blocks.3.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
288
+ "image_prefix.enc.visual.blocks.3.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
289
+ "image_prefix.enc.visual.blocks.3.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
290
+ "image_prefix.enc.visual.blocks.3.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
291
+ "image_prefix.enc.visual.blocks.3.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
292
+ "image_prefix.enc.visual.blocks.3.norm1.weight": "model-00002-of-00003.safetensors",
293
+ "image_prefix.enc.visual.blocks.3.norm2.weight": "model-00002-of-00003.safetensors",
294
+ "image_prefix.enc.visual.blocks.30.attn.proj.bias": "model-00003-of-00003.safetensors",
295
+ "image_prefix.enc.visual.blocks.30.attn.proj.weight": "model-00003-of-00003.safetensors",
296
+ "image_prefix.enc.visual.blocks.30.attn.qkv.bias": "model-00003-of-00003.safetensors",
297
+ "image_prefix.enc.visual.blocks.30.attn.qkv.weight": "model-00003-of-00003.safetensors",
298
+ "image_prefix.enc.visual.blocks.30.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
299
+ "image_prefix.enc.visual.blocks.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
300
+ "image_prefix.enc.visual.blocks.30.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
301
+ "image_prefix.enc.visual.blocks.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
302
+ "image_prefix.enc.visual.blocks.30.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
303
+ "image_prefix.enc.visual.blocks.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
304
+ "image_prefix.enc.visual.blocks.30.norm1.weight": "model-00003-of-00003.safetensors",
305
+ "image_prefix.enc.visual.blocks.30.norm2.weight": "model-00003-of-00003.safetensors",
306
+ "image_prefix.enc.visual.blocks.31.attn.proj.bias": "model-00003-of-00003.safetensors",
307
+ "image_prefix.enc.visual.blocks.31.attn.proj.weight": "model-00003-of-00003.safetensors",
308
+ "image_prefix.enc.visual.blocks.31.attn.qkv.bias": "model-00003-of-00003.safetensors",
309
+ "image_prefix.enc.visual.blocks.31.attn.qkv.weight": "model-00003-of-00003.safetensors",
310
+ "image_prefix.enc.visual.blocks.31.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
311
+ "image_prefix.enc.visual.blocks.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
312
+ "image_prefix.enc.visual.blocks.31.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
313
+ "image_prefix.enc.visual.blocks.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
314
+ "image_prefix.enc.visual.blocks.31.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
315
+ "image_prefix.enc.visual.blocks.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
316
+ "image_prefix.enc.visual.blocks.31.norm1.weight": "model-00003-of-00003.safetensors",
317
+ "image_prefix.enc.visual.blocks.31.norm2.weight": "model-00003-of-00003.safetensors",
318
+ "image_prefix.enc.visual.blocks.4.attn.proj.bias": "model-00002-of-00003.safetensors",
319
+ "image_prefix.enc.visual.blocks.4.attn.proj.weight": "model-00002-of-00003.safetensors",
320
+ "image_prefix.enc.visual.blocks.4.attn.qkv.bias": "model-00002-of-00003.safetensors",
321
+ "image_prefix.enc.visual.blocks.4.attn.qkv.weight": "model-00002-of-00003.safetensors",
322
+ "image_prefix.enc.visual.blocks.4.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
323
+ "image_prefix.enc.visual.blocks.4.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
324
+ "image_prefix.enc.visual.blocks.4.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
325
+ "image_prefix.enc.visual.blocks.4.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
326
+ "image_prefix.enc.visual.blocks.4.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
327
+ "image_prefix.enc.visual.blocks.4.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
328
+ "image_prefix.enc.visual.blocks.4.norm1.weight": "model-00002-of-00003.safetensors",
329
+ "image_prefix.enc.visual.blocks.4.norm2.weight": "model-00002-of-00003.safetensors",
330
+ "image_prefix.enc.visual.blocks.5.attn.proj.bias": "model-00002-of-00003.safetensors",
331
+ "image_prefix.enc.visual.blocks.5.attn.proj.weight": "model-00002-of-00003.safetensors",
332
+ "image_prefix.enc.visual.blocks.5.attn.qkv.bias": "model-00002-of-00003.safetensors",
333
+ "image_prefix.enc.visual.blocks.5.attn.qkv.weight": "model-00002-of-00003.safetensors",
334
+ "image_prefix.enc.visual.blocks.5.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
335
+ "image_prefix.enc.visual.blocks.5.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
336
+ "image_prefix.enc.visual.blocks.5.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
337
+ "image_prefix.enc.visual.blocks.5.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
338
+ "image_prefix.enc.visual.blocks.5.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
339
+ "image_prefix.enc.visual.blocks.5.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
340
+ "image_prefix.enc.visual.blocks.5.norm1.weight": "model-00002-of-00003.safetensors",
341
+ "image_prefix.enc.visual.blocks.5.norm2.weight": "model-00002-of-00003.safetensors",
342
+ "image_prefix.enc.visual.blocks.6.attn.proj.bias": "model-00003-of-00003.safetensors",
343
+ "image_prefix.enc.visual.blocks.6.attn.proj.weight": "model-00003-of-00003.safetensors",
344
+ "image_prefix.enc.visual.blocks.6.attn.qkv.bias": "model-00003-of-00003.safetensors",
345
+ "image_prefix.enc.visual.blocks.6.attn.qkv.weight": "model-00003-of-00003.safetensors",
346
+ "image_prefix.enc.visual.blocks.6.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
347
+ "image_prefix.enc.visual.blocks.6.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
348
+ "image_prefix.enc.visual.blocks.6.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
349
+ "image_prefix.enc.visual.blocks.6.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
350
+ "image_prefix.enc.visual.blocks.6.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
351
+ "image_prefix.enc.visual.blocks.6.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
352
+ "image_prefix.enc.visual.blocks.6.norm1.weight": "model-00002-of-00003.safetensors",
353
+ "image_prefix.enc.visual.blocks.6.norm2.weight": "model-00002-of-00003.safetensors",
354
+ "image_prefix.enc.visual.blocks.7.attn.proj.bias": "model-00003-of-00003.safetensors",
355
+ "image_prefix.enc.visual.blocks.7.attn.proj.weight": "model-00003-of-00003.safetensors",
356
+ "image_prefix.enc.visual.blocks.7.attn.qkv.bias": "model-00003-of-00003.safetensors",
357
+ "image_prefix.enc.visual.blocks.7.attn.qkv.weight": "model-00003-of-00003.safetensors",
358
+ "image_prefix.enc.visual.blocks.7.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
359
+ "image_prefix.enc.visual.blocks.7.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
360
+ "image_prefix.enc.visual.blocks.7.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
361
+ "image_prefix.enc.visual.blocks.7.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
362
+ "image_prefix.enc.visual.blocks.7.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
363
+ "image_prefix.enc.visual.blocks.7.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
364
+ "image_prefix.enc.visual.blocks.7.norm1.weight": "model-00003-of-00003.safetensors",
365
+ "image_prefix.enc.visual.blocks.7.norm2.weight": "model-00003-of-00003.safetensors",
366
+ "image_prefix.enc.visual.blocks.8.attn.proj.bias": "model-00003-of-00003.safetensors",
367
+ "image_prefix.enc.visual.blocks.8.attn.proj.weight": "model-00003-of-00003.safetensors",
368
+ "image_prefix.enc.visual.blocks.8.attn.qkv.bias": "model-00003-of-00003.safetensors",
369
+ "image_prefix.enc.visual.blocks.8.attn.qkv.weight": "model-00003-of-00003.safetensors",
370
+ "image_prefix.enc.visual.blocks.8.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
371
+ "image_prefix.enc.visual.blocks.8.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
372
+ "image_prefix.enc.visual.blocks.8.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
373
+ "image_prefix.enc.visual.blocks.8.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
374
+ "image_prefix.enc.visual.blocks.8.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
375
+ "image_prefix.enc.visual.blocks.8.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
376
+ "image_prefix.enc.visual.blocks.8.norm1.weight": "model-00003-of-00003.safetensors",
377
+ "image_prefix.enc.visual.blocks.8.norm2.weight": "model-00003-of-00003.safetensors",
378
+ "image_prefix.enc.visual.blocks.9.attn.proj.bias": "model-00003-of-00003.safetensors",
379
+ "image_prefix.enc.visual.blocks.9.attn.proj.weight": "model-00003-of-00003.safetensors",
380
+ "image_prefix.enc.visual.blocks.9.attn.qkv.bias": "model-00003-of-00003.safetensors",
381
+ "image_prefix.enc.visual.blocks.9.attn.qkv.weight": "model-00003-of-00003.safetensors",
382
+ "image_prefix.enc.visual.blocks.9.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
383
+ "image_prefix.enc.visual.blocks.9.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
384
+ "image_prefix.enc.visual.blocks.9.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
385
+ "image_prefix.enc.visual.blocks.9.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
386
+ "image_prefix.enc.visual.blocks.9.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
387
+ "image_prefix.enc.visual.blocks.9.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
388
+ "image_prefix.enc.visual.blocks.9.norm1.weight": "model-00003-of-00003.safetensors",
389
+ "image_prefix.enc.visual.blocks.9.norm2.weight": "model-00003-of-00003.safetensors",
390
+ "image_prefix.enc.visual.merger.ln_q.weight": "model-00003-of-00003.safetensors",
391
+ "image_prefix.enc.visual.merger.mlp.0.bias": "model-00003-of-00003.safetensors",
392
+ "image_prefix.enc.visual.merger.mlp.0.weight": "model-00003-of-00003.safetensors",
393
+ "image_prefix.enc.visual.merger.mlp.2.bias": "model-00003-of-00003.safetensors",
394
+ "image_prefix.enc.visual.merger.mlp.2.weight": "model-00003-of-00003.safetensors",
395
+ "image_prefix.enc.visual.patch_embed.proj.weight": "model-00002-of-00003.safetensors",
396
+ "image_prefix.norm_extra.weight": "model-00003-of-00003.safetensors",
397
+ "lm_head.weight": "model-00002-of-00003.safetensors",
398
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
399
+ "model.layers.0.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
400
+ "model.layers.0.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
401
+ "model.layers.0.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
402
+ "model.layers.0.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
403
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
404
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
405
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
406
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
407
+ "model.layers.0.norm_cross.weight": "model-00001-of-00003.safetensors",
408
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
409
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
410
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
411
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
412
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
413
+ "model.layers.1.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
414
+ "model.layers.1.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
415
+ "model.layers.1.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
416
+ "model.layers.1.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
417
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
418
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
419
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
420
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
421
+ "model.layers.1.norm_cross.weight": "model-00001-of-00003.safetensors",
422
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
423
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
424
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
425
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
426
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
427
+ "model.layers.10.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
428
+ "model.layers.10.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
429
+ "model.layers.10.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
430
+ "model.layers.10.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
431
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
432
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
433
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
434
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
435
+ "model.layers.10.norm_cross.weight": "model-00001-of-00003.safetensors",
436
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
437
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
438
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
439
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
440
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
441
+ "model.layers.11.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
442
+ "model.layers.11.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
443
+ "model.layers.11.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
444
+ "model.layers.11.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
445
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
446
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
447
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
448
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
449
+ "model.layers.11.norm_cross.weight": "model-00001-of-00003.safetensors",
450
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
451
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
452
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
453
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
454
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
455
+ "model.layers.12.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
456
+ "model.layers.12.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
457
+ "model.layers.12.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
458
+ "model.layers.12.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
459
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00003.safetensors",
460
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
461
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
462
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
463
+ "model.layers.12.norm_cross.weight": "model-00001-of-00003.safetensors",
464
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
465
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
466
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
467
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
468
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
469
+ "model.layers.13.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
470
+ "model.layers.13.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
471
+ "model.layers.13.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
472
+ "model.layers.13.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
473
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00003.safetensors",
474
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
475
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
476
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
477
+ "model.layers.13.norm_cross.weight": "model-00001-of-00003.safetensors",
478
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
479
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
480
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
481
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
482
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
483
+ "model.layers.14.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
484
+ "model.layers.14.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
485
+ "model.layers.14.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
486
+ "model.layers.14.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
487
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00003.safetensors",
488
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
489
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
490
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
491
+ "model.layers.14.norm_cross.weight": "model-00002-of-00003.safetensors",
492
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
493
+ "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
494
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
495
+ "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
496
+ "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
497
+ "model.layers.15.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
498
+ "model.layers.15.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
499
+ "model.layers.15.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
500
+ "model.layers.15.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
501
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
502
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
503
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
504
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
505
+ "model.layers.15.norm_cross.weight": "model-00002-of-00003.safetensors",
506
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
507
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
508
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
509
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
510
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
511
+ "model.layers.16.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
512
+ "model.layers.16.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
513
+ "model.layers.16.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
514
+ "model.layers.16.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
515
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
516
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
517
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
518
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
519
+ "model.layers.16.norm_cross.weight": "model-00002-of-00003.safetensors",
520
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
521
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
522
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
523
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
524
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
525
+ "model.layers.17.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
526
+ "model.layers.17.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
527
+ "model.layers.17.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
528
+ "model.layers.17.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
529
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
530
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
531
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
532
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
533
+ "model.layers.17.norm_cross.weight": "model-00002-of-00003.safetensors",
534
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
535
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
536
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
537
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
538
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
539
+ "model.layers.18.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
540
+ "model.layers.18.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
541
+ "model.layers.18.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
542
+ "model.layers.18.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
543
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
544
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
545
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
546
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
547
+ "model.layers.18.norm_cross.weight": "model-00002-of-00003.safetensors",
548
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
549
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
550
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
551
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
552
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
553
+ "model.layers.19.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
554
+ "model.layers.19.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
555
+ "model.layers.19.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
556
+ "model.layers.19.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
557
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
558
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
559
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
560
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
561
+ "model.layers.19.norm_cross.weight": "model-00002-of-00003.safetensors",
562
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
563
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
564
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
565
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
566
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
567
+ "model.layers.2.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
568
+ "model.layers.2.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
569
+ "model.layers.2.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
570
+ "model.layers.2.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
571
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
572
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
573
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
574
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
575
+ "model.layers.2.norm_cross.weight": "model-00001-of-00003.safetensors",
576
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
577
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
578
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
579
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
580
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
581
+ "model.layers.20.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
582
+ "model.layers.20.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
583
+ "model.layers.20.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
584
+ "model.layers.20.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
585
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
586
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
587
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
588
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
589
+ "model.layers.20.norm_cross.weight": "model-00002-of-00003.safetensors",
590
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
591
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
592
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
593
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
594
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
595
+ "model.layers.21.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
596
+ "model.layers.21.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
597
+ "model.layers.21.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
598
+ "model.layers.21.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
599
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
600
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
601
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
602
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
603
+ "model.layers.21.norm_cross.weight": "model-00002-of-00003.safetensors",
604
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
605
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
606
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
607
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
608
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
609
+ "model.layers.22.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
610
+ "model.layers.22.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
611
+ "model.layers.22.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
612
+ "model.layers.22.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
613
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
614
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
615
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
616
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
617
+ "model.layers.22.norm_cross.weight": "model-00002-of-00003.safetensors",
618
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
619
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
620
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
621
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
622
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
623
+ "model.layers.23.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
624
+ "model.layers.23.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
625
+ "model.layers.23.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
626
+ "model.layers.23.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
627
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
628
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
629
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
630
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
631
+ "model.layers.23.norm_cross.weight": "model-00002-of-00003.safetensors",
632
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
633
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
634
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
635
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
636
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
637
+ "model.layers.24.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
638
+ "model.layers.24.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
639
+ "model.layers.24.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
640
+ "model.layers.24.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
641
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
642
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
643
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
644
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
645
+ "model.layers.24.norm_cross.weight": "model-00002-of-00003.safetensors",
646
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
647
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
648
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
649
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
650
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
651
+ "model.layers.25.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
652
+ "model.layers.25.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
653
+ "model.layers.25.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
654
+ "model.layers.25.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
655
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
656
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
657
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
658
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
659
+ "model.layers.25.norm_cross.weight": "model-00002-of-00003.safetensors",
660
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
661
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
662
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
663
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
664
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
665
+ "model.layers.26.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
666
+ "model.layers.26.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
667
+ "model.layers.26.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
668
+ "model.layers.26.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
669
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00003.safetensors",
670
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
671
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
672
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
673
+ "model.layers.26.norm_cross.weight": "model-00002-of-00003.safetensors",
674
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
675
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
676
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
677
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
678
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
679
+ "model.layers.27.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
680
+ "model.layers.27.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
681
+ "model.layers.27.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
682
+ "model.layers.27.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
683
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00003.safetensors",
684
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
685
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
686
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
687
+ "model.layers.27.norm_cross.weight": "model-00002-of-00003.safetensors",
688
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
689
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
690
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
691
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
692
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
693
+ "model.layers.3.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
694
+ "model.layers.3.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
695
+ "model.layers.3.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
696
+ "model.layers.3.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
697
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
698
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
699
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
700
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
701
+ "model.layers.3.norm_cross.weight": "model-00001-of-00003.safetensors",
702
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
703
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
704
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
705
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
706
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
707
+ "model.layers.4.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
708
+ "model.layers.4.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
709
+ "model.layers.4.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
710
+ "model.layers.4.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
711
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
712
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
713
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
714
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
715
+ "model.layers.4.norm_cross.weight": "model-00001-of-00003.safetensors",
716
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
717
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
718
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
719
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
720
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
721
+ "model.layers.5.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
722
+ "model.layers.5.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
723
+ "model.layers.5.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
724
+ "model.layers.5.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
725
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
726
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
727
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
728
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
729
+ "model.layers.5.norm_cross.weight": "model-00001-of-00003.safetensors",
730
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
731
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
732
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
733
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
734
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
735
+ "model.layers.6.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
736
+ "model.layers.6.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
737
+ "model.layers.6.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
738
+ "model.layers.6.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
739
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
740
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
741
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
742
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
743
+ "model.layers.6.norm_cross.weight": "model-00001-of-00003.safetensors",
744
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
745
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
746
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
747
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
748
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
749
+ "model.layers.7.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
750
+ "model.layers.7.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
751
+ "model.layers.7.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
752
+ "model.layers.7.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
753
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
754
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
755
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
756
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
757
+ "model.layers.7.norm_cross.weight": "model-00001-of-00003.safetensors",
758
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
759
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
760
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
761
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
762
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
763
+ "model.layers.8.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
764
+ "model.layers.8.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
765
+ "model.layers.8.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
766
+ "model.layers.8.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
767
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
768
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
769
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
770
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
771
+ "model.layers.8.norm_cross.weight": "model-00001-of-00003.safetensors",
772
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
773
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
774
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
775
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
776
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
777
+ "model.layers.9.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
778
+ "model.layers.9.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
779
+ "model.layers.9.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
780
+ "model.layers.9.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
781
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
782
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
783
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
784
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
785
+ "model.layers.9.norm_cross.weight": "model-00001-of-00003.safetensors",
786
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
787
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
788
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
789
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
790
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
791
+ "model.norm.weight": "model-00002-of-00003.safetensors"
792
+ }
793
+ }
modeling_helium1_casa.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable
2
+ from typing import cast as type_cast
3
+
4
+ import torch
5
+ from transformers.cache_utils import DynamicCache
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.generation.utils import GenerateOutput
8
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
9
+ Qwen2_5_VisionTransformerPretrainedModel,
10
+ )
11
+
12
+ from .image_encoder import Qwen25VLEncoder
13
+ from .configuration_helium1_casa import Helium1CASAConfig
14
+ from .language_helium1_casa import (
15
+ CausalHeliumOutput,
16
+ Helium1CASAAttention,
17
+ Helium1ForCausalLM,
18
+ Helium1RMSNorm,
19
+ )
20
+
21
+
22
+ def meta_project(
23
+ logits: torch.Tensor | list[torch.Tensor],
24
+ projector: torch.nn.Module,
25
+ norm: torch.nn.Module | None = None,
26
+ ) -> torch.Tensor | list[torch.Tensor]:
27
+ """Projection operation that handles both tensors and list of tensors
28
+
29
+ Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where
30
+ S can be a different sequence length per image)
31
+ """
32
+ split_sizes: list[int] | None = None
33
+ if not isinstance(logits, torch.Tensor):
34
+ split_sizes = [_x.shape[0] for _x in logits]
35
+ logits = torch.cat(logits, dim=0)[None, :, :]
36
+ logits = type_cast(torch.Tensor, logits)
37
+ logits = projector(logits)
38
+
39
+ assert isinstance(logits, torch.Tensor)
40
+ if norm is not None:
41
+ logits = norm(logits)
42
+ if split_sizes is not None:
43
+ return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0))
44
+ return logits
45
+
46
+
47
+ class ImageProjection(torch.nn.Module):
48
+ """Takes in a batch or sequence of images and returns embeddings
49
+ which are then fed to the LM.
50
+
51
+ :param config: KyuteyeConfig object
52
+ :param lm_model_dim: Output dimension (number of channels) for this module
53
+ """
54
+
55
+ def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None:
56
+ super().__init__()
57
+ self.config = config
58
+ self.out_dim = lm_model_dim
59
+ visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
60
+
61
+ self.enc = Qwen25VLEncoder(visual=visual)
62
+ # Projection layer
63
+ self.proj_extra = self.init_proj_module()
64
+ # Output normalizations
65
+ self.norm_extra = Helium1RMSNorm(self.out_dim)
66
+
67
+ def init_proj_module(self) -> torch.nn.Module:
68
+ """Init the project module for the inserted and/or cross-attended image tokens"""
69
+ if self.config.vision_config.out_dim == self.out_dim:
70
+ return torch.nn.Identity()
71
+ return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim)
72
+
73
+ def forward(
74
+ self, x: torch.Tensor | list[torch.Tensor]
75
+ ) -> dict[
76
+ str,
77
+ torch.Tensor | list[torch.Tensor],
78
+ ]:
79
+ """Image embedding mapping
80
+
81
+ :param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors
82
+ with shape (C, H, W) (or (H, W, C) in the case of Qwen)
83
+
84
+ :return: Either a tensor with shape (num_total_image, S, D) or, if images
85
+ can have different seq length, a list of `num_total_images` Tensors with shape
86
+ (S, D)
87
+ """
88
+
89
+ # Apply image encoder
90
+ og_dtype = x[0].dtype
91
+ encoded = self.enc(x)["image_embeds"]
92
+ encoded = [_x.to(og_dtype) for _x in encoded]
93
+ if all(x.shape[0] == encoded[0].shape[0] for x in encoded):
94
+ encoded = torch.stack(encoded, dim=0)
95
+
96
+ # Extra projection
97
+ image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra)
98
+
99
+ # Apply different projection for extra vs cross attended tokens
100
+ return {"image_embeds": image_embeds}
101
+
102
+
103
+ class V2Helium1(Helium1ForCausalLM): # pyright: ignore[reportIncompatibleMethodOverride]
104
+ config_class = Helium1CASAConfig
105
+
106
+ def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
107
+ del kwargs
108
+ super().__init__(config)
109
+ self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim)
110
+
111
+ def get_device(self) -> str:
112
+ """Return the device type of the model"""
113
+ return next(self.parameters()).device.type
114
+
115
+ @property
116
+ def token_dim(self) -> int:
117
+ """Returns the number of dimensions for the token representation"""
118
+ return self.config.hidden_size
119
+
120
+ @property
121
+ def rotary_embed(self) -> Callable:
122
+ """Returns the rotary embedding function of the underlying model"""
123
+ return self.model.rotary_emb
124
+
125
+ def _update_model_kwargs_for_generation(
126
+ self,
127
+ outputs: Any,
128
+ model_kwargs: dict[str, Any],
129
+ is_encoder_decoder: bool = False,
130
+ num_new_tokens: int = 1,
131
+ ):
132
+ """This is required to handle multiple gen calls for subtitles"""
133
+ # Call parent to get default updates
134
+ model_kwargs = super()._update_model_kwargs_for_generation(
135
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
136
+ )
137
+ # Used by prepare_inputs_for_generation
138
+ model_kwargs["__is_first_gen_call__"] = False
139
+ return model_kwargs
140
+
141
+ def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
142
+ self,
143
+ input_ids: torch.Tensor,
144
+ past_key_values: DynamicCache | None = None,
145
+ **kwargs: Any,
146
+ ):
147
+ __is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True)
148
+ if past_key_values is not None and (
149
+ kwargs.get("cache_position") is None
150
+ or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
151
+ ):
152
+ # We're continuing from a cached state
153
+ past_length = past_key_values._seen_tokens
154
+ kwargs["cache_position"] = torch.arange(
155
+ past_length,
156
+ past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
157
+ dtype=torch.long,
158
+ device=input_ids.device,
159
+ )
160
+
161
+ return super().prepare_inputs_for_generation(
162
+ type_cast(torch.LongTensor, input_ids),
163
+ past_key_values=past_key_values,
164
+ **kwargs,
165
+ )
166
+
167
+ def prepare_multimodal_inputs(
168
+ self,
169
+ # text only training
170
+ input_ids: torch.Tensor | None = None,
171
+ inputs_embeds: torch.Tensor | None = None,
172
+ attention_mask: torch.Tensor | None = None,
173
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
174
+ labels: torch.Tensor | None = None,
175
+ # image values
176
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
177
+ pre_image_tokens: list[int] | None = None,
178
+ post_image_tokens: list[int] | None = None,
179
+ **_kwargs: Any,
180
+ ) -> dict:
181
+ """Get a batch data mixing text and image data"""
182
+ del _kwargs
183
+
184
+ processed_inputs = {
185
+ "input_ids": input_ids,
186
+ "inputs_embeds": inputs_embeds,
187
+ "labels": labels,
188
+ "attention_mask": attention_mask,
189
+ "image_embeds_insertion_points": image_embeds_insertion_points,
190
+ }
191
+ if pixel_values is not None:
192
+ processed_inputs.update(self.image_prefix(pixel_values))
193
+ assert "image_embeds" in processed_inputs
194
+ assert (
195
+ isinstance(processed_inputs["image_embeds"], torch.Tensor)
196
+ and processed_inputs["image_embeds"].ndim == 3
197
+ ) or (
198
+ isinstance(processed_inputs["image_embeds"], list)
199
+ and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
200
+ )
201
+
202
+ # Add kwargs necessary to compute cu_seqlens windows for CASA
203
+ processed_inputs["casa_windows_info"] = {
204
+ "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
205
+ "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
206
+ }
207
+
208
+ return processed_inputs
209
+
210
+ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
211
+ self,
212
+ input_ids: torch.Tensor | None = None,
213
+ inputs_embeds: torch.Tensor | None = None,
214
+ attention_mask: torch.Tensor | None = None,
215
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
216
+ return_loss: bool = True,
217
+ labels: torch.Tensor | None = None,
218
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
219
+ pre_image_tokens: list[int] | None = None,
220
+ post_image_tokens: list[int] | None = None,
221
+ **kwargs: Any,
222
+ ) -> CausalHeliumOutput:
223
+ """Multi modal forward pass"""
224
+ assert input_ids is not None or inputs_embeds is not None
225
+
226
+ if self.training:
227
+ assert return_loss is True, (
228
+ "Helium models always compute its own labels/losses in train mode"
229
+ )
230
+
231
+ # Case 1: For first generation call we need to compute pixel values and CASA states
232
+ if kwargs.get("__is_first_gen_call__", True):
233
+ processed_inputs = self.prepare_multimodal_inputs(
234
+ input_ids=input_ids,
235
+ inputs_embeds=inputs_embeds,
236
+ attention_mask=attention_mask,
237
+ image_embeds_insertion_points=image_embeds_insertion_points,
238
+ pixel_values=pixel_values,
239
+ labels=labels,
240
+ pre_image_tokens=pre_image_tokens,
241
+ post_image_tokens=post_image_tokens,
242
+ )
243
+ processed_inputs.pop("inputs_embeds", None)
244
+ else:
245
+ processed_inputs = {
246
+ "inputs_embeds": self.model.embed_tokens(input_ids),
247
+ "attention_mask": attention_mask,
248
+ }
249
+
250
+ # For Helium prefix, we need to update the positions by the number
251
+ # of image tokens inserted in the first call
252
+ if (
253
+ not self.config.casa_attention
254
+ and (cp := kwargs.get("cache_position", None)) is not None
255
+ and pixel_values is not None
256
+ ):
257
+ start = kwargs["cache_position"][0].item()
258
+ num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4
259
+ num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # type: ignore
260
+ kwargs["cache_position"] = torch.arange(
261
+ start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens),
262
+ start + num_tokens + num_image_tokens,
263
+ dtype=cp.dtype,
264
+ device=cp.device,
265
+ )
266
+
267
+ kwargs.pop("__is_first_gen_call__", True)
268
+ out = super().forward(
269
+ **processed_inputs, # type: ignore
270
+ **kwargs,
271
+ )
272
+
273
+ return out
274
+
275
+ @torch.no_grad()
276
+ def generate_from_image( # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride]
277
+ self,
278
+ input_ids: torch.Tensor | None = None,
279
+ inputs_embeds: torch.Tensor | None = None,
280
+ attention_mask: torch.Tensor | None = None,
281
+ image_embeds_insertion_points: list[torch.Tensor] | None = None,
282
+ pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
283
+ reset_streaming: bool = True,
284
+ **kwargs: Any,
285
+ ) -> "GenerateOutput | torch.LongTensor":
286
+ assert input_ids is not None and inputs_embeds is None, (
287
+ "Input IDs must be provided for generation"
288
+ )
289
+
290
+ # init self-attention KVCache
291
+ if kwargs.get("past_key_values", None) is None:
292
+ kwargs["past_key_values"] = DynamicCache()
293
+
294
+ # To avoid generate warning
295
+ if kwargs.get("pad_token_id", None) is None:
296
+ kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
297
+ if isinstance(kwargs["pad_token_id"], (list, tuple)):
298
+ kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
299
+
300
+ self.start_casa_streaming_states()
301
+ outputs = self.generate(
302
+ input_ids,
303
+ attention_mask=attention_mask,
304
+ pixel_values=pixel_values,
305
+ image_embeds_insertion_points=image_embeds_insertion_points,
306
+ use_cache=True,
307
+ **kwargs,
308
+ )
309
+ if reset_streaming:
310
+ self.reset_casa_streaming_states()
311
+ return outputs
312
+
313
+ def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
314
+ def __reset__(m: torch.nn.Module):
315
+ if isinstance(m, Helium1CASAAttention):
316
+ m._set_streaming(False, ())
317
+ m.reset_streaming()
318
+ if clean_cache:
319
+ del m.streaming_state.k
320
+ del m.streaming_state.v
321
+ del m.streaming_state.casa_handler
322
+
323
+ self.apply(__reset__)
324
+
325
+ def start_casa_streaming_states(self) -> None:
326
+ def __start__(m: torch.nn.Module):
327
+ if isinstance(m, Helium1CASAAttention):
328
+ m._set_streaming(True, ())
329
+
330
+ self.apply(__start__)
processing.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=no-member # avoid weird pylint warnings from SentencePieceProcessor
2
+ """Text and Image processor for CASA models using Qwen2.5_VL image encoder"""
3
+
4
+ from math import ceil
5
+ from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload
6
+ from typing import cast as type_cast
7
+
8
+ import torch
9
+ import torchvision.transforms.v2 as T
10
+ from einops import rearrange
11
+ from PIL import Image
12
+ from torchvision.transforms import InterpolationMode
13
+ from torchvision.transforms.functional import to_tensor as pil_to_tensor
14
+ from torchvision.transforms.v2 import functional as F
15
+ from transformers.image_processing_utils import BaseImageProcessor
16
+ from transformers.processing_utils import ProcessorMixin
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+
22
+
23
+ ImageMessage = TypedDict(
24
+ "ImageMessage",
25
+ {
26
+ "type": Literal["image"],
27
+ "image": str | Image.Image | None,
28
+ },
29
+ )
30
+
31
+ TextMessage = TypedDict(
32
+ "TextMessage",
33
+ {
34
+ "type": Literal["text"],
35
+ "text": str,
36
+ },
37
+ )
38
+
39
+ MessageContent = list[ImageMessage | TextMessage]
40
+
41
+ Message = TypedDict(
42
+ "Message",
43
+ {
44
+ "role": Literal["system", "user", "assistant"],
45
+ "content": MessageContent,
46
+ },
47
+ )
48
+
49
+ ProcessorInput = list[list[Message]] | list[Message]
50
+
51
+ __INTERP_NAME_TO_MODE__ = {
52
+ "nearest": InterpolationMode.NEAREST,
53
+ "bilinear": InterpolationMode.BILINEAR,
54
+ "bicubic": InterpolationMode.BICUBIC,
55
+ "lanczos": InterpolationMode.LANCZOS,
56
+ }
57
+
58
+ __INTERP_INT_TO_MODE__ = {
59
+ 0: InterpolationMode.NEAREST,
60
+ 2: InterpolationMode.BILINEAR,
61
+ 3: InterpolationMode.BICUBIC,
62
+ 4: InterpolationMode.BOX,
63
+ 5: InterpolationMode.HAMMING,
64
+ 1: InterpolationMode.LANCZOS,
65
+ }
66
+
67
+
68
+ @overload
69
+ def universal_resize(
70
+ img: Image.Image,
71
+ size: tuple[int, int],
72
+ interpolation: str | InterpolationMode | int = "bilinear",
73
+ antialias: bool = True,
74
+ ) -> Image.Image: ...
75
+ @overload
76
+ def universal_resize(
77
+ img: torch.Tensor,
78
+ size: tuple[int, int],
79
+ interpolation: str | InterpolationMode | int = "bilinear",
80
+ antialias: bool = True,
81
+ ) -> torch.Tensor: ...
82
+ def universal_resize(
83
+ img: Image.Image | torch.Tensor,
84
+ size: tuple[int, int],
85
+ interpolation: str | InterpolationMode | int = "bilinear",
86
+ antialias: bool = True,
87
+ ) -> Image.Image | torch.Tensor:
88
+ """Resize that works for PIL.Image, CHW tensor, or BCHW tensor"""
89
+ if isinstance(interpolation, str):
90
+ interpolation = __INTERP_NAME_TO_MODE__[interpolation]
91
+ elif isinstance(interpolation, int):
92
+ interpolation = __INTERP_INT_TO_MODE__[interpolation]
93
+
94
+ return F.resize(
95
+ img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias
96
+ )
97
+
98
+
99
+ @overload
100
+ def convert_to_rgb(img: Image.Image) -> Image.Image: ...
101
+ @overload
102
+ def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ...
103
+ def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor:
104
+ """Convert any image to RGB in a way that does not throw PIL warning"""
105
+ if isinstance(img, torch.Tensor):
106
+ return img
107
+ if img.mode == "RGB": # no changes
108
+ return img
109
+ if img.mode == "P": # palette images need to be converted to RGBA first
110
+ return img.convert("RGBA").convert("RGB")
111
+ return img.convert("RGB")
112
+
113
+
114
+ class QwenImageProcessor(BaseImageProcessor):
115
+ """Resizing for the Qwen2.5VL encoder. Note that the normalization is
116
+ handled in the image_encoder in the model forward"""
117
+
118
+ def __init__(
119
+ self,
120
+ img_size: int = 448,
121
+ interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic",
122
+ max_ratio: int = 10,
123
+ round_to_patch_size: int = 56,
124
+ use_fast: bool = True,
125
+ **kwargs: Any,
126
+ ) -> None:
127
+ # this will also be used in V2llms to determine whether to remove
128
+ # the temporal conv
129
+ self._num_target_channels = 588
130
+ self._merge_size = 2
131
+ self._patch_size = 14
132
+ super().__init__(
133
+ use_fast=use_fast,
134
+ do_normalize=False,
135
+ **kwargs,
136
+ )
137
+ self.img_size = img_size
138
+ self.interpolation = interpolation
139
+ self.max_ratio = max_ratio
140
+ self.round_to_patch_size = round_to_patch_size
141
+
142
+ def resize_transform(
143
+ self, img: Image.Image | torch.Tensor, img_size: int | None = None
144
+ ) -> Image.Image | torch.Tensor:
145
+ if img_size is None:
146
+ img_size = self.img_size
147
+ max_area = img_size**2
148
+ if isinstance(img, Image.Image):
149
+ img = convert_to_rgb(img)
150
+ w_og, h_og = img.size
151
+ else:
152
+ h_og, w_og = img.shape[-2:]
153
+ w, h = w_og, h_og
154
+
155
+ # Qwen requires max ratio of 10 between max and min sizes
156
+ if self.max_ratio > 0:
157
+ w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio)
158
+
159
+ # resize to max area
160
+ current_area = w * h
161
+ if current_area > max_area:
162
+ scale = (max_area / current_area) ** 0.5
163
+ w, h = int(w * scale), int(h * scale)
164
+
165
+ # resize to patch size
166
+ if self.round_to_patch_size > 0:
167
+ w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size
168
+ h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size
169
+
170
+ # resize
171
+ if w != w_og or h != h_og:
172
+ img = universal_resize(img, (h, w), self.interpolation)
173
+ if isinstance(img, torch.Tensor):
174
+ img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img))
175
+ return img
176
+
177
+ def __process_one__(
178
+ self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None
179
+ ) -> torch.Tensor:
180
+ """Same operation as __process_one_with_processor__ but without going through numpy"""
181
+ video_or_img = self.resize_transform(video_or_img, img_size)
182
+ if isinstance(video_or_img, Image.Image):
183
+ video_or_img = pil_to_tensor(video_or_img)
184
+ assert isinstance(video_or_img, torch.Tensor)
185
+ if video_or_img.ndim == 3:
186
+ video_or_img = video_or_img[None]
187
+ assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, (
188
+ f"Invalid shape {video_or_img.shape}."
189
+ )
190
+ t, c, h, w = video_or_img.shape
191
+ p = self._patch_size
192
+ m = self._merge_size
193
+
194
+ # Convert to RGB
195
+ if c == 1:
196
+ video_or_img = video_or_img.expand((-1, 3, -1, -1))
197
+ if c == 4:
198
+ video_or_img = video_or_img[:, :3]
199
+ c = video_or_img.shape[1]
200
+ assert c == 3, "Expecting RGB image in QwenNormalize"
201
+
202
+ # Reshape to t h w c' format
203
+ h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p
204
+ rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m)
205
+
206
+ video_or_img = rearrange(
207
+ video_or_img,
208
+ "t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)",
209
+ **rearrange_dict,
210
+ )
211
+ assert video_or_img.shape[-1] == self._num_target_channels, (
212
+ f"{video_or_img.shape[-1]} != {self._num_target_channels}"
213
+ )
214
+ video_or_img = video_or_img.view((-1, h, w, self._num_target_channels))
215
+
216
+ return video_or_img
217
+
218
+ @overload
219
+ def process_images(
220
+ self, image: Image.Image | torch.Tensor, img_size: int | None = None
221
+ ) -> torch.Tensor: ...
222
+ @overload
223
+ def process_images(
224
+ self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None
225
+ ) -> list[torch.Tensor]: ...
226
+ def process_images(
227
+ self,
228
+ image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor],
229
+ img_size: int | None = None,
230
+ ) -> torch.Tensor | list[torch.Tensor]:
231
+ if isinstance(image, list):
232
+ return [self.__process_one__(_x, img_size) for _x in image]
233
+ return self.__process_one__(image, img_size)
234
+
235
+
236
+ class ProcessorOutput(dict):
237
+ input_ids: torch.Tensor
238
+ attention_mask: torch.Tensor
239
+ image_embeds_insertion_points: list[torch.Tensor] | None
240
+ pixel_values: torch.Tensor | list[torch.Tensor] | None
241
+
242
+ def to(
243
+ self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16
244
+ ) -> "ProcessorOutput":
245
+ return ProcessorOutput(
246
+ {
247
+ "input_ids": self["input_ids"].to(device),
248
+ "attention_mask": self["attention_mask"].to(device),
249
+ "image_embeds_insertion_points": self["image_embeds_insertion_points"],
250
+ "pixel_values": (
251
+ self["pixel_values"].to(dtype).to(device)
252
+ if isinstance(self["pixel_values"], torch.Tensor)
253
+ else [x.to(dtype).to(device) for x in self["pixel_values"]]
254
+ if self["pixel_values"] is not None
255
+ else None
256
+ ),
257
+ }
258
+ )
259
+
260
+
261
+ class BaseProcessor(ProcessorMixin):
262
+ def __init__(
263
+ self,
264
+ tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer",
265
+ pre_image_tokens: tuple[int, ...] = (),
266
+ post_image_tokens: tuple[int, ...] = (),
267
+ system_start_tokens: tuple[int, ...] = (),
268
+ system_end_tokens: tuple[int, ...] = (),
269
+ user_start_tokens: tuple[int, ...] = (),
270
+ user_end_tokens: tuple[int, ...] = (),
271
+ asst_start_tokens: tuple[int, ...] = (),
272
+ asst_end_tokens: tuple[int, ...] = (),
273
+ allow_system_prompt: bool = True,
274
+ pad_token: int = 0,
275
+ bos_token: int | None = None,
276
+ ) -> None:
277
+ self.pre_image_tokens = list(pre_image_tokens)
278
+ self.post_image_tokens = list(post_image_tokens)
279
+ self.system_start_tokens = list(system_start_tokens)
280
+ self.system_end_tokens = list(system_end_tokens)
281
+ self.user_start_tokens = list(user_start_tokens)
282
+ self.user_end_tokens = list(user_end_tokens)
283
+ self.asst_start_tokens = list(asst_start_tokens)
284
+ self.asst_end_tokens = list(asst_end_tokens)
285
+ self._allow_system_prompt = allow_system_prompt
286
+ self.tokenizer = tokenizer
287
+ self._image_processor = None
288
+ self._pad_token = pad_token
289
+ self.bos_token = bos_token
290
+
291
+ @property
292
+ def image_processor(self) -> QwenImageProcessor:
293
+ assert self._image_processor is not None
294
+ return self._image_processor
295
+
296
+ def _process_content(
297
+ self,
298
+ message_content: MessageContent,
299
+ role: Literal["system", "user", "assistant"],
300
+ tokenized_messages: list[torch.Tensor],
301
+ insertion_points: list[int],
302
+ image_list: list[torch.Tensor | None],
303
+ token_count: int,
304
+ img_size: int | None = None,
305
+ **kwargs: Any,
306
+ ) -> int:
307
+ mapping = {
308
+ "user": (self.user_start_tokens, self.user_end_tokens),
309
+ "assistant": (self.asst_start_tokens, self.asst_end_tokens),
310
+ "system": (self.system_start_tokens, self.system_end_tokens),
311
+ }
312
+ if role.lower() not in mapping:
313
+ raise ValueError(f"Unknown role '{role}' encountered in messages.")
314
+ start_tokens, end_tokens = mapping[role.lower()]
315
+ # 1) Add the start tokens
316
+ if start_tokens:
317
+ tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long))
318
+ token_count += len(start_tokens)
319
+ # 2) Process the message content one by one (potentially interleaved image and text)
320
+ for part in message_content:
321
+ elt_type = part["type"]
322
+ if elt_type == "image":
323
+ part = cast(ImageMessage, part)
324
+ self._process_image_message(
325
+ part,
326
+ tokenized_messages,
327
+ image_list,
328
+ img_size=img_size,
329
+ )
330
+ token_count += len(self.pre_image_tokens)
331
+ insertion_points.append(token_count)
332
+ token_count += len(self.post_image_tokens)
333
+ else:
334
+ part = cast(TextMessage, part)
335
+ self._process_text_message(
336
+ part["text"],
337
+ role=role,
338
+ token_list=tokenized_messages,
339
+ **kwargs,
340
+ )
341
+ token_count += tokenized_messages[-1].size(0)
342
+ # 3) Add the end tokens
343
+ if end_tokens:
344
+ tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long))
345
+ token_count += len(end_tokens)
346
+ return token_count
347
+
348
+ def _process_text_message(
349
+ self,
350
+ message: str,
351
+ role: Literal["system", "user", "assistant"],
352
+ token_list: list[torch.Tensor],
353
+ **kwargs: Any,
354
+ ) -> None:
355
+ if role.lower() == "system" and not self._allow_system_prompt:
356
+ raise ValueError("System prompts are not allowed in this tokenizer configuration.")
357
+ tokens = self.tokenizer.encode(
358
+ message, add_special_tokens=False, return_tensors="pt", **kwargs
359
+ )
360
+ tokens = cast(torch.Tensor, tokens)
361
+ token_list.append(tokens.flatten().to(torch.long))
362
+
363
+ def _process_image_message(
364
+ self,
365
+ message: ImageMessage,
366
+ token_list: list[torch.Tensor],
367
+ image_list: list[torch.Tensor | None],
368
+ img_size: int | None = None,
369
+ ) -> None:
370
+ img = message["image"]
371
+ if img is None:
372
+ image_list.append(None)
373
+ else:
374
+ image_list.append(
375
+ self.image_processor.process_images(
376
+ self._load_image(img), img_size=img_size
377
+ ).squeeze(0)
378
+ )
379
+ if self.pre_image_tokens:
380
+ token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long))
381
+
382
+ if self.post_image_tokens:
383
+ token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long))
384
+
385
+ def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image:
386
+ if isinstance(image_path_or_image, str):
387
+ return Image.open(image_path_or_image).convert("RGB")
388
+ return image_path_or_image
389
+
390
+ def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor:
391
+ return torch.nn.functional.pad(
392
+ tokens,
393
+ (0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0),
394
+ value=pad_value,
395
+ )
396
+
397
+ def pad_tokenized_messages(
398
+ self,
399
+ tokenized_messages_batch: list[torch.Tensor],
400
+ image_insertion_points_batch: list[torch.Tensor] | None = None,
401
+ ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
402
+ max_len = max(len(x) for x in tokenized_messages_batch)
403
+ if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left":
404
+ image_insertion_points_batch = [
405
+ x + max_len - len(tokenized_messages_batch[idx])
406
+ for idx, x in enumerate(image_insertion_points_batch)
407
+ ]
408
+ input_ids = torch.stack(
409
+ [
410
+ self._maybe_pad(s, max_len - s.size(0), self._pad_token)
411
+ for s in tokenized_messages_batch
412
+ ],
413
+ dim=0,
414
+ )
415
+ attention_mask = torch.stack(
416
+ [
417
+ self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0)
418
+ for s in tokenized_messages_batch
419
+ ],
420
+ dim=0,
421
+ )
422
+ return input_ids, attention_mask, image_insertion_points_batch
423
+
424
+ def tokenize_messages(
425
+ self,
426
+ messages: ProcessorInput,
427
+ suppress_bos_token: bool = False,
428
+ **kwargs: Any,
429
+ ) -> ProcessorOutput | None:
430
+ """Tokenize a batch of messages into token IDs suitable for Helium1 CASA model.
431
+
432
+ Args:
433
+ messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages),
434
+ where each message is a list of dictionaries with 'role' and 'content' keys.
435
+ continue_final_message (bool, optional): If True, the final message in each list will not have an end token added.
436
+ Defaults to False.
437
+ suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added.
438
+ Defaults to False.
439
+ **kwargs: Additional keyword arguments passed to the underlying encode method.
440
+ """
441
+ if not messages:
442
+ return None
443
+ if isinstance(messages[0], dict):
444
+ messages = [messages] # type: ignore[assignment]
445
+
446
+ messages = cast(list[list[Message]], messages)
447
+ image_insertion_points_batch = []
448
+ tokenized_messages_batch = []
449
+ image_list: list[torch.Tensor | None] = []
450
+ for msgs in messages:
451
+ # msgs.append({
452
+ # "role": "assistant",
453
+ # "content": [{"type": "text", "text": ""}]
454
+ # })
455
+ tokenized_messages = []
456
+ if not suppress_bos_token and self.bos_token is not None:
457
+ tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long))
458
+ insertion_points = []
459
+ token_count = 0
460
+ for msg in msgs:
461
+ token_count = self._process_content(
462
+ msg["content"],
463
+ role=msg["role"],
464
+ tokenized_messages=tokenized_messages,
465
+ insertion_points=insertion_points,
466
+ image_list=image_list,
467
+ token_count=token_count,
468
+ **kwargs,
469
+ )
470
+ tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long))
471
+ image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long))
472
+
473
+ if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant":
474
+ # Remove the assistant end tokens from the final message
475
+ end_token_len = len(self.asst_end_tokens)
476
+ tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len]
477
+ if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user":
478
+ # Remove the assistant end tokens from the final message
479
+ end_token_len = len(self.asst_end_tokens)
480
+ tokenized_messages_batch[-1] = torch.cat(
481
+ [
482
+ tokenized_messages_batch[-1],
483
+ torch.Tensor(self.asst_start_tokens).to(torch.long),
484
+ ]
485
+ )
486
+
487
+ input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages(
488
+ tokenized_messages_batch, image_insertion_points_batch
489
+ )
490
+
491
+ if image_list:
492
+ assert sum(img is None for img in image_list) % len(image_list) == 0, (
493
+ "Either all or no image must be None."
494
+ )
495
+ pixel_values: None | torch.Tensor | list[torch.Tensor]
496
+ if image_list[0] is None:
497
+ pixel_values = None
498
+ else:
499
+ pixel_values = cast(list[torch.Tensor], image_list)
500
+ return ProcessorOutput(
501
+ input_ids=input_ids,
502
+ image_embeds_insertion_points=image_embeds_insertion_points,
503
+ attention_mask=attention_mask,
504
+ pixel_values=pixel_values,
505
+ )
processing_helium1_casa.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
2
+
3
+ from .processing import BaseProcessor, QwenImageProcessor
4
+
5
+
6
+ class Helium1CASAProcessor(BaseProcessor):
7
+ attributes = ["tokenizer"]
8
+ tokenizer_class = "PreTrainedTokenizerFast"
9
+
10
+ def __init__(
11
+ self,
12
+ tokenizer: PreTrainedTokenizerFast,
13
+ pre_image_tokens: tuple[int, ...] = tuple(),
14
+ post_image_tokens: tuple[int, ...] = tuple(),
15
+ system_start_tokens: tuple[int, ...] = tuple(),
16
+ system_end_tokens: tuple[int, ...] = tuple(),
17
+ user_start_tokens: tuple[int, ...] = (104,),
18
+ user_end_tokens: tuple[int, ...] = (105,),
19
+ asst_start_tokens: tuple[int, ...] = (102,),
20
+ asst_end_tokens: tuple[int, ...] = (103,),
21
+ bos_token: int = 1,
22
+ image_size: int = 896,
23
+ ):
24
+ super().__init__(
25
+ tokenizer=tokenizer,
26
+ pre_image_tokens=pre_image_tokens,
27
+ post_image_tokens=post_image_tokens,
28
+ system_start_tokens=system_start_tokens,
29
+ system_end_tokens=system_end_tokens,
30
+ user_start_tokens=user_start_tokens,
31
+ user_end_tokens=user_end_tokens,
32
+ asst_start_tokens=asst_start_tokens,
33
+ asst_end_tokens=asst_end_tokens,
34
+ allow_system_prompt=False,
35
+ bos_token=bos_token,
36
+ )
37
+ self._image_processor = QwenImageProcessor(img_size=image_size)
processor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_helium1_casa.Helium1CASAProcessor"
4
+ },
5
+ "bos_token": 1,
6
+ "image_size": 896,
7
+ "post_image_tokens": [],
8
+ "pre_image_tokens": [],
9
+ "processor_class": "Helium1CASAProcessor"
10
+ }
readme_images/CASA.png ADDED

Git LFS Details

  • SHA256: 388d6098d61c64dfee303411d93440dd00a8371806055af3a225aeb70590f746
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
readme_images/casa_explainer.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5ee73e4e8ea65ebc8d3e53d6468a3d2688d5f656259bcae73bffd843e5a0e69
3
+ size 299297
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html class="">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no" />
7
+
8
+ <meta name="description" content="We’re on a journey to advance and democratize artificial intelligence through open source and open science." />
9
+
10
+ <meta property="fb:app_id" content="1321688464574422" />
11
+
12
+ <meta name="twitter:card" content="summary_large_image" />
13
+
14
+ <meta name="twitter:site" content="@huggingface" />
15
+
16
+ <meta name="twitter:image" content="https://cdn-thumbnails.huggingface.co/social-thumbnails/models/kyutai/helium-1-2b.png" />
17
+
18
+ <meta property="og:title" content="tokenizer.model · kyutai/helium-1-2b at main" />
19
+
20
+ <meta property="og:type" content="website" />
21
+
22
+ <meta property="og:url" content="https://huggingface.co/kyutai/helium-1-2b/blob/main/tokenizer.model" />
23
+
24
+ <meta property="og:image" content="https://cdn-thumbnails.huggingface.co/social-thumbnails/models/kyutai/helium-1-2b.png" />
25
+
26
+ <link rel="stylesheet" href="/front/build/kube-02d86c8/style.css" />
27
+
28
+ <link rel="preconnect" href="https://fonts.gstatic.com" />
29
+
30
+ <link
31
+ href="https://fonts.googleapis.com/css2?family=Source+Sans+Pro:ital,wght@0,200;0,300;0,400;0,600;0,700;1,200;1,300;1,400;1,600;1,700&display=swap"
32
+ rel="stylesheet"
33
+ />
34
+
35
+ <link
36
+ href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600;700&display=swap"
37
+ rel="stylesheet"
38
+ />
39
+
40
+ <link
41
+ rel="preload"
42
+ href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.12.0/katex.min.css"
43
+ as="style"
44
+ onload="this.onload=null;this.rel='stylesheet'"
45
+ />
46
+
47
+ <noscript>
48
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.12.0/katex.min.css" />
49
+ </noscript>
50
+ <script>const guestTheme = document.cookie.match(/theme=(\w+)/)?.[1]; document.documentElement.classList.toggle('dark', guestTheme === 'dark' || ( (!guestTheme || guestTheme === 'system') && window.matchMedia('(prefers-color-scheme: dark)').matches));</script>
51
+ <link rel="canonical" href="https://huggingface.co/kyutai/helium-1-2b/blob/main/tokenizer.model">
52
+ <title>tokenizer.model · kyutai/helium-1-2b at main</title>
53
+
54
+ <script defer src="/js/script.js"></script>
55
+
56
+ <script>
57
+ (window.plausible =
58
+ window.plausible ||
59
+ function () {
60
+ (plausible.q = plausible.q || []).push(arguments);
61
+ }),
62
+ (plausible.init =
63
+ plausible.init ||
64
+ function (i) {
65
+ plausible.o = i || {};
66
+ });
67
+ plausible.init({
68
+ customProperties: {
69
+ loggedIn: "false",
70
+ },
71
+ endpoint: "/api/event",
72
+ });
73
+ </script>
74
+
75
+ <script>
76
+ window.hubConfig = {"features":{"signupDisabled":false},"sshGitUrl":"git@hf.co","moonHttpUrl":"https:\/\/huggingface.co","captchaApiKey":"bd5f2066-93dc-4bdd-a64b-a24646ca3859","datasetViewerPublicUrl":"https:\/\/datasets-server.huggingface.co","stripePublicKey":"pk_live_x2tdjFXBCvXo2FFmMybezpeM00J6gPCAAc","environment":"production","userAgent":"HuggingFace (production)","spacesIframeDomain":"hf.space","spacesApiUrl":"https:\/\/api.hf.space","docSearchKey":"ece5e02e57300e17d152c08056145326e90c4bff3dd07d7d1ae40cf1c8d39cb6","logoDev":{"apiUrl":"https:\/\/img.logo.dev\/","apiKey":"pk_UHS2HZOeRnaSOdDp7jbd5w"}};
77
+ </script>
78
+ <script type="text/javascript" src="https://de5282c3ca0c.edge.sdk.awswaf.com/de5282c3ca0c/526cf06acb0d/challenge.js" defer></script>
79
+ </head>
80
+ <body class="flex flex-col min-h-dvh bg-white dark:bg-gray-950 text-black ViewerBlobPage">
81
+ <div class="flex min-h-dvh flex-col"><div class="SVELTE_HYDRATER contents" data-target="DeviceProvider" data-props="{}"></div>
82
+ <div class="SVELTE_HYDRATER contents" data-target="SystemThemeMonitor" data-props="{&quot;isLoggedIn&quot;:false}"></div>
83
+
84
+ <div class="SVELTE_HYDRATER contents" data-target="MainHeader" data-props="{&quot;classNames&quot;:&quot;&quot;,&quot;isWide&quot;:false,&quot;isZh&quot;:false,&quot;isPro&quot;:false}"><header class="border-b border-gray-100 "><div class="w-full px-4 container flex h-16 items-center"><div class="flex flex-1 items-center"><a class="mr-5 flex flex-none items-center lg:mr-6" href="/"><img alt="Hugging Face's logo" class="w-7 md:mr-2" src="/front/assets/huggingface_logo-noborder.svg">
85
+ <span class="hidden whitespace-nowrap text-lg font-bold md:block">Hugging Face</span></a>
86
+ <div class="relative flex-1 lg:max-w-sm mr-2 sm:mr-4 md:mr-3 xl:mr-6"><input autocomplete="off" class="w-full dark:bg-gray-950 pl-8 form-input-alt h-9 pr-3 focus:shadow-xl " name="" placeholder="Search models, datasets, users..." spellcheck="false" type="text" value="">
87
+ <svg class="absolute left-2.5 text-gray-400 top-1/2 transform -translate-y-1/2" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M30 28.59L22.45 21A11 11 0 1 0 21 22.45L28.59 30zM5 14a9 9 0 1 1 9 9a9 9 0 0 1-9-9z" fill="currentColor"></path></svg>
88
+ </div>
89
+ <div class="flex flex-none items-center justify-center p-0.5 place-self-stretch lg:hidden"><button class="relative z-40 flex h-6 w-8 items-center justify-center" type="button"><svg width="1em" height="1em" viewBox="0 0 10 10" class="text-xl" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" preserveAspectRatio="xMidYMid meet" fill="currentColor"><path fill-rule="evenodd" clip-rule="evenodd" d="M1.65039 2.9999C1.65039 2.8066 1.80709 2.6499 2.00039 2.6499H8.00039C8.19369 2.6499 8.35039 2.8066 8.35039 2.9999C8.35039 3.1932 8.19369 3.3499 8.00039 3.3499H2.00039C1.80709 3.3499 1.65039 3.1932 1.65039 2.9999ZM1.65039 4.9999C1.65039 4.8066 1.80709 4.6499 2.00039 4.6499H8.00039C8.19369 4.6499 8.35039 4.8066 8.35039 4.9999C8.35039 5.1932 8.19369 5.3499 8.00039 5.3499H2.00039C1.80709 5.3499 1.65039 5.1932 1.65039 4.9999ZM2.00039 6.6499C1.80709 6.6499 1.65039 6.8066 1.65039 6.9999C1.65039 7.1932 1.80709 7.3499 2.00039 7.3499H8.00039C8.19369 7.3499 8.35039 7.1932 8.35039 6.9999C8.35039 6.8066 8.19369 6.6499 8.00039 6.6499H2.00039Z"></path></svg>
90
+ </button>
91
+
92
+ </div></div>
93
+ <nav aria-label="Main" class="ml-auto hidden lg:block"><ul class="flex items-center gap-x-1 2xl:gap-x-2"><li class="hover:text-indigo-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/models"><svg class="mr-1.5 text-gray-400 group-hover:text-indigo-500" style="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-quaternary" d="M20.23 7.24L12 12L3.77 7.24a1.98 1.98 0 0 1 .7-.71L11 2.76c.62-.35 1.38-.35 2 0l6.53 3.77c.29.173.531.418.7.71z" opacity=".25" fill="currentColor"></path><path class="uim-tertiary" d="M12 12v9.5a2.09 2.09 0 0 1-.91-.21L4.5 17.48a2.003 2.003 0 0 1-1-1.73v-7.5a2.06 2.06 0 0 1 .27-1.01L12 12z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M20.5 8.25v7.5a2.003 2.003 0 0 1-1 1.73l-6.62 3.82c-.275.13-.576.198-.88.2V12l8.23-4.76c.175.308.268.656.27 1.01z" fill="currentColor"></path></svg>
94
+ Models</a>
95
+ </li><li class="hover:text-red-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/datasets"><svg class="mr-1.5 text-gray-400 group-hover:text-red-500" style="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 25 25"><ellipse cx="12.5" cy="5" fill="currentColor" fill-opacity="0.25" rx="7.5" ry="2"></ellipse><path d="M12.5 15C16.6421 15 20 14.1046 20 13V20C20 21.1046 16.6421 22 12.5 22C8.35786 22 5 21.1046 5 20V13C5 14.1046 8.35786 15 12.5 15Z" fill="currentColor" opacity="0.5"></path><path d="M12.5 7C16.6421 7 20 6.10457 20 5V11.5C20 12.6046 16.6421 13.5 12.5 13.5C8.35786 13.5 5 12.6046 5 11.5V5C5 6.10457 8.35786 7 12.5 7Z" fill="currentColor" opacity="0.5"></path><path d="M5.23628 12C5.08204 12.1598 5 12.8273 5 13C5 14.1046 8.35786 15 12.5 15C16.6421 15 20 14.1046 20 13C20 12.8273 19.918 12.1598 19.7637 12C18.9311 12.8626 15.9947 13.5 12.5 13.5C9.0053 13.5 6.06886 12.8626 5.23628 12Z" fill="currentColor"></path></svg>
96
+ Datasets</a>
97
+ </li><li class="hover:text-blue-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/spaces"><svg class="mr-1.5 text-gray-400 group-hover:text-blue-500" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" viewBox="0 0 25 25"><path opacity=".5" d="M6.016 14.674v4.31h4.31v-4.31h-4.31ZM14.674 14.674v4.31h4.31v-4.31h-4.31ZM6.016 6.016v4.31h4.31v-4.31h-4.31Z" fill="currentColor"></path><path opacity=".75" fill-rule="evenodd" clip-rule="evenodd" d="M3 4.914C3 3.857 3.857 3 4.914 3h6.514c.884 0 1.628.6 1.848 1.414a5.171 5.171 0 0 1 7.31 7.31c.815.22 1.414.964 1.414 1.848v6.514A1.914 1.914 0 0 1 20.086 22H4.914A1.914 1.914 0 0 1 3 20.086V4.914Zm3.016 1.102v4.31h4.31v-4.31h-4.31Zm0 12.968v-4.31h4.31v4.31h-4.31Zm8.658 0v-4.31h4.31v4.31h-4.31Zm0-10.813a2.155 2.155 0 1 1 4.31 0 2.155 2.155 0 0 1-4.31 0Z" fill="currentColor"></path><path opacity=".25" d="M16.829 6.016a2.155 2.155 0 1 0 0 4.31 2.155 2.155 0 0 0 0-4.31Z" fill="currentColor"></path></svg>
98
+ Spaces</a>
99
+ </li><li class="max-xl:hidden relative"><div class="relative ">
100
+ <button class="group flex items-center px-2 py-0.5 dark:text-gray-300 hover:text-yellow-700 dark:hover:text-gray-100 " type="button">
101
+ <svg class="mr-1.5 mr-1.5 text-gray-400 text-yellow-500! group-hover:text-yellow-500" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path><path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path></svg>
102
+ Community
103
+ </button>
104
+
105
+
106
+ </div>
107
+ </li><li class="hover:text-yellow-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/docs"><svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="mr-1.5 text-gray-400 group-hover:text-yellow-500" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 16 16"><path d="m2.28 3.7-.3.16a.67.67 0 0 0-.34.58v8.73l.01.04.02.07.01.04.03.06.02.04.02.03.04.06.05.05.04.04.06.04.06.04.08.04.08.02h.05l.07.02h.11l.04-.01.07-.02.03-.01.07-.03.22-.12a5.33 5.33 0 0 1 5.15.1.67.67 0 0 0 .66 0 5.33 5.33 0 0 1 5.33 0 .67.67 0 0 0 1-.58V4.36a.67.67 0 0 0-.34-.5l-.3-.17v7.78a.63.63 0 0 1-.87.59 4.9 4.9 0 0 0-4.35.35l-.65.39a.29.29 0 0 1-.15.04.29.29 0 0 1-.16-.04l-.65-.4a4.9 4.9 0 0 0-4.34-.34.63.63 0 0 1-.87-.59V3.7Z" fill="currentColor" class="dark:opacity-40"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M8 3.1a5.99 5.99 0 0 0-5.3-.43.66.66 0 0 0-.42.62v8.18c0 .45.46.76.87.59a4.9 4.9 0 0 1 4.34.35l.65.39c.05.03.1.04.16.04.05 0 .1-.01.15-.04l.65-.4a4.9 4.9 0 0 1 4.35-.34.63.63 0 0 0 .86-.59V3.3a.67.67 0 0 0-.41-.62 5.99 5.99 0 0 0-5.3.43l-.3.17L8 3.1Zm.73 1.87a.43.43 0 1 0-.86 0v5.48a.43.43 0 0 0 .86 0V4.97Z" fill="currentColor" class="opacity-40 dark:opacity-100"></path><path d="M8.73 4.97a.43.43 0 1 0-.86 0v5.48a.43.43 0 1 0 .86 0V4.96Z" fill="currentColor" class="dark:opacity-40"></path></svg>
108
+ Docs</a>
109
+ </li><li class="hover:text-black dark:hover:text-white max-2xl:hidden"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/enterprise"><svg class="mr-1.5 text-gray-400 group-hover:text-black dark:group-hover:text-white" xmlns="http://www.w3.org/2000/svg" fill="none" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 12 12"><path fill-rule="evenodd" clip-rule="evenodd" d="M4.9 1.35a3.16 3.16 0 0 0-2.8 2.07L.37 8.58C0 9.71.7 10.65 1.86 10.65H7.3a3.2 3.2 0 0 0 2.84-2.07l1.67-5.16c.36-1.13-.3-2.07-1.46-2.07H4.91Zm.4 2.07L3.57 8.47h3.57l.36-1.12H5.4l.28-.91h1.75l.4-1.1H6.07l.3-.83h2l.36-1.1H5.27h.04Z" fill="currentColor"></path></svg>
110
+ Enterprise</a>
111
+ </li>
112
+
113
+ <li><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/pricing">Pricing
114
+ </a></li>
115
+
116
+ <li><div class="relative group">
117
+ <button class="px-2 py-0.5 hover:text-gray-500 dark:hover:text-gray-600 flex items-center " type="button">
118
+ <svg class=" text-gray-500 w-5 group-hover:text-gray-400 dark:text-gray-300 dark:group-hover:text-gray-100" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" viewBox="0 0 32 18" preserveAspectRatio="xMidYMid meet"><path fill-rule="evenodd" clip-rule="evenodd" d="M14.4504 3.30221C14.4504 2.836 14.8284 2.45807 15.2946 2.45807H28.4933C28.9595 2.45807 29.3374 2.836 29.3374 3.30221C29.3374 3.76842 28.9595 4.14635 28.4933 4.14635H15.2946C14.8284 4.14635 14.4504 3.76842 14.4504 3.30221Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M14.4504 9.00002C14.4504 8.53382 14.8284 8.15588 15.2946 8.15588H28.4933C28.9595 8.15588 29.3374 8.53382 29.3374 9.00002C29.3374 9.46623 28.9595 9.84417 28.4933 9.84417H15.2946C14.8284 9.84417 14.4504 9.46623 14.4504 9.00002Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M14.4504 14.6978C14.4504 14.2316 14.8284 13.8537 15.2946 13.8537H28.4933C28.9595 13.8537 29.3374 14.2316 29.3374 14.6978C29.3374 15.164 28.9595 15.542 28.4933 15.542H15.2946C14.8284 15.542 14.4504 15.164 14.4504 14.6978Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M1.94549 6.87377C2.27514 6.54411 2.80962 6.54411 3.13928 6.87377L6.23458 9.96907L9.32988 6.87377C9.65954 6.54411 10.194 6.54411 10.5237 6.87377C10.8533 7.20343 10.8533 7.73791 10.5237 8.06756L6.23458 12.3567L1.94549 8.06756C1.61583 7.73791 1.61583 7.20343 1.94549 6.87377Z" fill="currentColor"></path></svg>
119
+
120
+ </button>
121
+
122
+
123
+ </div></li>
124
+ <li><hr class="h-5 w-0.5 border-none bg-gray-100 dark:bg-gray-800"></li>
125
+ <li><a class="block cursor-pointer whitespace-nowrap px-2 py-0.5 hover:text-gray-500 dark:text-gray-300 dark:hover:text-gray-100" href="/login">Log In
126
+ </a></li>
127
+ <li><a class="whitespace-nowrap rounded-full border border-transparent bg-gray-900 px-3 py-1 leading-none text-white hover:border-black hover:bg-white hover:text-black" href="/join">Sign Up
128
+ </a></li></ul></nav></div></header></div>
129
+
130
+
131
+
132
+ <div class="SVELTE_HYDRATER contents" data-target="SSOBanner" data-props="{}"></div>
133
+
134
+
135
+ <main class="flex flex-1 flex-col">
136
+ <div class="SVELTE_HYDRATER contents" data-target="ModelHeader" data-props="{&quot;activeTab&quot;:&quot;files&quot;,&quot;author&quot;:{&quot;_id&quot;:&quot;6683d6350b54a28aff6645fe&quot;,&quot;avatarUrl&quot;:&quot;https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/8xGdIOlfkopZfhbMitw_k.jpeg&quot;,&quot;fullname&quot;:&quot;Kyutai&quot;,&quot;name&quot;:&quot;kyutai&quot;,&quot;type&quot;:&quot;org&quot;,&quot;isHf&quot;:false,&quot;isHfAdmin&quot;:false,&quot;isMod&quot;:false,&quot;isEnterprise&quot;:false,&quot;followerCount&quot;:886},&quot;canReadRepoSettings&quot;:false,&quot;canWriteRepoContent&quot;:false,&quot;canDisable&quot;:false,&quot;model&quot;:{&quot;author&quot;:&quot;kyutai&quot;,&quot;cardData&quot;:{&quot;library_name&quot;:&quot;transformers&quot;,&quot;license&quot;:&quot;cc-by-sa-4.0&quot;,&quot;language&quot;:[&quot;bg&quot;,&quot;cs&quot;,&quot;da&quot;,&quot;de&quot;,&quot;el&quot;,&quot;en&quot;,&quot;es&quot;,&quot;et&quot;,&quot;fi&quot;,&quot;fr&quot;,&quot;ga&quot;,&quot;hr&quot;,&quot;hu&quot;,&quot;it&quot;,&quot;lt&quot;,&quot;lv&quot;,&quot;mt&quot;,&quot;nl&quot;,&quot;pl&quot;,&quot;pt&quot;,&quot;ro&quot;,&quot;sk&quot;,&quot;sl&quot;,&quot;sv&quot;],&quot;pipeline_tag&quot;:&quot;text-generation&quot;},&quot;cardExists&quot;:true,&quot;config&quot;:{&quot;architectures&quot;:[&quot;LlamaForCausalLM&quot;],&quot;model_type&quot;:&quot;llama&quot;},&quot;createdAt&quot;:&quot;2025-04-30T13:59:54.000Z&quot;,&quot;discussionsDisabled&quot;:false,&quot;discussionsSorting&quot;:&quot;recently-created&quot;,&quot;downloads&quot;:28282,&quot;downloadsAllTime&quot;:456198,&quot;id&quot;:&quot;kyutai/helium-1-2b&quot;,&quot;isLikedByUser&quot;:false,&quot;availableInferenceProviders&quot;:[],&quot;inference&quot;:&quot;&quot;,&quot;lastModified&quot;:&quot;2025-04-30T14:38:01.000Z&quot;,&quot;likes&quot;:42,&quot;pipeline_tag&quot;:&quot;text-generation&quot;,&quot;library_name&quot;:&quot;transformers&quot;,&quot;librariesOther&quot;:[],&quot;trackDownloads&quot;:true,&quot;model-index&quot;:null,&quot;private&quot;:false,&quot;repoType&quot;:&quot;model&quot;,&quot;gated&quot;:false,&quot;pwcLink&quot;:{&quot;error&quot;:&quot;Unknown error, can't generate link to Papers With Code.&quot;},&quot;tags&quot;:[&quot;transformers&quot;,&quot;safetensors&quot;,&quot;llama&quot;,&quot;text-generation&quot;,&quot;bg&quot;,&quot;cs&quot;,&quot;da&quot;,&quot;de&quot;,&quot;el&quot;,&quot;en&quot;,&quot;es&quot;,&quot;et&quot;,&quot;fi&quot;,&quot;fr&quot;,&quot;ga&quot;,&quot;hr&quot;,&quot;hu&quot;,&quot;it&quot;,&quot;lt&quot;,&quot;lv&quot;,&quot;mt&quot;,&quot;nl&quot;,&quot;pl&quot;,&quot;pt&quot;,&quot;ro&quot;,&quot;sk&quot;,&quot;sl&quot;,&quot;sv&quot;,&quot;license:cc-by-sa-4.0&quot;,&quot;text-generation-inference&quot;,&quot;endpoints_compatible&quot;,&quot;region:us&quot;],&quot;tag_objs&quot;:[{&quot;id&quot;:&quot;text-generation&quot;,&quot;label&quot;:&quot;Text Generation&quot;,&quot;type&quot;:&quot;pipeline_tag&quot;,&quot;subType&quot;:&quot;nlp&quot;},{&quot;id&quot;:&quot;transformers&quot;,&quot;label&quot;:&quot;Transformers&quot;,&quot;type&quot;:&quot;library&quot;},{&quot;id&quot;:&quot;safetensors&quot;,&quot;label&quot;:&quot;Safetensors&quot;,&quot;type&quot;:&quot;library&quot;},{&quot;id&quot;:&quot;bg&quot;,&quot;label&quot;:&quot;Bulgarian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;cs&quot;,&quot;label&quot;:&quot;Czech&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;da&quot;,&quot;label&quot;:&quot;Danish&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;de&quot;,&quot;label&quot;:&quot;German&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;el&quot;,&quot;label&quot;:&quot;Greek&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;en&quot;,&quot;label&quot;:&quot;English&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;es&quot;,&quot;label&quot;:&quot;Spanish&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;et&quot;,&quot;label&quot;:&quot;Estonian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;fi&quot;,&quot;label&quot;:&quot;Finnish&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;fr&quot;,&quot;label&quot;:&quot;French&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;ga&quot;,&quot;label&quot;:&quot;Irish&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;hr&quot;,&quot;label&quot;:&quot;Croatian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;hu&quot;,&quot;label&quot;:&quot;Hungarian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;it&quot;,&quot;label&quot;:&quot;Italian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;lt&quot;,&quot;label&quot;:&quot;Lithuanian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;lv&quot;,&quot;label&quot;:&quot;Latvian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;mt&quot;,&quot;label&quot;:&quot;Maltese&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;nl&quot;,&quot;label&quot;:&quot;Dutch&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;pl&quot;,&quot;label&quot;:&quot;Polish&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;pt&quot;,&quot;label&quot;:&quot;Portuguese&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;ro&quot;,&quot;label&quot;:&quot;Romanian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;sk&quot;,&quot;label&quot;:&quot;Slovak&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;sl&quot;,&quot;label&quot;:&quot;Slovenian&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;sv&quot;,&quot;label&quot;:&quot;Swedish&quot;,&quot;type&quot;:&quot;language&quot;},{&quot;id&quot;:&quot;llama&quot;,&quot;label&quot;:&quot;llama&quot;,&quot;type&quot;:&quot;other&quot;,&quot;clickable&quot;:true},{&quot;id&quot;:&quot;text-generation-inference&quot;,&quot;label&quot;:&quot;text-generation-inference&quot;,&quot;type&quot;:&quot;other&quot;,&quot;clickable&quot;:true},{&quot;id&quot;:&quot;endpoints_compatible&quot;,&quot;label&quot;:&quot;Inference Endpoints&quot;,&quot;type&quot;:&quot;other&quot;,&quot;clickable&quot;:true},{&quot;id&quot;:&quot;license:cc-by-sa-4.0&quot;,&quot;label&quot;:&quot;cc-by-sa-4.0&quot;,&quot;type&quot;:&quot;license&quot;},{&quot;type&quot;:&quot;region&quot;,&quot;label&quot;:&quot;🇺🇸 Region: US&quot;,&quot;id&quot;:&quot;region:us&quot;}],&quot;transformersInfo&quot;:{&quot;auto_model&quot;:&quot;AutoModelForCausalLM&quot;,&quot;pipeline_tag&quot;:&quot;text-generation&quot;,&quot;processor&quot;:&quot;AutoTokenizer&quot;},&quot;safetensors&quot;:{&quot;parameters&quot;:{&quot;BF16&quot;:2023868416},&quot;total&quot;:2023868416,&quot;sharded&quot;:false},&quot;hasBlockedOids&quot;:false,&quot;region&quot;:&quot;us&quot;,&quot;isQuantized&quot;:false},&quot;discussionsStats&quot;:{&quot;closed&quot;:1,&quot;open&quot;:0,&quot;total&quot;:1},&quot;query&quot;:{},&quot;inferenceContextData&quot;:{&quot;billableEntities&quot;:[],&quot;entityName2Providers&quot;:{}}}"><header class="bg-linear-to-t border-b border-gray-100 pt-6 sm:pt-9 from-purple-500/8 dark:from-purple-500/20 to-white to-70% dark:to-gray-950"><div class="container relative "><h1 class="flex flex-wrap items-center max-md:leading-tight mb-3 text-lg max-sm:gap-y-1.5 md:text-xl">
137
+ <div class="group flex flex-none items-center"><div class="relative mr-1 flex items-center">
138
+
139
+
140
+
141
+ <span class="inline-block "><span class="contents"><a href="/kyutai" class="text-gray-400 hover:text-blue-600"><img alt="" class="size-3.5 rounded-sm flex-none select-none" src="https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/8xGdIOlfkopZfhbMitw_k.jpeg" crossorigin="anonymous"></a></span>
142
+ </span></div>
143
+
144
+
145
+ <span class="inline-block "><span class="contents"><a href="/kyutai" class="text-gray-400 hover:text-blue-600">kyutai</a></span>
146
+ </span>
147
+ <div class="mx-0.5 text-gray-300">/</div></div>
148
+
149
+ <div class="max-w-full "><a class="break-words font-mono font-semibold hover:text-blue-600 " href="/kyutai/helium-1-2b">helium-1-2b</a>
150
+ <button class="text-sm mr-4 focus:outline-hidden inline-flex cursor-pointer items-center text-sm mx-0.5 text-gray-600 " title="Copy model name to clipboard" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
151
+ </button></div>
152
+ <div class="inline-flex items-center overflow-hidden whitespace-nowrap rounded-md border bg-white text-sm leading-none text-gray-500 mr-2"><button class="relative flex items-center overflow-hidden from-red-50 to-transparent dark:from-red-900 px-1.5 py-1 hover:bg-linear-to-t focus:outline-hidden" title="Like"><svg class="left-1.5 absolute" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32" fill="currentColor"><path d="M22.45,6a5.47,5.47,0,0,1,3.91,1.64,5.7,5.7,0,0,1,0,8L16,26.13,5.64,15.64a5.7,5.7,0,0,1,0-8,5.48,5.48,0,0,1,7.82,0L16,10.24l2.53-2.58A5.44,5.44,0,0,1,22.45,6m0-2a7.47,7.47,0,0,0-5.34,2.24L16,7.36,14.89,6.24a7.49,7.49,0,0,0-10.68,0,7.72,7.72,0,0,0,0,10.82L16,29,27.79,17.06a7.72,7.72,0,0,0,0-10.82A7.49,7.49,0,0,0,22.45,4Z"></path></svg>
153
+
154
+
155
+ <span class="ml-4 pl-0.5 ">like</span></button>
156
+ <button class="focus:outline-hidden flex items-center border-l px-1.5 py-1 text-gray-400 hover:bg-gray-50 focus:bg-gray-100 dark:hover:bg-gray-900 dark:focus:bg-gray-800" title="See users who liked this repository">42</button></div>
157
+
158
+
159
+ <div class="relative flex items-center gap-1.5 "><div class="mr-2 inline-flex h-6 items-center overflow-hidden whitespace-nowrap rounded-md border text-sm text-gray-500"><button class="focus:outline-hidden relative flex h-full max-w-56 items-center gap-1.5 overflow-hidden px-1.5 hover:bg-gray-50 focus:bg-gray-100 dark:hover:bg-gray-900 dark:focus:bg-gray-800" type="button" ><div class="flex h-full flex-1 items-center justify-center ">Follow</div>
160
+ <img alt="" class="rounded-xs size-3 flex-none select-none" src="https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/8xGdIOlfkopZfhbMitw_k.jpeg" loading="lazy">
161
+ <span class="truncate">Kyutai</span></button>
162
+ <button class="focus:outline-hidden flex h-full items-center border-l pl-1.5 pr-1.5 text-gray-400 hover:bg-gray-50 focus:bg-gray-100 dark:hover:bg-gray-900 dark:focus:bg-gray-800" title="Show Kyutai's followers" type="button">886</button></div>
163
+
164
+ </div>
165
+
166
+ </h1>
167
+ <div class="mb-3 flex flex-wrap md:mb-4">
168
+ <a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?pipeline_tag=text-generation"><div class="tag tag-white "><div class="tag-ico -ml-2 tag-ico-red"><svg class="-mr-0.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 18 18"><path d="M16.2607 8.08202L14.468 6.28928C14.3063 6.12804 14.0873 6.03749 13.859 6.03749C13.6307 6.03749 13.4117 6.12804 13.25 6.28928L5.6375 13.904V16.9125H8.64607L16.2607 9.30002C16.422 9.13836 16.5125 8.91935 16.5125 8.69102C16.5125 8.4627 16.422 8.24369 16.2607 8.08202V8.08202ZM8.1953 15.825H6.725V14.3547L11.858 9.22118L13.3288 10.6915L8.1953 15.825ZM14.0982 9.92262L12.6279 8.45232L13.8606 7.21964L15.3309 8.68994L14.0982 9.92262Z"></path><path d="M6.18125 9.84373H7.26875V6.03748H8.9V4.94998H4.55V6.03748H6.18125V9.84373Z"></path><path d="M4.55 11.475H2.375V2.775H11.075V4.95H12.1625V2.775C12.1625 2.48658 12.0479 2.20997 11.844 2.00602C11.64 1.80208 11.3634 1.6875 11.075 1.6875H2.375C2.08658 1.6875 1.80997 1.80208 1.60602 2.00602C1.40207 2.20997 1.2875 2.48658 1.2875 2.775V11.475C1.2875 11.7634 1.40207 12.04 1.60602 12.244C1.80997 12.4479 2.08658 12.5625 2.375 12.5625H4.55V11.475Z"></path></svg></div>
169
+
170
+
171
+
172
+ <span>Text Generation</span>
173
+
174
+
175
+ </div></a>
176
+ <a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?library=transformers"><div class="tag tag-white "><svg class="text-black inline-block text-sm" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" preserveAspectRatio="xMidYMid meet" width="1em" height="1em" viewBox="0 0 95 88"><path fill="#fff" d="M94.25 70.08a8.28 8.28 0 0 1-.43 6.46 10.57 10.57 0 0 1-3 3.6 25.18 25.18 0 0 1-5.7 3.2 65.74 65.74 0 0 1-7.56 2.65 46.67 46.67 0 0 1-11.42 1.68c-5.42.05-10.09-1.23-13.4-4.5a40.4 40.4 0 0 1-10.14.03c-3.34 3.25-7.99 4.52-13.39 4.47a46.82 46.82 0 0 1-11.43-1.68 66.37 66.37 0 0 1-7.55-2.65c-2.28-.98-4.17-2-5.68-3.2a10.5 10.5 0 0 1-3.02-3.6c-.99-2-1.18-4.3-.42-6.46a8.54 8.54 0 0 1-.33-5.63c.25-.95.66-1.83 1.18-2.61a8.67 8.67 0 0 1 2.1-8.47 8.23 8.23 0 0 1 2.82-2.07 41.75 41.75 0 1 1 81.3-.12 8.27 8.27 0 0 1 3.11 2.19 8.7 8.7 0 0 1 2.1 8.47c.52.78.93 1.66 1.18 2.61a8.61 8.61 0 0 1-.32 5.63Z"></path><path fill="#FFD21E" d="M47.21 76.5a34.75 34.75 0 1 0 0-69.5 34.75 34.75 0 0 0 0 69.5Z"></path><path fill="#FF9D0B" d="M81.96 41.75a34.75 34.75 0 1 0-69.5 0 34.75 34.75 0 0 0 69.5 0Zm-73.5 0a38.75 38.75 0 1 1 77.5 0 38.75 38.75 0 0 1-77.5 0Z"></path><path fill="#3A3B45" d="M58.5 32.3c1.28.44 1.78 3.06 3.07 2.38a5 5 0 1 0-6.76-2.07c.61 1.15 2.55-.72 3.7-.32ZM34.95 32.3c-1.28.44-1.79 3.06-3.07 2.38a5 5 0 1 1 6.76-2.07c-.61 1.15-2.56-.72-3.7-.32Z"></path><path fill="#FF323D" d="M46.96 56.29c9.83 0 13-8.76 13-13.26 0-2.34-1.57-1.6-4.09-.36-2.33 1.15-5.46 2.74-8.9 2.74-7.19 0-13-6.88-13-2.38s3.16 13.26 13 13.26Z"></path><path fill="#3A3B45" fill-rule="evenodd" d="M39.43 54a8.7 8.7 0 0 1 5.3-4.49c.4-.12.81.57 1.24 1.28.4.68.82 1.37 1.24 1.37.45 0 .9-.68 1.33-1.35.45-.7.89-1.38 1.32-1.25a8.61 8.61 0 0 1 5 4.17c3.73-2.94 5.1-7.74 5.1-10.7 0-2.34-1.57-1.6-4.09-.36l-.14.07c-2.31 1.15-5.39 2.67-8.77 2.67s-6.45-1.52-8.77-2.67c-2.6-1.29-4.23-2.1-4.23.29 0 3.05 1.46 8.06 5.47 10.97Z" clip-rule="evenodd"></path><path fill="#FF9D0B" d="M70.71 37a3.25 3.25 0 1 0 0-6.5 3.25 3.25 0 0 0 0 6.5ZM24.21 37a3.25 3.25 0 1 0 0-6.5 3.25 3.25 0 0 0 0 6.5ZM17.52 48c-1.62 0-3.06.66-4.07 1.87a5.97 5.97 0 0 0-1.33 3.76 7.1 7.1 0 0 0-1.94-.3c-1.55 0-2.95.59-3.94 1.66a5.8 5.8 0 0 0-.8 7 5.3 5.3 0 0 0-1.79 2.82c-.24.9-.48 2.8.8 4.74a5.22 5.22 0 0 0-.37 5.02c1.02 2.32 3.57 4.14 8.52 6.1 3.07 1.22 5.89 2 5.91 2.01a44.33 44.33 0 0 0 10.93 1.6c5.86 0 10.05-1.8 12.46-5.34 3.88-5.69 3.33-10.9-1.7-15.92-2.77-2.78-4.62-6.87-5-7.77-.78-2.66-2.84-5.62-6.25-5.62a5.7 5.7 0 0 0-4.6 2.46c-1-1.26-1.98-2.25-2.86-2.82A7.4 7.4 0 0 0 17.52 48Zm0 4c.51 0 1.14.22 1.82.65 2.14 1.36 6.25 8.43 7.76 11.18.5.92 1.37 1.31 2.14 1.31 1.55 0 2.75-1.53.15-3.48-3.92-2.93-2.55-7.72-.68-8.01.08-.02.17-.02.24-.02 1.7 0 2.45 2.93 2.45 2.93s2.2 5.52 5.98 9.3c3.77 3.77 3.97 6.8 1.22 10.83-1.88 2.75-5.47 3.58-9.16 3.58-3.81 0-7.73-.9-9.92-1.46-.11-.03-13.45-3.8-11.76-7 .28-.54.75-.76 1.34-.76 2.38 0 6.7 3.54 8.57 3.54.41 0 .7-.17.83-.6.79-2.85-12.06-4.05-10.98-8.17.2-.73.71-1.02 1.44-1.02 3.14 0 10.2 5.53 11.68 5.53.11 0 .2-.03.24-.1.74-1.2.33-2.04-4.9-5.2-5.21-3.16-8.88-5.06-6.8-7.33.24-.26.58-.38 1-.38 3.17 0 10.66 6.82 10.66 6.82s2.02 2.1 3.25 2.1c.28 0 .52-.1.68-.38.86-1.46-8.06-8.22-8.56-11.01-.34-1.9.24-2.85 1.31-2.85Z"></path><path fill="#FFD21E" d="M38.6 76.69c2.75-4.04 2.55-7.07-1.22-10.84-3.78-3.77-5.98-9.3-5.98-9.3s-.82-3.2-2.69-2.9c-1.87.3-3.24 5.08.68 8.01 3.91 2.93-.78 4.92-2.29 2.17-1.5-2.75-5.62-9.82-7.76-11.18-2.13-1.35-3.63-.6-3.13 2.2.5 2.79 9.43 9.55 8.56 11-.87 1.47-3.93-1.71-3.93-1.71s-9.57-8.71-11.66-6.44c-2.08 2.27 1.59 4.17 6.8 7.33 5.23 3.16 5.64 4 4.9 5.2-.75 1.2-12.28-8.53-13.36-4.4-1.08 4.11 11.77 5.3 10.98 8.15-.8 2.85-9.06-5.38-10.74-2.18-1.7 3.21 11.65 6.98 11.76 7.01 4.3 1.12 15.25 3.49 19.08-2.12Z"></path><path fill="#FF9D0B" d="M77.4 48c1.62 0 3.07.66 4.07 1.87a5.97 5.97 0 0 1 1.33 3.76 7.1 7.1 0 0 1 1.95-.3c1.55 0 2.95.59 3.94 1.66a5.8 5.8 0 0 1 .8 7 5.3 5.3 0 0 1 1.78 2.82c.24.9.48 2.8-.8 4.74a5.22 5.22 0 0 1 .37 5.02c-1.02 2.32-3.57 4.14-8.51 6.1-3.08 1.22-5.9 2-5.92 2.01a44.33 44.33 0 0 1-10.93 1.6c-5.86 0-10.05-1.8-12.46-5.34-3.88-5.69-3.33-10.9 1.7-15.92 2.78-2.78 4.63-6.87 5.01-7.77.78-2.66 2.83-5.62 6.24-5.62a5.7 5.7 0 0 1 4.6 2.46c1-1.26 1.98-2.25 2.87-2.82A7.4 7.4 0 0 1 77.4 48Zm0 4c-.51 0-1.13.22-1.82.65-2.13 1.36-6.25 8.43-7.76 11.18a2.43 2.43 0 0 1-2.14 1.31c-1.54 0-2.75-1.53-.14-3.48 3.91-2.93 2.54-7.72.67-8.01a1.54 1.54 0 0 0-.24-.02c-1.7 0-2.45 2.93-2.45 2.93s-2.2 5.52-5.97 9.3c-3.78 3.77-3.98 6.8-1.22 10.83 1.87 2.75 5.47 3.58 9.15 3.58 3.82 0 7.73-.9 9.93-1.46.1-.03 13.45-3.8 11.76-7-.29-.54-.75-.76-1.34-.76-2.38 0-6.71 3.54-8.57 3.54-.42 0-.71-.17-.83-.6-.8-2.85 12.05-4.05 10.97-8.17-.19-.73-.7-1.02-1.44-1.02-3.14 0-10.2 5.53-11.68 5.53-.1 0-.19-.03-.23-.1-.74-1.2-.34-2.04 4.88-5.2 5.23-3.16 8.9-5.06 6.8-7.33-.23-.26-.57-.38-.98-.38-3.18 0-10.67 6.82-10.67 6.82s-2.02 2.1-3.24 2.1a.74.74 0 0 1-.68-.38c-.87-1.46 8.05-8.22 8.55-11.01.34-1.9-.24-2.85-1.31-2.85Z"></path><path fill="#FFD21E" d="M56.33 76.69c-2.75-4.04-2.56-7.07 1.22-10.84 3.77-3.77 5.97-9.3 5.97-9.3s.82-3.2 2.7-2.9c1.86.3 3.23 5.08-.68 8.01-3.92 2.93.78 4.92 2.28 2.17 1.51-2.75 5.63-9.82 7.76-11.18 2.13-1.35 3.64-.6 3.13 2.2-.5 2.79-9.42 9.55-8.55 11 .86 1.47 3.92-1.71 3.92-1.71s9.58-8.71 11.66-6.44c2.08 2.27-1.58 4.17-6.8 7.33-5.23 3.16-5.63 4-4.9 5.2.75 1.2 12.28-8.53 13.36-4.4 1.08 4.11-11.76 5.3-10.97 8.15.8 2.85 9.05-5.38 10.74-2.18 1.69 3.21-11.65 6.98-11.76 7.01-4.31 1.12-15.26 3.49-19.08-2.12Z"></path></svg>
177
+
178
+
179
+
180
+ <span>Transformers</span>
181
+
182
+
183
+ </div></a>
184
+ <a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?library=safetensors"><div class="tag tag-white "><svg class="text-black inline-block text-sm" viewBox="0 0 57 44" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet"><path d="M36.816 20.1474L54.9918 27.4409C55.5142 27.6506 55.9623 28.0112 56.2788 28.4766C56.5954 28.9421 56.7661 29.4913 56.7691 30.0542C56.7722 30.6171 56.6074 31.1682 56.2959 31.637C55.9844 32.1059 55.5402 32.4713 55.0201 32.6866L29.953 43.0646C29.2593 43.3518 28.4799 43.3518 27.7862 43.0646L2.71624 32.6894C2.19613 32.4741 1.75197 32.1087 1.44044 31.6398C1.12892 31.171 0.964165 30.62 0.967204 30.057C0.970244 29.4941 1.14094 28.9449 1.45751 28.4794C1.77408 28.014 2.22216 27.6534 2.74456 27.4437L21.2404 20.0227C22.2997 19.5979 25.6477 20.8441 28.8682 20.8555C32.3096 20.8668 35.6292 19.6715 36.816 20.1474ZM11.3042 30.1119L28.8682 37.3828L46.435 30.1119L28.8682 23.0619L11.3042 30.1119ZM29.9247 0.388251L54.9918 10.4462C55.5142 10.6559 55.9623 11.0165 56.2788 11.482C56.5954 11.9474 56.7661 12.4967 56.7691 13.0596C56.7722 13.6225 56.6074 14.1735 56.2959 14.6424C55.9844 15.1112 55.5402 15.4766 55.0201 15.6919L29.953 26.07C29.2593 26.3572 28.4799 26.3572 27.7862 26.07L2.71624 15.6948C2.19613 15.4795 1.75197 15.1141 1.44044 14.6452C1.12892 14.1763 0.964165 13.6253 0.967204 13.0624C0.970244 12.4995 1.14094 11.9503 1.45751 11.4848C1.77408 11.0193 2.22216 10.6588 2.74456 10.4491L27.8117 0.388251C28.4896 0.1157 29.2467 0.1157 29.9247 0.388251ZM11.3042 13.1172L28.8682 20.3881L46.435 13.1172L28.8682 6.06729L11.3042 13.1172Z" fill="currentColor"></path></svg>
185
+
186
+
187
+
188
+ <span>Safetensors</span>
189
+
190
+
191
+ </div></a>
192
+
193
+ <button class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" type="button"><div class="tag tag-white ">
194
+ <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="text-green-600/80" preserveAspectRatio="xMidYMid meet" width="1em" height="1em" viewBox="0 0 10 10"><path fill-rule="evenodd" clip-rule="evenodd" d="M0.625 5C0.625 6.16032 1.08594 7.27312 1.90641 8.09359C2.72688 8.91406 3.83968 9.375 5 9.375C6.16032 9.375 7.27312 8.91406 8.09359 8.09359C8.91406 7.27312 9.375 6.16032 9.375 5C9.375 3.83968 8.91406 2.72688 8.09359 1.90641C7.27312 1.08594 6.16032 0.625 5 0.625C3.83968 0.625 2.72688 1.08594 1.90641 1.90641C1.08594 2.72688 0.625 3.83968 0.625 5ZM7.64365 7.48027C7.61734 7.50832 7.59054 7.53598 7.56326 7.56326C7.13828 7.98824 6.61864 8.2968 6.0539 8.46842C6.29802 8.11949 6.49498 7.64804 6.63475 7.09483C7.00845 7.18834 7.35014 7.3187 7.64365 7.48027ZM8.10076 6.87776C8.37677 6.42196 8.55005 5.90894 8.60556 5.37499H6.86808C6.85542 5.71597 6.82551 6.04557 6.77971 6.35841C7.25309 6.47355 7.68808 6.6414 8.062 6.85549C8.07497 6.86283 8.08789 6.87025 8.10076 6.87776ZM6.03795 6.22536C6.07708 5.95737 6.1044 5.67232 6.11705 5.37499H3.88295C3.89666 5.69742 3.92764 6.00542 3.9722 6.29287C4.37075 6.21726 4.79213 6.17749 5.224 6.17749C5.50054 6.17749 5.77294 6.19376 6.03795 6.22536ZM4.1261 7.02673C4.34894 7.84835 4.68681 8.375 5 8.375C5.32122 8.375 5.66839 7.82101 5.8908 6.963C5.67389 6.93928 5.45082 6.92699 5.224 6.92699C4.84316 6.92699 4.47332 6.96176 4.1261 7.02673ZM3.39783 7.21853C3.53498 7.71842 3.72038 8.14579 3.9461 8.46842C3.42141 8.30898 2.93566 8.03132 2.52857 7.65192C2.77253 7.48017 3.06711 7.33382 3.39783 7.21853ZM3.23916 6.48077C3.18263 6.13193 3.14625 5.76074 3.13192 5.37499H1.39444C1.4585 5.99112 1.67936 6.57938 2.03393 7.08403C2.3706 6.83531 2.78055 6.63162 3.23916 6.48077ZM1.39444 4.62499H3.13192C3.14615 4.24204 3.18211 3.87344 3.23794 3.52681C2.77814 3.37545 2.36731 3.17096 2.03024 2.92123C1.67783 3.42469 1.45828 4.011 1.39444 4.62499ZM2.5237 2.35262C2.76812 2.52552 3.06373 2.67281 3.39584 2.78875C3.53318 2.28573 3.71928 1.85578 3.9461 1.53158C3.41932 1.69166 2.93178 1.97089 2.5237 2.35262ZM3.97101 3.71489C3.92709 4.00012 3.89654 4.30547 3.88295 4.62499H6.11705C6.10453 4.33057 6.07761 4.04818 6.03909 3.78248C5.77372 3.81417 5.50093 3.83049 5.224 3.83049C4.79169 3.83049 4.3699 3.79065 3.97101 3.71489ZM5.8928 3.04476C5.67527 3.06863 5.45151 3.08099 5.224 3.08099C4.84241 3.08099 4.47186 3.04609 4.12405 2.98086C4.34686 2.1549 4.68584 1.625 5 1.625C5.32218 1.625 5.67048 2.18233 5.8928 3.04476ZM6.78083 3.6493C6.826 3.95984 6.85552 4.28682 6.86808 4.62499H8.60556C8.55029 4.09337 8.37827 3.58251 8.10436 3.1282C8.0903 3.1364 8.07618 3.14449 8.062 3.15249C7.68838 3.36641 7.25378 3.53417 6.78083 3.6493ZM7.64858 2.52499C7.35446 2.68754 7.0117 2.81868 6.63664 2.91268C6.49676 2.35623 6.29913 1.88209 6.0539 1.53158C6.61864 1.7032 7.13828 2.01176 7.56326 2.43674C7.59224 2.46572 7.62068 2.49514 7.64858 2.52499Z" fill="currentColor"></path></svg>
195
+
196
+
197
+
198
+ <span>24 languages</span>
199
+
200
+
201
+ </div></button>
202
+ <a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?other=llama"><div class="tag tag-white ">
203
+
204
+
205
+
206
+ <span>llama</span>
207
+
208
+
209
+ </div></a>
210
+ <a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?other=text-generation-inference"><div class="tag tag-white ">
211
+ <svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 12 12"><path fill="#23B0FF" d="m9.6 3.6-3.2-2a1 1 0 0 0-1.1 0L2 3.7a1 1 0 0 0-.3 1.6H10a1 1 0 0 0-.3-1.6Z"></path><path fill="#2094FF" d="m6.7 9.7 3.2-4.5-.4-.8H5.7v4.8l1 .5Z"></path><path fill="#6BCAFF" d="M4.9 9.7 1.7 5.2l.4-.8h3.8v4.8l-1 .5Z"></path><path fill="#000" fill-rule="evenodd" d="M9.9 3.2c.8.5 1 1.5.5 2.3L7 10c-.6.9-2 .9-2.6 0L1.3 5.5c-.5-.8-.3-1.8.5-2.3l3.2-2c.5-.3 1.2-.3 1.7 0l3.2 2ZM6.4 5h3l-3 4.2V5ZM5.3 5h-3l3 4.2V5Zm3.8-1L6 2a.5.5 0 0 0-.5 0L2.6 4H9Z" clip-rule="evenodd"></path></svg>
212
+
213
+
214
+
215
+ <span>text-generation-inference</span>
216
+
217
+
218
+ </div></a><div class="relative inline-block ">
219
+ <button class="group mr-1 mb-1 md:mr-1.5 md:mb-1.5 rounded-full rounded-br-none " type="button">
220
+ <div slot="button"><div class="tag rounded-full tag-white relative rounded-br-none pr-2.5">
221
+ <svg class="text-xs text-gray-900" width="1em" height="1em" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1.46009 5.0945V6.88125C1.46009 7.25201 1.75937 7.55129 2.13012 7.55129C2.50087 7.55129 2.80016 7.25201 2.80016 6.88125V5.0945C2.80016 4.72375 2.50087 4.42446 2.13012 4.42446C1.75937 4.42446 1.46009 4.72375 1.46009 5.0945ZM4.14022 5.0945V6.88125C4.14022 7.25201 4.4395 7.55129 4.81026 7.55129C5.18101 7.55129 5.48029 7.25201 5.48029 6.88125V5.0945C5.48029 4.72375 5.18101 4.42446 4.81026 4.42446C4.4395 4.42446 4.14022 4.72375 4.14022 5.0945ZM1.23674 9.78473H8.38377C8.75452 9.78473 9.0538 9.48545 9.0538 9.1147C9.0538 8.74395 8.75452 8.44466 8.38377 8.44466H1.23674C0.865993 8.44466 0.566711 8.74395 0.566711 9.1147C0.566711 9.48545 0.865993 9.78473 1.23674 9.78473ZM6.82036 5.0945V6.88125C6.82036 7.25201 7.11964 7.55129 7.49039 7.55129C7.86114 7.55129 8.16042 7.25201 8.16042 6.88125V5.0945C8.16042 4.72375 7.86114 4.42446 7.49039 4.42446C7.11964 4.42446 6.82036 4.72375 6.82036 5.0945ZM4.39484 0.623142L0.865993 2.48137C0.682851 2.57517 0.566711 2.76725 0.566711 2.97273C0.566711 3.28094 0.816857 3.53109 1.12507 3.53109H8.49991C8.80365 3.53109 9.0538 3.28094 9.0538 2.97273C9.0538 2.76725 8.93766 2.57517 8.75452 2.48137L5.22568 0.623142C4.9666 0.484669 4.65391 0.484669 4.39484 0.623142V0.623142Z" fill="currentColor"></path></svg>
222
+
223
+ <span class="-mr-1 text-gray-400">License:</span>
224
+
225
+ <span>cc-by-sa-4.0</span>
226
+
227
+
228
+ <div class="border-br-gray-200 absolute bottom-0.5 right-0.5 h-1 w-1 border-[3px] border-l-transparent border-t-transparent border-b-gray-200 border-r-gray-200 group-hover:border-b-gray-400 group-hover:border-r-gray-400 dark:border-b-gray-700 dark:border-r-gray-700 group-hover:dark:border-b-gray-400 group-hover:dark:border-r-gray-400"></div></div></div>
229
+ </button>
230
+
231
+
232
+ </div></div>
233
+
234
+ <div class="flex flex-col-reverse lg:flex-row lg:items-center lg:justify-between"><div class="-mb-px flex h-12 items-center overflow-x-auto overflow-y-hidden ">
235
+ <a class="tab-alternate" href="/kyutai/helium-1-2b"><svg class="mr-1.5 text-gray-400 flex-none" style="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-quaternary" d="M20.23 7.24L12 12L3.77 7.24a1.98 1.98 0 0 1 .7-.71L11 2.76c.62-.35 1.38-.35 2 0l6.53 3.77c.29.173.531.418.7.71z" opacity=".25" fill="currentColor"></path><path class="uim-tertiary" d="M12 12v9.5a2.09 2.09 0 0 1-.91-.21L4.5 17.48a2.003 2.003 0 0 1-1-1.73v-7.5a2.06 2.06 0 0 1 .27-1.01L12 12z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M20.5 8.25v7.5a2.003 2.003 0 0 1-1 1.73l-6.62 3.82c-.275.13-.576.198-.88.2V12l8.23-4.76c.175.308.268.656.27 1.01z" fill="currentColor"></path></svg>
236
+ Model card
237
+
238
+
239
+
240
+ </a><a class="tab-alternate active" href="/kyutai/helium-1-2b/tree/main"><svg class="mr-1.5 text-gray-400 flex-none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-tertiary" d="M21 19h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2zm0-4h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2zm0-8h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2zm0 4h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M9 19a1 1 0 0 1-1-1V6a1 1 0 0 1 2 0v12a1 1 0 0 1-1 1zm-6-4.333a1 1 0 0 1-.64-1.769L3.438 12l-1.078-.898a1 1 0 0 1 1.28-1.538l2 1.667a1 1 0 0 1 0 1.538l-2 1.667a.999.999 0 0 1-.64.231z" fill="currentColor"></path></svg>
241
+ <span class="xl:hidden">Files</span>
242
+ <span class="hidden xl:inline">Files and versions</span>
243
+
244
+
245
+
246
+
247
+ <span class="inline-block "><span class="contents"><div slot="anchor" class="shadow-purple-500/10 ml-2 inline-flex -translate-y-px items-center gap-0.5 rounded-md border bg-white px-1 py-0.5 align-middle text-xs font-semibold leading-none text-gray-800 shadow-sm dark:border-gray-700 dark:bg-gradient-to-b dark:from-gray-925 dark:to-gray-925 dark:text-gray-300"><svg class="size-3 " xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 12 12"><path fill-rule="evenodd" clip-rule="evenodd" d="M6.14 3.64 5.1 4.92 2.98 2.28h2.06l1.1 1.36Zm0 4.72-1.1 1.36H2.98l2.13-2.64 1.03 1.28Zm4.9 1.36L8.03 6l3-3.72H8.96L5.97 6l3 3.72h2.06Z" fill="#7875FF"></path><path d="M4.24 6 2.6 8.03.97 6 2.6 3.97 4.24 6Z" fill="#FF7F41" opacity="1"></path></svg>
248
+ <span>xet</span>
249
+ </div></span>
250
+ </span>
251
+ </a><a class="tab-alternate" href="/kyutai/helium-1-2b/discussions"><svg class="mr-1.5 text-gray-400 flex-none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path><path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path></svg>
252
+ Community
253
+ <div class="ml-1.5 flex h-4 min-w-[1rem] items-center justify-center rounded px-1 text-xs leading-none shadow-sm bg-gray-200 text-gray-600 dark:bg-gray-900 dark:text-gray-500">1</div>
254
+
255
+
256
+ </a></div>
257
+
258
+
259
+
260
+
261
+ <div class="relative mb-1.5 flex flex-wrap gap-1.5 sm:flex-nowrap lg:mb-0"><div class="order-last sm:order-first"><div class="relative ">
262
+ <button class="btn px-1.5 py-1.5 " type="button">
263
+
264
+ <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="p-0.5" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><circle cx="16" cy="7" r="3" fill="currentColor"></circle><circle cx="16" cy="16" r="3" fill="currentColor"></circle><circle cx="16" cy="25" r="3" fill="currentColor"></circle></svg>
265
+
266
+ </button>
267
+
268
+
269
+ </div></div>
270
+
271
+
272
+
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+
284
+ <div class="flex-none w-full sm:w-auto"><div class="relative ">
285
+ <button class="text-sm btn btn w-full cursor-pointer text-sm" type="button">
286
+ <svg class="mr-1.5 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><rect x="6.34" y="19" width="11.31" height="2" transform="translate(-10.63 14.34) rotate(-45)"></rect><path d="M17,30a1,1,0,0,1-.37-.07,1,1,0,0,1-.62-.79l-1-7,2-.28.75,5.27L21,24.52V17a1,1,0,0,1,.29-.71l4.07-4.07A8.94,8.94,0,0,0,28,5.86V4H26.14a8.94,8.94,0,0,0-6.36,2.64l-4.07,4.07A1,1,0,0,1,15,11H7.48L4.87,14.26l5.27.75-.28,2-7-1a1,1,0,0,1-.79-.62,1,1,0,0,1,.15-1l4-5A1,1,0,0,1,7,9h7.59l3.77-3.78A10.92,10.92,0,0,1,26.14,2H28a2,2,0,0,1,2,2V5.86a10.92,10.92,0,0,1-3.22,7.78L23,17.41V25a1,1,0,0,1-.38.78l-5,4A1,1,0,0,1,17,30Z"></path></svg>
287
+ Deploy
288
+ <svg class="-mr-1 text-gray-500 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path d="M16.293 9.293L12 13.586L7.707 9.293l-1.414 1.414L12 16.414l5.707-5.707z" fill="currentColor"></path></svg></button>
289
+
290
+
291
+ </div>
292
+
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+ </div>
303
+ <div class="relative flex-auto sm:flex-none">
304
+ <button class="from-gray-800! to-black! text-white! gap-1! border-gray-800! dark:border-gray-900! btn w-full cursor-pointer text-sm" type="button">
305
+ <svg class="mr-1.5 mr-0.5! " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path fill="currentColor" d="M28 4H4a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h8v4H8v2h16v-2h-4v-4h8a2 2 0 0 0 2-2V6a2 2 0 0 0-2-2ZM18 28h-4v-4h4Zm10-6H4V6h24Z"></path></svg>
306
+ Use this model
307
+ <svg class="-mr-1 text-gray-500 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path d="M16.293 9.293L12 13.586L7.707 9.293l-1.414 1.414L12 16.414l5.707-5.707z" fill="currentColor"></path></svg></button>
308
+
309
+
310
+ </div>
311
+
312
+
313
+
314
+ </div>
315
+ </div></div></header>
316
+ </div>
317
+
318
+ <div class="container relative flex flex-col md:grid md:space-y-0 w-full md:grid-cols-12 space-y-4 md:gap-6 mb-16"><section class="pt-8 border-gray-100 col-span-full"><div class="SVELTE_HYDRATER contents" data-target="ViewerHeader" data-props="{&quot;context&quot;:{&quot;repo&quot;:{&quot;name&quot;:&quot;kyutai/helium-1-2b&quot;,&quot;type&quot;:&quot;model&quot;},&quot;rev&quot;:&quot;main&quot;,&quot;path&quot;:&quot;tokenizer.model&quot;,&quot;subpaths&quot;:[{&quot;dir&quot;:&quot;tokenizer.model&quot;}]},&quot;refs&quot;:{&quot;branches&quot;:[{&quot;name&quot;:&quot;main&quot;,&quot;ref&quot;:&quot;refs/heads/main&quot;,&quot;targetCommit&quot;:&quot;5764947fc2e782982c24f363eacd9baea3e821f8&quot;}],&quot;tags&quot;:[],&quot;converts&quot;:[]},&quot;view&quot;:&quot;blob&quot;,&quot;isMac&quot;:false}"><header class="flex flex-wrap items-center justify-start pb-2 md:justify-end lg:flex-nowrap"><div class="grow max-md:flex max-md:w-full max-md:items-start max-md:justify-between"><div class="relative mr-4 flex min-w-0 basis-auto flex-wrap items-center gap-x-3 md:grow md:basis-full lg:basis-auto lg:flex-nowrap"><div class="relative mb-2">
319
+ <button class="text-sm md:text-base btn w-full cursor-pointer text-sm" type="button">
320
+ <svg class="mr-1.5 text-gray-700 dark:text-gray-400" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24" style="transform: rotate(360deg);"><path d="M13 14c-3.36 0-4.46 1.35-4.82 2.24C9.25 16.7 10 17.76 10 19a3 3 0 0 1-3 3a3 3 0 0 1-3-3c0-1.31.83-2.42 2-2.83V7.83A2.99 2.99 0 0 1 4 5a3 3 0 0 1 3-3a3 3 0 0 1 3 3c0 1.31-.83 2.42-2 2.83v5.29c.88-.65 2.16-1.12 4-1.12c2.67 0 3.56-1.34 3.85-2.23A3.006 3.006 0 0 1 14 7a3 3 0 0 1 3-3a3 3 0 0 1 3 3c0 1.34-.88 2.5-2.09 2.86C17.65 11.29 16.68 14 13 14m-6 4a1 1 0 0 0-1 1a1 1 0 0 0 1 1a1 1 0 0 0 1-1a1 1 0 0 0-1-1M7 4a1 1 0 0 0-1 1a1 1 0 0 0 1 1a1 1 0 0 0 1-1a1 1 0 0 0-1-1m10 2a1 1 0 0 0-1 1a1 1 0 0 0 1 1a1 1 0 0 0 1-1a1 1 0 0 0-1-1z" fill="currentColor"></path></svg>
321
+ main
322
+ <svg class="-mr-1 text-gray-500 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path d="M16.293 9.293L12 13.586L7.707 9.293l-1.414 1.414L12 16.414l5.707-5.707z" fill="currentColor"></path></svg></button>
323
+
324
+
325
+ </div>
326
+ <div class="relative mb-2 flex flex-wrap items-center"><a class="truncate text-gray-800 hover:underline" href="/kyutai/helium-1-2b/tree/main">helium-1-2b</a>
327
+ <span class="mx-1 text-gray-300">/</span>
328
+ <span class="dark:text-gray-300">tokenizer.model</span>
329
+ <button class="text-xs ml-2 focus:outline-hidden inline-flex cursor-pointer items-center text-sm mx-0.5 text-gray-600 " title="Copy path" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
330
+ </button></div>
331
+ </div>
332
+ </div>
333
+
334
+ </header></div>
335
+ <div class="SVELTE_HYDRATER contents" data-target="LastCommit" data-props="{&quot;commitLast&quot;:{&quot;date&quot;:&quot;2025-04-30T14:01:50.000Z&quot;,&quot;verified&quot;:&quot;verified&quot;,&quot;subject&quot;:&quot;Upload tokenizer.model with huggingface_hub&quot;,&quot;authors&quot;:[{&quot;_id&quot;:&quot;6355a3c1805be5a8f30fea49&quot;,&quot;avatar&quot;:&quot;https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/ONMEctCWAeAgF2eZ307si.jpeg&quot;,&quot;isHf&quot;:false,&quot;user&quot;:&quot;lmz&quot;}],&quot;commit&quot;:{&quot;id&quot;:&quot;b8d50a6775dfd77d956b7cd18928736dccd17fe7&quot;,&quot;parentIds&quot;:[&quot;b3d4f57a13777182735134b6aaf4b610767cd08c&quot;]},&quot;title&quot;:&quot;Upload tokenizer.model with huggingface_hub&quot;},&quot;repo&quot;:{&quot;name&quot;:&quot;kyutai/helium-1-2b&quot;,&quot;type&quot;:&quot;model&quot;}}"><div class="from-gray-100-to-white bg-linear-to-t flex flex-wrap items-baseline gap-y-1 rounded-t-lg border border-b-0 px-3 py-2 dark:border-gray-800"><img class="mr-2.5 mt-0.5 h-4 w-4 self-center rounded-full" alt="lmz's picture" src="https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/ONMEctCWAeAgF2eZ307si.jpeg">
336
+ <div class="mr-4 flex flex-none items-center truncate"><a class="hover:underline" href="/lmz">lmz
337
+ </a>
338
+
339
+ </div>
340
+ <div class="mr-4 truncate font-mono text-xs text-gray-500 hover:prose-a:underline sm:text-sm"><!-- HTML_TAG_START -->Upload tokenizer.model with huggingface_hub<!-- HTML_TAG_END --></div>
341
+ <a class="rounded-sm border bg-gray-50 px-1.5 text-sm hover:underline dark:border-gray-800 dark:bg-gray-900" href="/kyutai/helium-1-2b/commit/b8d50a6775dfd77d956b7cd18928736dccd17fe7">b8d50a6</a>
342
+ <span class="mx-2 text-green-500 dark:text-green-600 px-1.5 border-green-100 dark:border-green-800 rounded-full border text-xs uppercase" title="This commit is signed and the signature is verified">verified</span>
343
+ <time class="ml-auto hidden flex-none truncate pl-2 text-gray-500 dark:text-gray-400 lg:block" datetime="2025-04-30T14:01:50" title="Wed, 30 Apr 2025 14:01:50 GMT">7 months ago</time></div></div>
344
+ <div class="relative flex flex-wrap items-center border px-3 py-1.5 text-sm text-gray-800 dark:border-gray-800 dark:bg-gray-900 ">
345
+ <a class="group my-1 mr-4 flex items-center " download="" href="/kyutai/helium-1-2b/resolve/main/tokenizer.model?download=true"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" viewBox="0 0 32 32"><path fill="currentColor" d="M26 24v4H6v-4H4v4a2 2 0 0 0 2 2h20a2 2 0 0 0 2-2v-4zm0-10l-1.41-1.41L17 20.17V2h-2v18.17l-7.59-7.58L6 14l10 10l10-10z"></path></svg>
346
+ download</span>
347
+
348
+ </a><div class="SVELTE_HYDRATER contents" data-target="CopyButton" data-props="{&quot;value&quot;:&quot;https://huggingface.co/kyutai/helium-1-2b/resolve/main/tokenizer.model&quot;,&quot;style&quot;:&quot;blank&quot;,&quot;label&quot;:&quot;Copy download link&quot;,&quot;classNames&quot;:&quot;my-1 mr-4 flex items-center no-underline hover:underline&quot;}"><button class="my-1 mr-4 flex items-center no-underline hover:underline " title="Copy download link" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
349
+ <span class="ml-1.5 ">Copy download link</span></button></div><a class="group my-1 mr-4 flex items-center " href="/kyutai/helium-1-2b/commits/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32" style="transform: rotate(360deg);"><path d="M16 4C9.383 4 4 9.383 4 16s5.383 12 12 12s12-5.383 12-12S22.617 4 16 4zm0 2c5.535 0 10 4.465 10 10s-4.465 10-10 10S6 21.535 6 16S10.465 6 16 6zm-1 2v9h7v-2h-5V8z" fill="currentColor"></path></svg>
350
+ history</span>
351
+
352
+ </a><a class="group my-1 mr-4 flex items-center " href="/kyutai/helium-1-2b/blame/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32" style="transform: rotate(360deg);"><path d="M16 2a14 14 0 1 0 14 14A14 14 0 0 0 16 2zm0 26a12 12 0 1 1 12-12a12 12 0 0 1-12 12z" fill="currentColor"></path><path d="M11.5 11a2.5 2.5 0 1 0 2.5 2.5a2.48 2.48 0 0 0-2.5-2.5z" fill="currentColor"></path><path d="M20.5 11a2.5 2.5 0 1 0 2.5 2.5a2.48 2.48 0 0 0-2.5-2.5z" fill="currentColor"></path></svg>
353
+ blame</span>
354
+
355
+ </a><a class="group my-1 mr-4 flex items-center text-green-600 dark:text-green-500" href="/kyutai/helium-1-2b/edit/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M2 26h28v2H2z" fill="currentColor"></path><path d="M25.4 9c.8-.8.8-2 0-2.8l-3.6-3.6c-.8-.8-2-.8-2.8 0l-15 15V24h6.4l15-15zm-5-5L24 7.6l-3 3L17.4 7l3-3zM6 22v-3.6l10-10l3.6 3.6l-10 10H6z" fill="currentColor"></path></svg>
356
+ contribute</span>
357
+
358
+ </a><a class="group my-1 mr-4 flex items-center " href="/kyutai/helium-1-2b/delete/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M12 12h2v12h-2z" fill="currentColor"></path><path d="M18 12h2v12h-2z" fill="currentColor"></path><path d="M4 6v2h2v20a2 2 0 0 0 2 2h16a2 2 0 0 0 2-2V8h2V6zm4 22V8h16v20z" fill="currentColor"></path><path d="M12 2h8v2h-8z" fill="currentColor"></path></svg>
359
+ delete</span>
360
+
361
+ </a>
362
+
363
+ <div class="mr-4 flex items-center"><div class="SVELTE_HYDRATER contents" data-target="ScanStatusBadge" data-props="{&quot;classNames&quot;:&quot;mr-2&quot;,&quot;scanStatus&quot;:{&quot;status&quot;:&quot;safe&quot;,&quot;protectAiScan&quot;:{&quot;status&quot;:&quot;unscanned&quot;,&quot;message&quot;:null,&quot;reportLink&quot;:&quot;https://protectai.com/insights/models/kyutai/helium-1-2b/5764947fc2e782982c24f363eacd9baea3e821f8/files?blob-id=48679a193304ea9e6dda4c4de9be4d4db590c249&amp;utm_source=huggingface&quot;},&quot;avScan&quot;:{&quot;status&quot;:&quot;safe&quot;,&quot;message&quot;:&quot;No security issues detected&quot;,&quot;reportLink&quot;:&quot;https://fdtn.ai/ai-supply-chain/hugging-face?utm_source=huggingface&quot;,&quot;reportLabel&quot;:&quot;Learn more at Cisco Foundation AI&quot;},&quot;pickleImportScan&quot;:{&quot;status&quot;:&quot;unscanned&quot;,&quot;pickleImports&quot;:[],&quot;version&quot;:&quot;0.0.0&quot;},&quot;virusTotalScan&quot;:{&quot;status&quot;:&quot;safe&quot;,&quot;message&quot;:&quot;0/76 engines detect it as malicious.&quot;,&quot;reportLink&quot;:&quot;https://www.virustotal.com/gui/file/abb8879fdb2001dfae68d0bbdccbe92ae1593bad518abb34c9513f27904ee303?utm_source=huggingface&quot;,&quot;reportLabel&quot;:&quot;See more details on VirusTotal&quot;},&quot;jFrogScan&quot;:{&quot;status&quot;:&quot;unscanned&quot;,&quot;message&quot;:&quot;Not a machine-learning model&quot;,&quot;reportLink&quot;:&quot;&quot;,&quot;reportLabel&quot;:&quot;&quot;}},&quot;repo&quot;:{&quot;name&quot;:&quot;kyutai/helium-1-2b&quot;,&quot;type&quot;:&quot;model&quot;},&quot;revision&quot;:&quot;main&quot;,&quot;filePath&quot;:&quot;tokenizer.model&quot;,&quot;openByDefault&quot;:false}"><div class="sm:relative mr-2"><button class="flex h-[1.125rem] select-none items-center gap-0.5 rounded border pl-0.5 pr-0.5 text-xs leading-tight text-gray-400 hover:cursor-pointer text-gray-400 hover:border-gray-200 hover:bg-gray-50 hover:text-gray-500 dark:border-gray-800 dark:hover:bg-gray-800 dark:hover:text-gray-200 "><svg class="flex-none" width="1em" height="1em" viewBox="0 0 22 28" fill="none" xmlns="http://www.w3.org/2000/svg"><path fill-rule="evenodd" clip-rule="evenodd" d="M15.3634 10.3639C15.8486 10.8491 15.8486 11.6357 15.3634 12.1209L10.9292 16.5551C10.6058 16.8785 10.0814 16.8785 9.7579 16.5551L7.03051 13.8277C6.54532 13.3425 6.54532 12.5558 7.03051 12.0707C7.51569 11.5855 8.30234 11.5855 8.78752 12.0707L9.7579 13.041C10.0814 13.3645 10.6058 13.3645 10.9292 13.041L13.6064 10.3639C14.0916 9.8787 14.8782 9.8787 15.3634 10.3639Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M10.6666 27.12C4.93329 25.28 0 19.2267 0 12.7867V6.52001C0 5.40001 0.693334 4.41334 1.73333 4.01334L9.73333 1.01334C10.3333 0.786673 11 0.786673 11.6 1.02667L19.6 4.02667C20.1083 4.21658 20.5465 4.55701 20.8562 5.00252C21.1659 5.44803 21.3324 5.97742 21.3333 6.52001V12.7867C21.3333 19.24 16.4 25.28 10.6666 27.12Z" fill="currentColor" fill-opacity="0.22"></path><path d="M10.0845 1.94967L10.0867 1.94881C10.4587 1.8083 10.8666 1.81036 11.2286 1.95515L11.2387 1.95919L11.2489 1.963L19.2489 4.963L19.25 4.96342C19.5677 5.08211 19.8416 5.29488 20.0351 5.57333C20.2285 5.85151 20.3326 6.18203 20.3333 6.52082C20.3333 6.52113 20.3333 6.52144 20.3333 6.52176L20.3333 12.7867C20.3333 18.6535 15.8922 24.2319 10.6666 26.0652C5.44153 24.2316 1 18.6409 1 12.7867V6.52001C1 5.82357 1.42893 5.20343 2.08883 4.94803L10.0845 1.94967Z" stroke="currentColor" stroke-opacity="0.30" stroke-width="2"></path></svg>
364
+
365
+ <span class="mr-0.5 max-sm:hidden">Safe</span></button>
366
+
367
+ </div></div>
368
+ </div>
369
+
370
+ <div class="flex items-center gap-x-3 dark:text-gray-300 sm:ml-auto">
371
+ 1.14 MB</div></div>
372
+
373
+ <div class="relative min-h-[100px] rounded-b-lg border border-t-0 leading-tight dark:border-gray-800 dark:bg-gray-925">
374
+ <div class="p-4 py-8 text-center">This file is stored with
375
+ <a class="underline" href="https://huggingface.co/docs/hub/xet/index">Xet</a>
376
+ . It is too big to display, but you can still
377
+ <a download class="underline" href="/kyutai/helium-1-2b/resolve/main/tokenizer.model">download</a>
378
+ it.
379
+ </div>
380
+ <div class="bg-linear-to-br from-gray-50-to-white relative border-t p-4"><div class="text-smd mb-2 flex items-baseline"><h3 class="font-semibold">Large File Pointer Details</h3>
381
+ <span class="ml-2">(</span>
382
+ <a href="/kyutai/helium-1-2b/raw/main/tokenizer.model" class="flex items-center underline decoration-gray-400 hover:decoration-gray-700 dark:decoration-gray-500 dark:hover:decoration-gray-300" target="_blank"><svg class="mr-0.5 text-xs" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M25.7 9.3l-7-7A.908.908 0 0 0 18 2H8a2.006 2.006 0 0 0-2 2v24a2.006 2.006 0 0 0 2 2h16a2.006 2.006 0 0 0 2-2V10a.908.908 0 0 0-.3-.7zM18 4.4l5.6 5.6H18zM24 28H8V4h8v6a2.006 2.006 0 0 0 2 2h6z" fill="currentColor"></path></svg> Raw pointer file
383
+ </a>
384
+ <span>)</span></div>
385
+ <dl class="break-words font-mono text-[0.8rem]"><div class="mr-1 flex md:mb-1"><dt class="mr-1.5 font-semibold">SHA256:</dt>
386
+ <dd class="truncate">abb8879fdb2001dfae68d0bbdccbe92ae1593bad518abb34c9513f27904ee303</dd>
387
+ </div><div class="flex flex-wrap"><dt class="mr-1.5 font-semibold">Pointer size:</dt>
388
+ <dd>132 Bytes</dd>
389
+
390
+ <div class="px-1.5 opacity-40">·</div>
391
+
392
+ <dt class="mr-1.5 font-semibold">Size of remote file:</dt>
393
+ <dd>1.14 MB</dd>
394
+
395
+ <div class="px-1.5 opacity-40">·</div>
396
+ <dt class="mr-1.5 font-semibold">Xet hash:</dt>
397
+ <dd class="truncate">bca9ac44cc00b884e9fb49abf7fa3576e32e568e788ff68ce6cee89d24e2b8e4</dd></div></dl>
398
+ <p class="mt-2 text-sm text-gray-500">Xet efficiently stores Large Files inside Git, intelligently splitting files into unique chunks and
399
+ accelerating uploads and downloads.
400
+ <a class="underline" href="/join/xet" target="_blank">More info</a>.</p></div>
401
+ </div></section></div></main>
402
+
403
+ </div>
404
+ <script>
405
+ import("\/front\/build\/kube-02d86c8\/index.js"); window.moonSha = "kube-02d86c8\/"; window.__hf_deferred =
406
+ {};
407
+ </script>
408
+ <!-- Stripe -->
409
+ <script>
410
+ if (["hf.co", "huggingface.co"].includes(window.location.hostname)) {
411
+ const script = document.createElement("script");
412
+ script.src = "https://js.stripe.com/v3/";
413
+ script.async = true;
414
+ document.head.appendChild(script);
415
+ }
416
+ </script>
417
+ </body>
418
+ </html>
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "additional_special_tokens": [
4
+ "<|im_sp_00|>",
5
+ "<|im_sp_01|>",
6
+ "<|im_sp_02|>",
7
+ "<|im_sp_94|>",
8
+ "<|im_sp_95|>",
9
+ "<|im_sp_96|>",
10
+ "<|im_sp_97|>",
11
+ "<|im_sp_98|>",
12
+ "<|im_sp_99|>"
13
+ ]
14
+ }
utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=protected-access
2
+ """Utils to handle CASA layers construction"""
3
+
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass, fields
6
+ from typing import Any, Callable, Generic, TypeVar
7
+
8
+ import torch
9
+
10
+
11
+ def delta_w_factory(
12
+ org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
13
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
14
+ """Factory for building linear op where the weights are the sum of two layers' weights"""
15
+
16
+ def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
17
+ nonlocal org_lin, new_lin
18
+ bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
19
+ return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)
20
+
21
+ return _delta_w_fwd
22
+
23
+
24
+ @dataclass
25
+ class StreamingState:
26
+ """Streaming State used by CASA layers at inference to save
27
+ e.g. the offset, the KV Cache and other persistent states"""
28
+
29
+ offset: int = 0
30
+
31
+ def _is_valid_field(self, key: str) -> bool:
32
+ return key in {x.name for x in fields(self)}
33
+
34
+ def _init_field(self, key: str) -> None:
35
+ """Init function for non-arggment dependent defauls"""
36
+ assert self._is_valid_field(key)
37
+ if key == "offset":
38
+ self.offset = 0
39
+ else:
40
+ # for fields which should be set explicitly and cannot be auto-initialized
41
+ setattr(self, key, None)
42
+
43
+ def init(self) -> None:
44
+ for key in [x.name for x in fields(self)]:
45
+ self._init_field(key)
46
+
47
+ def _reset_field(self, name: str) -> None:
48
+ """Resets the given field"""
49
+ self._init_field(name)
50
+
51
+ def reset(self) -> None:
52
+ for f in fields(self):
53
+ self._reset_field(f.name)
54
+
55
+ def _get_field(self, f: str) -> Any:
56
+ """Get field and init if not"""
57
+ assert self._is_valid_field(f)
58
+ if getattr(self, f) is None:
59
+ self._init_field(f)
60
+ return getattr(self, f)
61
+
62
+ def _set_field(self, f: str, value: Any) -> None:
63
+ assert self._is_valid_field(f)
64
+ setattr(self, f, value)
65
+
66
+
67
+ StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)
68
+
69
+
70
+ class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method
71
+ """Overrides Audiocraft's Streaming modules with additional small utils"""
72
+
73
+ def __init__(self, state_class: type) -> None:
74
+ torch.nn.Module.__init__(self)
75
+ self.is_streaming: bool = False
76
+ self.enable_viz: tuple[str, ...] = ()
77
+ self._streaming_state: StreamingStateT = state_class()
78
+
79
+ @property
80
+ def streaming_state(self) -> StreamingStateT:
81
+ return self._streaming_state
82
+
83
+ def _apply_named_streaming(self, fn: Callable):
84
+ """Apply function to all streaming modules"""
85
+ for name, module in self.named_modules():
86
+ if isinstance(module, StreamingModule):
87
+ fn(name, module)
88
+
89
+ def reset_streaming(self):
90
+ """Reset the streaming state."""
91
+
92
+ def _reset(_: str, module: StreamingModule):
93
+ module._streaming_state.reset()
94
+
95
+ self._apply_named_streaming(_reset)
96
+
97
+ def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
98
+ """Set all streaming modules in streaming mode"""
99
+
100
+ def _set_streaming(_, module: StreamingModule) -> None:
101
+ module.is_streaming = streaming
102
+ module.enable_viz = viz
103
+ if streaming:
104
+ module.streaming_state.init()
105
+
106
+ self._apply_named_streaming(_set_streaming)
107
+
108
+ @contextmanager
109
+ def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
110
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
111
+ self._set_streaming(stream, viz)
112
+ try:
113
+ yield
114
+ finally:
115
+ self._set_streaming(False, ())
116
+ self.reset_streaming()