Super-squash branch 'main' using huggingface_hub
Browse filesCo-authored-by: mboehle <mboehle@users.noreply.huggingface.co>
- .gitattributes +6 -0
- Notice +2 -0
- README.md +311 -0
- __init__.py +0 -0
- casa_attention.py +1010 -0
- config.json +77 -0
- configuration_helium1_casa.py +270 -0
- generation_config.json +10 -0
- image_encoder.py +57 -0
- language_helium1_casa.py +1077 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +793 -0
- modeling_helium1_casa.py +330 -0
- processing.py +505 -0
- processing_helium1_casa.py +37 -0
- processor_config.json +10 -0
- readme_images/CASA.png +3 -0
- readme_images/casa_explainer.mp4 +3 -0
- tokenizer.json +0 -0
- tokenizer.model +418 -0
- tokenizer_config.json +14 -0
- utils.py +116 -0
.gitattributes
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
model-00002-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
model-00003-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
readme_images/CASA.png filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
readme_images/casa_explainer.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
readme_images/half_res_trimmed.mp4 filter=lfs diff=lfs merge=lfs -text
|
Notice
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CASA-Helium1-VL-2B's image encoder is finetuned from the image encoder of Qwen2.5-VL-3B.
|
| 2 |
+
Qwen is licensed under the Qwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
|
README.md
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-sa-4.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
base_model:
|
| 6 |
+
- kyutai/helium-1-2b
|
| 7 |
+
datasets:
|
| 8 |
+
- HuggingFaceM4/FineVision
|
| 9 |
+
- mvp-lab/LLaVA-OneVision-1.5-Instruct-Data
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
<img align="right" src="readme_images/CASA.png" width="150px" >
|
| 13 |
+
|
| 14 |
+
# Model Card for CASA-Helium1-VL-2B
|
| 15 |
+
|
| 16 |
+
**CASA** ([Project Page][blog] . [arXiv][casa-arxiv] . [github][casa-git]) stands for **C**ross-**A**ttention via **S**elf-**A**ttention.
|
| 17 |
+
**CASA** is a vision-language fusion paradigm that aims to improve on cross-attention while preserving its practical benefits.
|
| 18 |
+
|
| 19 |
+
Specifically, **CASA** layers inject visual tokens into a text stream by using image-to-text cross-attention while additionally enabling
|
| 20 |
+
text-to-text self interaction in the same layer, and contained to smaller local attention windows.
|
| 21 |
+
This simple modification enables natural gating in the cross-attention mechanism, improving its performance and substantially closing the gap to standard token insertion methods.
|
| 22 |
+
For qualitative samples of CASA used for live video captioning, please check the [associated HuggingFace space](https://huggingface.co/spaces/kyutai/casa-samples).
|
| 23 |
+
|
| 24 |
+

|
| 25 |
+
|
| 26 |
+
## Model Details
|
| 27 |
+
|
| 28 |
+
### Model Description
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
This model page contains the model weights for CASA trained from a pretrained text-only Helium1-2B backbone and from the image encoder from Qwen2.5-VL-3B.
|
| 32 |
+
In the collection, we also provides weights for:
|
| 33 |
+
- [`CASA-Qwen2_5-VL-3B`](https://huggingface.co/kyutai/CASA-Qwen2_5-VL-3B): A CASA model adapted from the full pretrained `Qwen2.5-VL-3B` (keeping the backbone LLM weights are kept frozen)
|
| 34 |
+
- [`CASA-Qwen2_5-VL-3B-LiveCC`](https://huggingface.co/kyutai/CASA-Qwen2_5-VL-3B-LiveCC): A CASA model adapted from the full pretrained `Qwen2.5-VL-3B` and futher finetuned for live video captioning.
|
| 35 |
+
- [`Helium1-VL-2B`](https://huggingface.co/kyutai/Helium1-VL-2B): A reference VLM trained from Helium1-2B with standard token insertion mechanism in the same setting as `CASA-Helium1-VL-2B`.
|
| 36 |
+
|
| 37 |
+
Model Summary:
|
| 38 |
+
- **Developed by:** Kyutai
|
| 39 |
+
- **Model type:** Multimodal vision+text model based on Cross-Attention
|
| 40 |
+
- **Language(s) (NLP):** English
|
| 41 |
+
- **License:** CC-BY-NC-SA-4.0
|
| 42 |
+
- **LLM Backboner from:** [Helium1 2B](https://huggingface.co/kyutai/helium-1-2b)
|
| 43 |
+
- **Image Encoder from:** [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
| 44 |
+
- **Terms of use:** As the released models include frozen weights of the Qwen2.5VL-3B image encoder, the weights are subject to the [Qwen RESEARCH LICENSE AGREEMENT](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct/blob/main/LICENSE)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
### Model Sources
|
| 48 |
+
|
| 49 |
+
- **Project Page** [kyutai.org/casa][blog]
|
| 50 |
+
- **Preprint** [arXiv][casa-arxiv]
|
| 51 |
+
- **Repository:** [Github kyutai-labs/casa][casa-git]
|
| 52 |
+
|
| 53 |
+
## Uses
|
| 54 |
+
|
| 55 |
+
### Direct Use
|
| 56 |
+
The intended use of the Helium model is research and development of vision-language systems, including but not limited to image or video understanding.
|
| 57 |
+
|
| 58 |
+
`CASA-Helium1-VL-2B`, `Helium1-VL-2B` and `CASA-Qwen2_5-VL-2B` can be used as vision-language models to analyze or interpret images as input signals.
|
| 59 |
+
|
| 60 |
+
`CASA-Qwen2_5-VL-2B-LiveCC` can be used as a vision-language model on streaming videos as inputs at 2fps.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
The models can be used primarly with English as a language. For most downstream use cases, the model should be aligned with supervised fine-tuning, RLHF or related methods.
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
### Out-of-Scope Use
|
| 67 |
+
|
| 68 |
+
The model should not be used in other languages than the ones on which it was trained.
|
| 69 |
+
The model is not intended to be used to impersonate other people or any malicious use of any kind.
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
## Bias, Risks, and Limitations
|
| 73 |
+
Our CASA-Helium1 model was not aligned to human preferences. As such, the model can generate incorrect, biased, harmful or generally unhelpful content. Thus, the model should not be used for downstream applications without further alignment, evaluations and mitigations of risks.
|
| 74 |
+
|
| 75 |
+
### Recommendations
|
| 76 |
+
|
| 77 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 78 |
+
|
| 79 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 80 |
+
|
| 81 |
+
## How to Get Started with the Model
|
| 82 |
+
|
| 83 |
+
See our [github repository][casa-git] for additional scripts to perform benchmark evaluation and live video captioning.
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
Below is a short snippet to show you how to load our models, process inputs, and run inference, using a standard HuggingFace `transformers` pipeline and chat template.
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
# Minimal requirements:
|
| 90 |
+
# /// script
|
| 91 |
+
# requires-python = ">=3.10"
|
| 92 |
+
# dependencies = [
|
| 93 |
+
# "rich",
|
| 94 |
+
# "einops>=0.8.1",
|
| 95 |
+
# "torch==2.7.0",
|
| 96 |
+
# "transformers==4.51.3",
|
| 97 |
+
# "torchvision==0.22.0",
|
| 98 |
+
# "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"
|
| 99 |
+
# ]
|
| 100 |
+
# ///
|
| 101 |
+
import torch
|
| 102 |
+
from transformers.models.auto.modeling_auto import AutoModel
|
| 103 |
+
from transformers.models.auto.processing_auto import AutoProcessor
|
| 104 |
+
|
| 105 |
+
model_id = "kyutai/CASA-Helium1-VL-2B"
|
| 106 |
+
model = AutoModel.from_pretrained(
|
| 107 |
+
model_id,
|
| 108 |
+
torch_dtype=torch.bfloat16,
|
| 109 |
+
attn_implementation="flash_attention_2",
|
| 110 |
+
trust_remote_code=True,
|
| 111 |
+
).cuda()
|
| 112 |
+
processor = AutoProcessor.from_pretrained(
|
| 113 |
+
model_id,
|
| 114 |
+
trust_remote_code=True,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
conversation = [
|
| 118 |
+
{
|
| 119 |
+
"role": "user",
|
| 120 |
+
"content": [
|
| 121 |
+
{
|
| 122 |
+
"type": "image",
|
| 123 |
+
"image": "assets/casa_model.png",
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"type": "text",
|
| 127 |
+
"text": "Describe this image.",
|
| 128 |
+
},
|
| 129 |
+
],
|
| 130 |
+
},
|
| 131 |
+
]
|
| 132 |
+
inputs = processor.tokenize_messages(messages=conversation)
|
| 133 |
+
inputs = inputs.to(model.device)
|
| 134 |
+
input_len = inputs["input_ids"].shape[1]
|
| 135 |
+
output_ids = model.generate_from_image(
|
| 136 |
+
**inputs,
|
| 137 |
+
max_new_tokens=512,
|
| 138 |
+
pre_image_tokens=processor.pre_image_tokens,
|
| 139 |
+
post_image_tokens=processor.post_image_tokens,
|
| 140 |
+
eos_token_id=model.generation_config.eos_token_id,
|
| 141 |
+
)[0, input_len:]
|
| 142 |
+
response = processor.tokenizer.decode(output_ids, skip_special_tokens=True)
|
| 143 |
+
print(response)
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
## Training Details
|
| 149 |
+
|
| 150 |
+
Please have a look at our associated [research paper][casa-arxiv] for details on the training pipeline.
|
| 151 |
+
|
| 152 |
+
### Training Data
|
| 153 |
+
|
| 154 |
+
To train our CASA-Helium models we use the [FineVision](https://huggingface.co/datasets/HuggingFaceM4/FineVision)
|
| 155 |
+
dataset as well as a small, non overlapping, subset of [Llava-OneVision-1.5-Instruct](https://github.com/EvolvingLMMs-Lab/LLaVA-OneVision-1.5)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
## Evaluation
|
| 159 |
+
We evaluate our models on a range of benchmarks covering document understanding (`DocVQA`), chart understanding (`ChartQA`, `InfoVQA`),
|
| 160 |
+
visual text reading (`TextVQA`, `OCRBench`), and general QA (`RealWorldQA`, `AI2D`, `GQA`, `MME`). Results are reported below. Please refer to our [project page][blog] and [arxiv paper][casa-arxiv] for additional evaluation.
|
| 161 |
+
|
| 162 |
+
<table style="border-collapse: collapse;">
|
| 163 |
+
<tr>
|
| 164 |
+
<th rowspan="2" align="left">Model</th>
|
| 165 |
+
<th colspan="3" align="center">Document / Chart</th>
|
| 166 |
+
<th colspan="2" align="center">Scene Text</th>
|
| 167 |
+
<th colspan="4" align="center">Knowledge / QA</th>
|
| 168 |
+
</tr>
|
| 169 |
+
<tr>
|
| 170 |
+
<th>ChartQA</th>
|
| 171 |
+
<th>DocVQA</th>
|
| 172 |
+
<th>InfoVQA</th>
|
| 173 |
+
<th>OCRBench</th>
|
| 174 |
+
<th>TextVQA</th>
|
| 175 |
+
<th>RealWorldQA</th>
|
| 176 |
+
<th>AI2D</th>
|
| 177 |
+
<th>GQA</th>
|
| 178 |
+
<th>MME</th>
|
| 179 |
+
</tr>
|
| 180 |
+
</thead>
|
| 181 |
+
<tbody>
|
| 182 |
+
<tr>
|
| 183 |
+
<td align="left">Helium1-VL-2B</td>
|
| 184 |
+
<td>81.6</td><td>89.1</td><td>61.8</td>
|
| 185 |
+
<td>728</td><td>75.5</td>
|
| 186 |
+
<td>59.9</td><td>67.7</td><td>55.5</td><td>1732</td>
|
| 187 |
+
</tr>
|
| 188 |
+
<tr>
|
| 189 |
+
<td align="left"><span style="color:#fb923c;"><strong>CASA-Helium1-VL-2B</strong></span></td>
|
| 190 |
+
<td>73.4</td><td>83.7</td><td>48.6</td>
|
| 191 |
+
<td>723</td><td>71.0</td>
|
| 192 |
+
<td>58.3</td><td>63.3</td><td>54.6</td><td>1572</td>
|
| 193 |
+
</tr>
|
| 194 |
+
<tr>
|
| 195 |
+
<td align="left"><span style="color:#60a5fa;">mPLUG-Owl3 8B</span></td>
|
| 196 |
+
<td>59.2<sup>†</sup></td><td>55.9<sup>†</sup></td><td>36.8<sup>†</sup></td>
|
| 197 |
+
<td>527<sup>†</sup></td><td>69.0</td>
|
| 198 |
+
<td>63.9<sup>†</sup></td><td>73.4</td><td>65.0</td><td>1940<sup>†</sup></td>
|
| 199 |
+
</tr>
|
| 200 |
+
<tr>
|
| 201 |
+
<td align="left"><span style="color:#60a5fa;">mPLUG-Owl3 2B</span></td>
|
| 202 |
+
<td>48.5<sup>†</sup></td><td>48.2<sup>†</sup></td><td>28.1<sup>†</sup></td>
|
| 203 |
+
<td>450<sup>†</sup></td><td>62.6</td>
|
| 204 |
+
<td>56.9<sup>†</sup></td><td>62.6</td><td>61.0</td><td>1551<sup>†</sup></td>
|
| 205 |
+
</tr>
|
| 206 |
+
</tbody>
|
| 207 |
+
</table>
|
| 208 |
+
<p>
|
| 209 |
+
<sup>†</sup> Reproduced with the publicly available models on Hugging Face.  
|
| 210 |
+
<!-- ◇ Results and model not publicly available. -->
|
| 211 |
+
</p>
|
| 212 |
+
|
| 213 |
+
<p align="center">
|
| 214 |
+
<em>
|
| 215 |
+
Results for <code>CASA-Helium1-VL-2B</code> compared to a recent cross-attention baseline (blue), and our token insertion
|
| 216 |
+
(<code>Helium1-VL-2B</code> trained in the same conditions. CASA outperforms current SoTA
|
| 217 |
+
cross-attention-based VLMs, narrowing the gap to insertion-based approaches.
|
| 218 |
+
</em>
|
| 219 |
+
</p>
|
| 220 |
+
|
| 221 |
+
<table style="border-collapse: collapse;">
|
| 222 |
+
<thead>
|
| 223 |
+
<tr>
|
| 224 |
+
<th rowspan="2" align="left">Model</th>
|
| 225 |
+
<th colspan="3" align="center">Document / Chart</th>
|
| 226 |
+
<th colspan="2" align="center">Scene Text</th>
|
| 227 |
+
<th colspan="4" align="center">Knowledge / QA</th>
|
| 228 |
+
</tr>
|
| 229 |
+
<tr>
|
| 230 |
+
<th>ChartQA</th>
|
| 231 |
+
<th>DocVQA</th>
|
| 232 |
+
<th>InfoVQA</th>
|
| 233 |
+
<th>OCRBench</th>
|
| 234 |
+
<th>TextVQA</th>
|
| 235 |
+
<th>RealWorldQA</th>
|
| 236 |
+
<th>AI2D</th>
|
| 237 |
+
<th>GQA</th>
|
| 238 |
+
<th>MME</th>
|
| 239 |
+
</tr>
|
| 240 |
+
</thead>
|
| 241 |
+
|
| 242 |
+
<tbody>
|
| 243 |
+
<tr>
|
| 244 |
+
<td align="left">
|
| 245 |
+
Qwen2.5-VL-3B
|
| 246 |
+
</td>
|
| 247 |
+
<td>84.0</td><td>93.6</td><td>77.1</td>
|
| 248 |
+
<td>797</td><td>79.3</td>
|
| 249 |
+
<td>62.2<sup>†</sup></td><td>81.6</td><td>61.0<sup>†</sup></td><td>2249<sup>†</sup></td>
|
| 250 |
+
</tr>
|
| 251 |
+
<tr>
|
| 252 |
+
<td align="left">
|
| 253 |
+
<span style="color:#fb923c;"><strong>CASA-Qwen2_5-VL-3B</strong></span>
|
| 254 |
+
</td>
|
| 255 |
+
<td>82.4</td><td>88.9</td><td>59.6</td>
|
| 256 |
+
<td>790</td><td>77.4</td>
|
| 257 |
+
<td>62.5</td><td>75.1</td><td>59.4</td><td>1918</td>
|
| 258 |
+
</tr>
|
| 259 |
+
</tbody>
|
| 260 |
+
</table>
|
| 261 |
+
|
| 262 |
+
<p>
|
| 263 |
+
<sup>†</sup> Reproduced with the publicly available models on Hugging Face.
|
| 264 |
+
</p>
|
| 265 |
+
|
| 266 |
+
<p align="center">
|
| 267 |
+
<em>
|
| 268 |
+
Results for <code>CASA-Qwen2_5-VL-3B</code>, adapted from frozen Qwen2.5-VL. CASA reaches performance close to the original
|
| 269 |
+
insertion-based model while while training only
|
| 270 |
+
the CASA layers and last blocks of the image encoder.
|
| 271 |
+
</em>
|
| 272 |
+
</p>
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
## Technical Specifications
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
### Compute Infrastructure
|
| 279 |
+
|
| 280 |
+
`CASA-Helium1-2B` was trained starting from a `Helium1-2B` LLM and the image encoder from `Qwen2.5-VL-3B`.
|
| 281 |
+
We finetune the whole LLM backbone as well as the last four blocks of the image encoder.
|
| 282 |
+
The currently released model was trained on four DGX nodes with 8 H100 GPUs.
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
#### Software
|
| 286 |
+
|
| 287 |
+
Our training code and inference code was implemented in Pytorch.
|
| 288 |
+
|
| 289 |
+
## Citation
|
| 290 |
+
|
| 291 |
+
```
|
| 292 |
+
@article{kyutai2025casa,
|
| 293 |
+
author = {Moritz Böhle and Amélie Royer and Juliette Marrie and Edouard Grave and Patrick Pérez},
|
| 294 |
+
year = {2025},
|
| 295 |
+
title = {CASA: Cross-Attention vis Self-Attention},
|
| 296 |
+
journal = {ArXiv},
|
| 297 |
+
url = {https://arxiv.org/abs/2512.19535}
|
| 298 |
+
}
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
## Model Card Authors and Contact
|
| 303 |
+
|
| 304 |
+
* Amelie Royer
|
| 305 |
+
* Moritz Boehle
|
| 306 |
+
* Juliette Marrie
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
[blog]: https://kyutai.org/casa
|
| 310 |
+
[casa-arxiv]: https://arxiv.org/abs/2512.19535
|
| 311 |
+
[casa-git]: https://github.com/kyutai-labs/casa
|
__init__.py
ADDED
|
File without changes
|
casa_attention.py
ADDED
|
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CASA layers"""
|
| 2 |
+
|
| 3 |
+
import bisect
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from itertools import accumulate
|
| 6 |
+
from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload
|
| 7 |
+
from typing import cast as type_cast
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 11 |
+
|
| 12 |
+
from .utils import StreamingModule, StreamingState, delta_w_factory
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from flash_attn import flash_attn_varlen_func
|
| 19 |
+
except ImportError:
|
| 20 |
+
flash_attn_varlen_func = None # type: ignore
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
WindowsComputeKwargs = TypedDict(
|
| 24 |
+
"WindowsComputeKwargs",
|
| 25 |
+
{
|
| 26 |
+
"num_post_image_tokens": int,
|
| 27 |
+
"num_pre_image_tokens": int,
|
| 28 |
+
},
|
| 29 |
+
total=False,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def __split_n_merge__(
|
| 34 |
+
x: torch.Tensor,
|
| 35 |
+
sample_lengths: list[int],
|
| 36 |
+
padding_side: Literal["left", "right"] = "right",
|
| 37 |
+
pad_value: int | float | bool = 0,
|
| 38 |
+
) -> torch.Tensor:
|
| 39 |
+
max_sample_length = max(sample_lengths)
|
| 40 |
+
pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2))
|
| 41 |
+
return torch.stack(
|
| 42 |
+
[
|
| 43 |
+
torch.nn.functional.pad(
|
| 44 |
+
_x,
|
| 45 |
+
pad_tuple + (0, max_sample_length - _x.shape[0])
|
| 46 |
+
if padding_side == "right"
|
| 47 |
+
else pad_tuple + (max_sample_length - _x.shape[0], 0),
|
| 48 |
+
value=pad_value,
|
| 49 |
+
)
|
| 50 |
+
for _x in torch.split(x, sample_lengths, dim=0)
|
| 51 |
+
],
|
| 52 |
+
dim=0,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@overload
|
| 57 |
+
def insert_image_tokens(
|
| 58 |
+
inputs_embeds: torch.Tensor,
|
| 59 |
+
image_embeds: torch.Tensor | Sequence[torch.Tensor],
|
| 60 |
+
image_embeds_insertion_points: list[torch.Tensor],
|
| 61 |
+
recover_batch_dim: Literal[True],
|
| 62 |
+
attention_mask: torch.Tensor | None = None,
|
| 63 |
+
padding_side: Literal["left", "right"] = "right",
|
| 64 |
+
keep_only_attended: bool = False,
|
| 65 |
+
pad_output: int | float | bool = 0.0,
|
| 66 |
+
) -> tuple[
|
| 67 |
+
torch.Tensor,
|
| 68 |
+
None,
|
| 69 |
+
torch.Tensor | None,
|
| 70 |
+
torch.Tensor,
|
| 71 |
+
]: ...
|
| 72 |
+
@overload
|
| 73 |
+
def insert_image_tokens(
|
| 74 |
+
inputs_embeds: torch.Tensor,
|
| 75 |
+
image_embeds: torch.Tensor | Sequence[torch.Tensor],
|
| 76 |
+
image_embeds_insertion_points: list[torch.Tensor],
|
| 77 |
+
recover_batch_dim: Literal[False],
|
| 78 |
+
attention_mask: torch.Tensor | None = None,
|
| 79 |
+
padding_side: Literal["left", "right"] = "right",
|
| 80 |
+
keep_only_attended: bool = False,
|
| 81 |
+
pad_output: int | float | bool = 0.0,
|
| 82 |
+
) -> tuple[
|
| 83 |
+
torch.Tensor,
|
| 84 |
+
list[int],
|
| 85 |
+
torch.Tensor | None,
|
| 86 |
+
torch.Tensor,
|
| 87 |
+
]: ...
|
| 88 |
+
def insert_image_tokens(
|
| 89 |
+
inputs_embeds: torch.Tensor,
|
| 90 |
+
image_embeds: torch.Tensor | Sequence[torch.Tensor],
|
| 91 |
+
image_embeds_insertion_points: list[torch.Tensor],
|
| 92 |
+
recover_batch_dim: bool = True,
|
| 93 |
+
attention_mask: torch.Tensor | None = None,
|
| 94 |
+
padding_side: Literal["left", "right"] = "right",
|
| 95 |
+
keep_only_attended: bool = False,
|
| 96 |
+
pad_output: int | float | bool = 0.0,
|
| 97 |
+
) -> tuple[
|
| 98 |
+
torch.Tensor | torch.Tensor,
|
| 99 |
+
list[int] | None,
|
| 100 |
+
torch.Tensor | torch.Tensor | None,
|
| 101 |
+
torch.Tensor | torch.Tensor,
|
| 102 |
+
]:
|
| 103 |
+
"""
|
| 104 |
+
Insert image embeddings into text embeddings
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
inputs_embeds (torch.Tensor): (B, S, D) input token embeddings.
|
| 108 |
+
image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings.
|
| 109 |
+
image_embeds_insertion_points (list[torch.Tensor]): Insertion indices.
|
| 110 |
+
attention_mask (torch.Tensor, optional): (B, S) attention mask.
|
| 111 |
+
padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images.
|
| 112 |
+
return_indices (bool): Whether to return gather indices or the fused sequence directly.
|
| 113 |
+
keep_only_attended: This is only applicable when recover_batch_dim is False; whether to
|
| 114 |
+
remove any non-attended tokens in the whole array. In this case, the attention
|
| 115 |
+
mask returned is **still the original one**, so we can remember which indices have been
|
| 116 |
+
removed
|
| 117 |
+
Returns:
|
| 118 |
+
output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence
|
| 119 |
+
image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list
|
| 120 |
+
attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding.
|
| 121 |
+
image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions.
|
| 122 |
+
"""
|
| 123 |
+
if isinstance(image_embeds, list) and len(image_embeds) == 0:
|
| 124 |
+
batch_size, text_seq_length, token_dim = inputs_embeds.shape
|
| 125 |
+
if recover_batch_dim:
|
| 126 |
+
return (
|
| 127 |
+
inputs_embeds,
|
| 128 |
+
None,
|
| 129 |
+
attention_mask,
|
| 130 |
+
torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool),
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1]
|
| 134 |
+
return (
|
| 135 |
+
torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])),
|
| 136 |
+
[text_seq_length] * inputs_embeds.shape[0],
|
| 137 |
+
attention_mask.flatten() if attention_mask is not None else None,
|
| 138 |
+
torch.zeros((flattened_seq_length, 1), dtype=torch.bool),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Sanity checks
|
| 142 |
+
if isinstance(image_embeds, torch.Tensor):
|
| 143 |
+
assert inputs_embeds.shape[-1] == image_embeds.shape[-1]
|
| 144 |
+
else:
|
| 145 |
+
assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds)
|
| 146 |
+
|
| 147 |
+
batch_size, text_seq_length, token_dim = inputs_embeds.shape
|
| 148 |
+
image_seq_length = [x.shape[0] for x in image_embeds]
|
| 149 |
+
|
| 150 |
+
# Flatten insertion points
|
| 151 |
+
insertion_offset = []
|
| 152 |
+
counter, offset_from_text, offset_from_image = 0, 0, 0
|
| 153 |
+
for sample in image_embeds_insertion_points:
|
| 154 |
+
for pt in sample:
|
| 155 |
+
insertion_offset.append(pt + offset_from_image + offset_from_text)
|
| 156 |
+
offset_from_image += image_seq_length[counter]
|
| 157 |
+
counter += 1
|
| 158 |
+
offset_from_text += text_seq_length
|
| 159 |
+
image_insert_positions = [
|
| 160 |
+
x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx])
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# Flatten image embeds
|
| 164 |
+
if isinstance(image_embeds, list):
|
| 165 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
| 166 |
+
else:
|
| 167 |
+
image_embeds = type_cast(torch.Tensor, image_embeds)
|
| 168 |
+
image_embeds = torch.reshape(image_embeds, (-1, token_dim))
|
| 169 |
+
|
| 170 |
+
# Flatten text embeds across batch dim (B x S, D)
|
| 171 |
+
inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim))
|
| 172 |
+
flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length)
|
| 173 |
+
text_insert_positions = sorted(
|
| 174 |
+
set(range(flattened_seq_length)).difference(set(image_insert_positions))
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Scatter image embeds in the flattened dict
|
| 178 |
+
# scatter text related stuff
|
| 179 |
+
output = torch.empty(
|
| 180 |
+
(flattened_seq_length, token_dim),
|
| 181 |
+
device=inputs_embeds.device,
|
| 182 |
+
dtype=inputs_embeds.dtype,
|
| 183 |
+
)
|
| 184 |
+
txt_positions_tensor = torch.Tensor(text_insert_positions).to(
|
| 185 |
+
dtype=torch.long, device=inputs_embeds.device
|
| 186 |
+
)
|
| 187 |
+
output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds)
|
| 188 |
+
attention_mask_new: torch.Tensor | None = None
|
| 189 |
+
if attention_mask is not None:
|
| 190 |
+
attention_mask_new = torch.ones(
|
| 191 |
+
(flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
|
| 192 |
+
)
|
| 193 |
+
attention_mask_new.scatter_(
|
| 194 |
+
0, txt_positions_tensor, attention_mask.flatten().to(torch.bool)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# scatter image related stuff
|
| 198 |
+
image_tokens_mask = torch.zeros(
|
| 199 |
+
(flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
|
| 200 |
+
)
|
| 201 |
+
img_positions_tensor = torch.Tensor(image_insert_positions).to(
|
| 202 |
+
device=inputs_embeds.device, dtype=torch.long
|
| 203 |
+
)
|
| 204 |
+
output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds)
|
| 205 |
+
image_tokens_mask.scatter_(0, img_positions_tensor, True)
|
| 206 |
+
|
| 207 |
+
# Compute expected sample length, taking into account the real batch
|
| 208 |
+
# i.e. recover the batch dimension of image embeddings
|
| 209 |
+
sample_lengths = []
|
| 210 |
+
counter = 0
|
| 211 |
+
for sample_idx, pts in enumerate(image_embeds_insertion_points):
|
| 212 |
+
num_image_tokens = 0
|
| 213 |
+
for _ in pts:
|
| 214 |
+
num_image_tokens += image_seq_length[counter]
|
| 215 |
+
counter += 1
|
| 216 |
+
if keep_only_attended and attention_mask is not None:
|
| 217 |
+
attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item()
|
| 218 |
+
sample_lengths.append(attended_seq_length + num_image_tokens)
|
| 219 |
+
else:
|
| 220 |
+
sample_lengths.append(text_seq_length + num_image_tokens)
|
| 221 |
+
|
| 222 |
+
# For CASA attention, we can keep stuff flatten ad return
|
| 223 |
+
# the sample_lengths for the blockwise attention
|
| 224 |
+
if not recover_batch_dim:
|
| 225 |
+
if keep_only_attended and attention_mask_new is not None:
|
| 226 |
+
output = output[attention_mask_new]
|
| 227 |
+
image_tokens_mask = image_tokens_mask[attention_mask_new]
|
| 228 |
+
return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None]
|
| 229 |
+
|
| 230 |
+
# Otherwise, time to (pad) and reshape
|
| 231 |
+
# Easy case: everything has the same length
|
| 232 |
+
if all(x == sample_lengths[0] for x in sample_lengths):
|
| 233 |
+
output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim))
|
| 234 |
+
image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1))
|
| 235 |
+
if attention_mask_new is not None:
|
| 236 |
+
attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0]))
|
| 237 |
+
# if there is any size mismatch we break into a
|
| 238 |
+
# list and pad again
|
| 239 |
+
else:
|
| 240 |
+
# split and merge
|
| 241 |
+
output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output)
|
| 242 |
+
# note that the extra padding tokens are also marked as image tokens to be removed later
|
| 243 |
+
image_tokens_mask = __split_n_merge__(
|
| 244 |
+
image_tokens_mask, sample_lengths, padding_side, True
|
| 245 |
+
)[:, :, None]
|
| 246 |
+
if attention_mask_new is not None:
|
| 247 |
+
attention_mask_new = __split_n_merge__(
|
| 248 |
+
attention_mask_new, sample_lengths, padding_side, 0
|
| 249 |
+
)
|
| 250 |
+
# Return
|
| 251 |
+
return output, sample_lengths, attention_mask_new, image_tokens_mask
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def get_sample_lengths_from_insertion_points(
|
| 255 |
+
image_embeds_insertion_points: list[torch.Tensor],
|
| 256 |
+
image_embeds: torch.Tensor | list[torch.Tensor] | None,
|
| 257 |
+
total_seq_len: int | None = None,
|
| 258 |
+
attention_mask: torch.Tensor | None = None,
|
| 259 |
+
**kwargs: WindowsComputeKwargs,
|
| 260 |
+
) -> tuple[list[tuple[int, bool]], list[int]]:
|
| 261 |
+
"""Compute sample lengths as if each image insertion point defines a
|
| 262 |
+
new document (ex document ID)
|
| 263 |
+
"""
|
| 264 |
+
num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0))
|
| 265 |
+
num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0))
|
| 266 |
+
squashed_samples_lengths = type_cast(
|
| 267 |
+
list[list[int]] | None, kwargs.get("squashed_samples_lengths", None)
|
| 268 |
+
)
|
| 269 |
+
if squashed_samples_lengths is not None:
|
| 270 |
+
assert len(squashed_samples_lengths) == len(image_embeds_insertion_points)
|
| 271 |
+
|
| 272 |
+
def __insert_next_sample__(
|
| 273 |
+
batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False
|
| 274 |
+
) -> None:
|
| 275 |
+
nonlocal attention_mask
|
| 276 |
+
nonlocal text_sample_lengths, full_sample_lengths
|
| 277 |
+
nonlocal cum_samples_lengths, current_image_offset
|
| 278 |
+
nonlocal last_image_idx, current_image_idx, current_length
|
| 279 |
+
# Add the sample between [last_insrt_pt, insrt_pt] with breaks in
|
| 280 |
+
# between any squashed samples we find on the way
|
| 281 |
+
start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt)
|
| 282 |
+
added_sample = False
|
| 283 |
+
for end_of_sample in cum_samples_lengths[start_pt:]:
|
| 284 |
+
# we will break the loop at the end when end_of_sample = insrt_pt
|
| 285 |
+
end_of_sample = min(end_of_sample, insrt_pt)
|
| 286 |
+
|
| 287 |
+
# Add between [last_insrt_pt, end_of_sample]
|
| 288 |
+
current_length = end_of_sample - last_insrt_pt
|
| 289 |
+
if attention_mask is not None:
|
| 290 |
+
current_length -= int(
|
| 291 |
+
torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item()
|
| 292 |
+
)
|
| 293 |
+
if current_length > 0:
|
| 294 |
+
added_sample = True
|
| 295 |
+
text_sample_lengths.append(
|
| 296 |
+
(current_length, end_of_batch_sample and insrt_pt == end_of_sample)
|
| 297 |
+
)
|
| 298 |
+
# add image tokens to current_length
|
| 299 |
+
if current_image_idx > 0 and image_embeds is not None:
|
| 300 |
+
images_in_sample = [
|
| 301 |
+
img_idx
|
| 302 |
+
for img_idx in range(last_image_idx, current_image_idx)
|
| 303 |
+
if img_idx < len(image_embeds_insertion_points[batch_idx])
|
| 304 |
+
and last_insrt_pt
|
| 305 |
+
<= image_embeds_insertion_points[batch_idx][img_idx]
|
| 306 |
+
< end_of_sample
|
| 307 |
+
]
|
| 308 |
+
if len(images_in_sample) > 0:
|
| 309 |
+
num_image_tokens = sum(
|
| 310 |
+
_x.shape[0]
|
| 311 |
+
for _x in image_embeds[
|
| 312 |
+
current_image_offset + images_in_sample[0] : current_image_offset
|
| 313 |
+
+ images_in_sample[-1]
|
| 314 |
+
+ 1
|
| 315 |
+
]
|
| 316 |
+
)
|
| 317 |
+
current_length += num_image_tokens
|
| 318 |
+
full_sample_lengths.append(current_length)
|
| 319 |
+
|
| 320 |
+
# prepare for next loop
|
| 321 |
+
last_insrt_pt = end_of_sample
|
| 322 |
+
if end_of_sample == insrt_pt:
|
| 323 |
+
break
|
| 324 |
+
# End of loop: Catching weird use case where we may end up on a span
|
| 325 |
+
# full of padding tokens which will not get added due to current_length > 0
|
| 326 |
+
if end_of_batch_sample:
|
| 327 |
+
assert added_sample, "Weird edge case. Don't do that, thank you"
|
| 328 |
+
text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
|
| 329 |
+
|
| 330 |
+
# End of loop: Catching weird use case where we may end up on a span
|
| 331 |
+
# full of padding tokens which will not get added due to current_length > 0
|
| 332 |
+
if end_of_batch_sample:
|
| 333 |
+
assert added_sample, "Weird edge case. Don't do that, thank you"
|
| 334 |
+
text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
|
| 335 |
+
|
| 336 |
+
current_image_offset = 0
|
| 337 |
+
text_sample_lengths, full_sample_lengths = [], []
|
| 338 |
+
cum_samples_lengths: list[int] = []
|
| 339 |
+
current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
|
| 340 |
+
for batch_idx, pts in enumerate(image_embeds_insertion_points):
|
| 341 |
+
if squashed_samples_lengths is not None:
|
| 342 |
+
cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx]))
|
| 343 |
+
else:
|
| 344 |
+
assert total_seq_len is not None
|
| 345 |
+
cum_samples_lengths = [total_seq_len]
|
| 346 |
+
|
| 347 |
+
for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()):
|
| 348 |
+
# check if the images are consecutive in which way we want
|
| 349 |
+
# them to belong to the same window
|
| 350 |
+
if current_image_idx >= 1 and insrt_pt == (
|
| 351 |
+
image_embeds_insertion_points[batch_idx][current_image_idx - 1]
|
| 352 |
+
+ num_pre_image_tokens
|
| 353 |
+
+ num_post_image_tokens
|
| 354 |
+
):
|
| 355 |
+
continue
|
| 356 |
+
# Otherwise, we found a new sample
|
| 357 |
+
# not very important but for completeness: the insertion points come *after*
|
| 358 |
+
# the pre-image tokens per design but for the document-id mask it is more consistent to
|
| 359 |
+
# have them correspond to the same image
|
| 360 |
+
insrt_pt -= num_pre_image_tokens
|
| 361 |
+
|
| 362 |
+
# Update text and full sample lengths
|
| 363 |
+
if insrt_pt > last_insrt_pt:
|
| 364 |
+
__insert_next_sample__(
|
| 365 |
+
batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False
|
| 366 |
+
)
|
| 367 |
+
last_image_idx = current_image_idx
|
| 368 |
+
last_insrt_pt = insrt_pt
|
| 369 |
+
|
| 370 |
+
# End of batch: add sample in progress and reset
|
| 371 |
+
current_image_idx += 1
|
| 372 |
+
if cum_samples_lengths[-1] > last_insrt_pt:
|
| 373 |
+
__insert_next_sample__(
|
| 374 |
+
batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True
|
| 375 |
+
)
|
| 376 |
+
current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
|
| 377 |
+
current_image_offset += len(pts)
|
| 378 |
+
|
| 379 |
+
# Sanity checks that the is_eob are correctly place
|
| 380 |
+
assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), (
|
| 381 |
+
f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs"
|
| 382 |
+
f" from original batch size ({len(image_embeds_insertion_points)})"
|
| 383 |
+
)
|
| 384 |
+
return text_sample_lengths, full_sample_lengths
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class CASAAttentionHandler:
|
| 388 |
+
def __init__(
|
| 389 |
+
self,
|
| 390 |
+
inputs_embeds: torch.Tensor,
|
| 391 |
+
image_embeds: torch.Tensor | list[torch.Tensor],
|
| 392 |
+
image_embeds_insertion_points: list[torch.Tensor],
|
| 393 |
+
attention_mask: torch.Tensor | None = None,
|
| 394 |
+
rope_fn: Callable | None = None,
|
| 395 |
+
windows: Literal["batch", "squashed", "images", "turn_based"] = "images",
|
| 396 |
+
use_asymetric_q_kv: bool = True,
|
| 397 |
+
casa_windows_info: None | dict = None,
|
| 398 |
+
):
|
| 399 |
+
"""Initialize the structure holding the query buffer for CASA attention layers
|
| 400 |
+
(ie the **flattened** text+image inserted tokens).
|
| 401 |
+
Note that this structure is shared across all casa layers, and it gets updated
|
| 402 |
+
with the current hidden states at every layer; this is merely a buffer to keep
|
| 403 |
+
scatter_ operations in-plae as much as possible
|
| 404 |
+
|
| 405 |
+
In this module, the embeddings related values (image_tokens_mask,
|
| 406 |
+
text_sample_lengths etc) are stored under the assumption of a tensor
|
| 407 |
+
which is *flatened* and *witout padding tokens*
|
| 408 |
+
Only the attention mask is kept as-is (text-only, batched, padded) to
|
| 409 |
+
be able to recover original shapes when needed
|
| 410 |
+
"""
|
| 411 |
+
super().__init__()
|
| 412 |
+
assert windows == "images" # for inference code release
|
| 413 |
+
# Note 1: Unless overriden, text/full_sample_lengths are defined such that one
|
| 414 |
+
# document = one sample in the batch
|
| 415 |
+
if attention_mask is None:
|
| 416 |
+
text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds]
|
| 417 |
+
else:
|
| 418 |
+
text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask]
|
| 419 |
+
(
|
| 420 |
+
full_inputs_embeds,
|
| 421 |
+
full_sample_lengths,
|
| 422 |
+
# Full attention mask is only needed at inference to
|
| 423 |
+
# flatten the KV-Cache and remove padding tokens
|
| 424 |
+
_,
|
| 425 |
+
self.image_tokens_mask,
|
| 426 |
+
) = insert_image_tokens(
|
| 427 |
+
inputs_embeds=inputs_embeds,
|
| 428 |
+
image_embeds=image_embeds,
|
| 429 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 430 |
+
attention_mask=attention_mask,
|
| 431 |
+
recover_batch_dim=False,
|
| 432 |
+
keep_only_attended=attention_mask is not None,
|
| 433 |
+
)
|
| 434 |
+
assert self.image_tokens_mask.ndim == 2
|
| 435 |
+
self.image_embeds = image_embeds
|
| 436 |
+
self.image_embeds_insertion_points = image_embeds_insertion_points
|
| 437 |
+
self.attention_mask = None if attention_mask is None else attention_mask.bool()
|
| 438 |
+
self.use_asymetric_qkv = use_asymetric_q_kv
|
| 439 |
+
# At inference, we have to use asymetric QKV for efficiency
|
| 440 |
+
if self.attention_mask is not None:
|
| 441 |
+
self.use_asymetric_qkv = True
|
| 442 |
+
|
| 443 |
+
# Build CASA windows
|
| 444 |
+
assert casa_windows_info is not None
|
| 445 |
+
text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points(
|
| 446 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 447 |
+
image_embeds=image_embeds,
|
| 448 |
+
total_seq_len=inputs_embeds.shape[1],
|
| 449 |
+
attention_mask=self.attention_mask,
|
| 450 |
+
**casa_windows_info, # pyright: ignore
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Sanity checks on the sample lengths
|
| 454 |
+
self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0]
|
| 455 |
+
self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0]
|
| 456 |
+
|
| 457 |
+
assert len(self.text_sample_lengths) == len(self.full_sample_lengths), (
|
| 458 |
+
f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}"
|
| 459 |
+
f" != full sample lengths {len(self.full_sample_lengths)}"
|
| 460 |
+
)
|
| 461 |
+
if self.attention_mask is None:
|
| 462 |
+
num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1]
|
| 463 |
+
else:
|
| 464 |
+
num_unpadded_text_tokens = int(
|
| 465 |
+
torch.sum(type_cast(torch.Tensor, attention_mask)).item()
|
| 466 |
+
)
|
| 467 |
+
assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, (
|
| 468 |
+
f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
|
| 469 |
+
)
|
| 470 |
+
assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], (
|
| 471 |
+
f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Finally we can compute cu_seqlen based on sample lengths
|
| 475 |
+
self.max_seqlen_q = max(self.text_sample_lengths)[0]
|
| 476 |
+
self.cu_seqlens_q = self.get_cu_seqlens(
|
| 477 |
+
[x[0] for x in self.text_sample_lengths], device=inputs_embeds.device
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
self.max_seqlen_kv = max(self.full_sample_lengths)
|
| 481 |
+
self.cu_seqlens_kv = self.get_cu_seqlens(
|
| 482 |
+
self.full_sample_lengths, device=inputs_embeds.device
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# For inference: We save the length of the current document
|
| 486 |
+
# to trim the KV cache appropriately
|
| 487 |
+
self.current_doc_lengths = self.full_sample_lengths
|
| 488 |
+
|
| 489 |
+
# Precompute position embeddings
|
| 490 |
+
self.position_embeds = None
|
| 491 |
+
self.rope_fn = rope_fn
|
| 492 |
+
if self.rope_fn is not None:
|
| 493 |
+
self.position_embeds = self.compute_position_embeddings(
|
| 494 |
+
self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
@property
|
| 498 |
+
def batch_lengths(self) -> list[int]:
|
| 499 |
+
"""Return a (batch_size,) list of integers containing the
|
| 500 |
+
number of (non-padded) text tokens for each sample in the batch"""
|
| 501 |
+
bls = [0]
|
| 502 |
+
for ln, eob in self.text_sample_lengths:
|
| 503 |
+
bls[-1] += ln
|
| 504 |
+
if eob:
|
| 505 |
+
bls.append(0)
|
| 506 |
+
return bls[:-1]
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def full_batch_lengths(self) -> list[int]:
|
| 510 |
+
"""Same as batch_lengths for text+image tokens"""
|
| 511 |
+
bls = [0]
|
| 512 |
+
for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths):
|
| 513 |
+
bls[-1] += ln
|
| 514 |
+
if eob:
|
| 515 |
+
bls.append(0)
|
| 516 |
+
return bls[:-1]
|
| 517 |
+
|
| 518 |
+
def get_cu_seqlens(
|
| 519 |
+
self, sample_lengths: list[int], device: torch.device | None
|
| 520 |
+
) -> torch.Tensor:
|
| 521 |
+
"""Update cu_seqlengths according to the given sample_lengths"""
|
| 522 |
+
return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to(
|
| 523 |
+
dtype=torch.int32, device=device
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def compute_position_embeddings(
|
| 527 |
+
self,
|
| 528 |
+
rope_fn: Callable,
|
| 529 |
+
sample_lengths: list[int],
|
| 530 |
+
dummy_for_dtype_and_device: torch.Tensor,
|
| 531 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 532 |
+
"""Compute info required for position embeddings. Can be override e.g. for Qwen"""
|
| 533 |
+
# option 1: Standard range
|
| 534 |
+
# position_ids = torch.arange(0, full_inputs_embeds.shape[0])
|
| 535 |
+
# option 2: Follows document boundary
|
| 536 |
+
position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0)
|
| 537 |
+
return rope_fn(
|
| 538 |
+
dummy_for_dtype_and_device,
|
| 539 |
+
position_ids.to(dummy_for_dtype_and_device.device)[None, ...],
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
def get_position_embedding(
|
| 543 |
+
self,
|
| 544 |
+
key: Literal["q", "kv"],
|
| 545 |
+
num_queries: int = 0,
|
| 546 |
+
) -> tuple[torch.Tensor, torch.Tensor] | None:
|
| 547 |
+
if self.position_embeds is None:
|
| 548 |
+
return None
|
| 549 |
+
cos, sin = self.position_embeds
|
| 550 |
+
bls = self.full_batch_lengths
|
| 551 |
+
# For Q, we only want the text-only posembeds
|
| 552 |
+
if key == "q" and self.use_asymetric_qkv:
|
| 553 |
+
bls = self.batch_lengths
|
| 554 |
+
cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]]
|
| 555 |
+
elif key not in {"q", "kv"}:
|
| 556 |
+
raise ValueError(f"Unknow for position embedding {key}")
|
| 557 |
+
|
| 558 |
+
# Easy case: training or first step at inference: we use all the posembeds
|
| 559 |
+
if num_queries == 0:
|
| 560 |
+
return cos, sin
|
| 561 |
+
# If num queries is given, we need to trim for *every sample in the batch*
|
| 562 |
+
cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)]
|
| 563 |
+
sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)]
|
| 564 |
+
return torch.cat(cos, dim=1), torch.cat(sin, dim=1)
|
| 565 |
+
|
| 566 |
+
def get_full_embeds(
|
| 567 |
+
self, hidden_states: torch.Tensor, norm_fn: Callable | None
|
| 568 |
+
) -> torch.Tensor:
|
| 569 |
+
"""Update attended hidden states in the current query buffer
|
| 570 |
+
|
| 571 |
+
:param hidden_states: (b, s, d) Tensor input to the CASA attention layer"
|
| 572 |
+
"""
|
| 573 |
+
assert self.image_embeds is not None
|
| 574 |
+
return insert_image_tokens(
|
| 575 |
+
inputs_embeds=hidden_states,
|
| 576 |
+
image_embeds=self.image_embeds
|
| 577 |
+
if norm_fn is None
|
| 578 |
+
else norm_fn(self.image_embeds)
|
| 579 |
+
if isinstance(self.image_embeds, torch.Tensor)
|
| 580 |
+
else [norm_fn(_x) for _x in self.image_embeds],
|
| 581 |
+
image_embeds_insertion_points=self.image_embeds_insertion_points,
|
| 582 |
+
attention_mask=self.attention_mask,
|
| 583 |
+
recover_batch_dim=False,
|
| 584 |
+
keep_only_attended=self.attention_mask is not None,
|
| 585 |
+
)[0][None, :, :]
|
| 586 |
+
|
| 587 |
+
def recover_text_embeds(
|
| 588 |
+
self,
|
| 589 |
+
hidden_states_out: torch.Tensor,
|
| 590 |
+
hidden_states_in: torch.Tensor,
|
| 591 |
+
update_image_embeddings: bool = False,
|
| 592 |
+
) -> torch.Tensor:
|
| 593 |
+
"""Returns text embeddings from the query buffer, including non-attended tokens at inference"""
|
| 594 |
+
if update_image_embeddings and not self.use_asymetric_qkv:
|
| 595 |
+
raise NotImplementedError("Implement image embeddings updates for asymetric QKV")
|
| 596 |
+
# Remove image tokens in the symetric case
|
| 597 |
+
if not self.use_asymetric_qkv:
|
| 598 |
+
hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]]
|
| 599 |
+
|
| 600 |
+
# if there's not attention mask, we are in the right padded case
|
| 601 |
+
# (keep_only_attended = False) we can directly return the query
|
| 602 |
+
# outputs (which don't contain the image)
|
| 603 |
+
if self.attention_mask is None:
|
| 604 |
+
return hidden_states_out
|
| 605 |
+
|
| 606 |
+
# Otherwise, we need to "scatter" back only the text-attended tokens to the original
|
| 607 |
+
# hidden states, which contain the paddings
|
| 608 |
+
num_queries = hidden_states_in.shape[1]
|
| 609 |
+
|
| 610 |
+
# Case 1: the padded hidden_states_in is larger than hidden_states_out
|
| 611 |
+
# we rebatch+pad hidden_state_out before doing the scattering
|
| 612 |
+
if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]:
|
| 613 |
+
s = torch.split(hidden_states_out, self.batch_lengths, dim=0)
|
| 614 |
+
assert max(_s.shape[0] for _s in s) <= num_queries # sanity check
|
| 615 |
+
s = [
|
| 616 |
+
torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0)
|
| 617 |
+
for _s in s
|
| 618 |
+
]
|
| 619 |
+
return torch.where(
|
| 620 |
+
self.attention_mask[:, -num_queries:, None],
|
| 621 |
+
torch.stack(s),
|
| 622 |
+
hidden_states_in,
|
| 623 |
+
)
|
| 624 |
+
# If both have the smae shape, it means hidden_states_in contained no padding
|
| 625 |
+
# so we can directly return hidden states out
|
| 626 |
+
return hidden_states_out
|
| 627 |
+
|
| 628 |
+
def extend(self, num_tokens: int, offset: int = 0):
|
| 629 |
+
"""Extend all necessary values of the Handler for infenrece
|
| 630 |
+
Note: this implementation curently assumes a single conversation at a time
|
| 631 |
+
(otherwise image tokens mask would have to change) and that tokens added are
|
| 632 |
+
attended to"""
|
| 633 |
+
# image embeds is inserted in the first step and stored in the KV cache
|
| 634 |
+
self.image_embeds = None
|
| 635 |
+
|
| 636 |
+
# Update attention mask (non-flattened) (assumes all new tokens are attended to)
|
| 637 |
+
if self.attention_mask is not None:
|
| 638 |
+
self.attention_mask = torch.nn.functional.pad(
|
| 639 |
+
self.attention_mask, (0, num_tokens), value=1
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# Update image token mask (assumes only one image/conversation
|
| 643 |
+
# is started at once so that we always extend by zero)
|
| 644 |
+
# Note that the mask is stored flattened to avoid padding so we have to
|
| 645 |
+
# do something a bit ugly and inefficient here
|
| 646 |
+
imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0)
|
| 647 |
+
imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask]
|
| 648 |
+
self.image_tokens_mask = torch.cat(imtokmask, dim=0)
|
| 649 |
+
|
| 650 |
+
# Recompute cumulative document lengths after assigning the new
|
| 651 |
+
# number of tokens to each sample in the batch
|
| 652 |
+
for idx, (ln, is_eob) in enumerate(self.text_sample_lengths):
|
| 653 |
+
if is_eob:
|
| 654 |
+
self.text_sample_lengths[idx] = (num_tokens + ln, is_eob)
|
| 655 |
+
self.full_sample_lengths[idx] += num_tokens
|
| 656 |
+
|
| 657 |
+
# Recompute cu sequlen
|
| 658 |
+
# First step: Technically this never occurs, but we keep it for completeness
|
| 659 |
+
if offset == 0:
|
| 660 |
+
self.max_seqlen_q = max(self.text_sample_lengths)[0]
|
| 661 |
+
self.cu_seqlens_q = self.get_cu_seqlens(
|
| 662 |
+
[x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
self.max_seqlen_kv = max(self.full_sample_lengths)
|
| 666 |
+
self.cu_seqlens_kv = self.get_cu_seqlens(
|
| 667 |
+
self.full_sample_lengths, device=self.cu_seqlens_kv.device
|
| 668 |
+
)
|
| 669 |
+
# Step > 0: the annoying part is since flashattn_varlen does not accept
|
| 670 |
+
# 0-len documents, we need to remove documents from the KV Cache when they're past
|
| 671 |
+
# their windows. In our current setting, this means we only want to keep the latest
|
| 672 |
+
# documents
|
| 673 |
+
else:
|
| 674 |
+
self.max_seqlen_q = num_tokens
|
| 675 |
+
self.cu_seqlens_q = self.get_cu_seqlens(
|
| 676 |
+
[num_tokens for (_, eob) in self.text_sample_lengths if eob],
|
| 677 |
+
device=self.cu_seqlens_q.device,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
final_doc_lengths = [
|
| 681 |
+
ln
|
| 682 |
+
for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths)
|
| 683 |
+
if eob
|
| 684 |
+
]
|
| 685 |
+
self.current_doc_lengths = final_doc_lengths
|
| 686 |
+
self.max_seqlen_kv = max(self.current_doc_lengths)
|
| 687 |
+
self.cu_seqlens_kv = self.get_cu_seqlens(
|
| 688 |
+
final_doc_lengths,
|
| 689 |
+
device=self.cu_seqlens_kv.device,
|
| 690 |
+
)
|
| 691 |
+
# Update position embeddings
|
| 692 |
+
if self.rope_fn is not None and self.position_embeds is not None:
|
| 693 |
+
self.position_embeds = self.compute_position_embeddings(
|
| 694 |
+
self.rope_fn,
|
| 695 |
+
self.full_sample_lengths,
|
| 696 |
+
dummy_for_dtype_and_device=self.position_embeds[0],
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
@dataclass
|
| 701 |
+
class CASAAttentionStreamingState(StreamingState):
|
| 702 |
+
"""Streaming State for CASA Atention module. Keep the hidden"""
|
| 703 |
+
|
| 704 |
+
k: torch.Tensor = None # pyright: ignore[reportAssignmentType]
|
| 705 |
+
v: torch.Tensor = None # pyright: ignore[reportAssignmentType]
|
| 706 |
+
recover_batched_trims: list[int] = None # pyright: ignore[reportAssignmentType]
|
| 707 |
+
casa_handler: CASAAttentionHandler = None # pyright: ignore[reportAssignmentType]
|
| 708 |
+
|
| 709 |
+
def maybe_get_casa_handler(
|
| 710 |
+
self,
|
| 711 |
+
casa_handler: CASAAttentionHandler | None,
|
| 712 |
+
is_first_casa_layer: bool = False,
|
| 713 |
+
num_queries: int = -1,
|
| 714 |
+
) -> CASAAttentionHandler | None:
|
| 715 |
+
# Set given Casa Handler the first time we reach this
|
| 716 |
+
if self.casa_handler is None:
|
| 717 |
+
self.casa_handler = casa_handler # pyright: ignore
|
| 718 |
+
# subsequent calls: we need to extend shape to accomodate new tokens
|
| 719 |
+
# however because CASA handler is shared across layers, we only need to do it once
|
| 720 |
+
if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer:
|
| 721 |
+
# since CasaHandler is shared, we only use its extend step once
|
| 722 |
+
self.casa_handler.extend(num_queries, offset=self.offset)
|
| 723 |
+
return self.casa_handler
|
| 724 |
+
|
| 725 |
+
def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor:
|
| 726 |
+
"""Recover batched key/value states with left padding"""
|
| 727 |
+
s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1)
|
| 728 |
+
mlen = max(_s.shape[1] for _s in s)
|
| 729 |
+
# Remember the added padding so that we can re-flatten KV later
|
| 730 |
+
if self.recover_batched_trims is None:
|
| 731 |
+
self.recover_batched_trims = [mlen - _s.shape[1] for _s in s]
|
| 732 |
+
s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s]
|
| 733 |
+
return torch.cat(s, dim=0)
|
| 734 |
+
|
| 735 |
+
def __get_flattened_kv__(
|
| 736 |
+
self, k: torch.Tensor | None = None, v: torch.Tensor | None = None
|
| 737 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 738 |
+
"""
|
| 739 |
+
Flattened and remove padding to act with flash_attn_func
|
| 740 |
+
"""
|
| 741 |
+
k = self.k if k is None else k
|
| 742 |
+
v = self.v if v is None else v
|
| 743 |
+
assert k is not None and v is not None
|
| 744 |
+
|
| 745 |
+
# Since every batch at least contributes one document,
|
| 746 |
+
# we can use this to check whether we are in streaming mode with dropped docs.
|
| 747 |
+
# If so, we should trim the kv cache accordingly
|
| 748 |
+
if len(self.casa_handler.current_doc_lengths) == len(k):
|
| 749 |
+
k = torch.cat(
|
| 750 |
+
[
|
| 751 |
+
_k[self.recover_batched_trims[idx] :][-doc_len:]
|
| 752 |
+
for idx, _k, doc_len in zip(
|
| 753 |
+
range(len(k)), k, self.casa_handler.current_doc_lengths
|
| 754 |
+
)
|
| 755 |
+
]
|
| 756 |
+
)
|
| 757 |
+
v = torch.cat(
|
| 758 |
+
[
|
| 759 |
+
_v[self.recover_batched_trims[idx] :][-doc_len:]
|
| 760 |
+
for idx, _v, doc_len in zip(
|
| 761 |
+
range(len(k)), v, self.casa_handler.current_doc_lengths
|
| 762 |
+
)
|
| 763 |
+
]
|
| 764 |
+
)
|
| 765 |
+
return k[None, ...], v[None, ...]
|
| 766 |
+
|
| 767 |
+
k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)])
|
| 768 |
+
v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)])
|
| 769 |
+
return k[None, ...], v[None, ...]
|
| 770 |
+
|
| 771 |
+
def extend_kv(
|
| 772 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor
|
| 773 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 774 |
+
"""
|
| 775 |
+
Extend KV Cache while keep
|
| 776 |
+
"""
|
| 777 |
+
assert self.casa_handler is not None
|
| 778 |
+
if self.k is None and self.v is None:
|
| 779 |
+
# Init with batch-padded key and value states
|
| 780 |
+
self.k = self.__recover_batched_kv__(key_states)
|
| 781 |
+
self.v = self.__recover_batched_kv__(value_states)
|
| 782 |
+
return self.__get_flattened_kv__()
|
| 783 |
+
if self.k is not None and self.v is not None:
|
| 784 |
+
# this is during generation; normally there is no padding at this stage
|
| 785 |
+
# so we can directly reshape the flattened key states
|
| 786 |
+
rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3])
|
| 787 |
+
self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1)
|
| 788 |
+
self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1)
|
| 789 |
+
return self.__get_flattened_kv__()
|
| 790 |
+
|
| 791 |
+
raise ValueError("Impossible configuration (k and v updates are desynchronized )")
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
class CASAAttention(StreamingModule[CASAAttentionStreamingState]):
|
| 795 |
+
def __init__(
|
| 796 |
+
self,
|
| 797 |
+
config: "PretrainedConfig",
|
| 798 |
+
layer_idx: int | None,
|
| 799 |
+
self_attn: torch.nn.Module | None = None,
|
| 800 |
+
input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
| 801 |
+
):
|
| 802 |
+
super().__init__(CASAAttentionStreamingState)
|
| 803 |
+
self.head_dim = config.head_dim
|
| 804 |
+
self.config = config
|
| 805 |
+
|
| 806 |
+
self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0)
|
| 807 |
+
self.use_delta_w = config.casa_delta_w
|
| 808 |
+
|
| 809 |
+
self.q_proj_casa = self.init_from_config_proj("q", config)
|
| 810 |
+
self.k_proj_casa = self.init_from_config_proj("k", config)
|
| 811 |
+
self.v_proj_casa = self.init_from_config_proj("v", config)
|
| 812 |
+
self.o_proj_casa = self.init_from_config_proj("o", config)
|
| 813 |
+
|
| 814 |
+
# Delta_w
|
| 815 |
+
self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
|
| 816 |
+
self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
|
| 817 |
+
self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
|
| 818 |
+
self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
|
| 819 |
+
|
| 820 |
+
if config.casa_delta_w:
|
| 821 |
+
assert self_attn is not None
|
| 822 |
+
self.set_delta_w(self_attn)
|
| 823 |
+
|
| 824 |
+
# Layer norm
|
| 825 |
+
self.norm_fn: Callable | None = None
|
| 826 |
+
if config.xa_norm_on_images:
|
| 827 |
+
assert input_layernorm_fn is not None
|
| 828 |
+
self.norm_fn = input_layernorm_fn
|
| 829 |
+
|
| 830 |
+
def init_from_mha(self, self_attn: torch.nn.Module):
|
| 831 |
+
assert self_attn is not None
|
| 832 |
+
with torch.no_grad():
|
| 833 |
+
assert hasattr(self_attn, "q_proj")
|
| 834 |
+
for key in ["q", "k", "v", "o"]:
|
| 835 |
+
src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj"))
|
| 836 |
+
tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa"))
|
| 837 |
+
tgt.weight.copy_(src.weight)
|
| 838 |
+
if tgt.bias is not None and src.bias is not None:
|
| 839 |
+
tgt.bias.copy_(src.bias)
|
| 840 |
+
|
| 841 |
+
def set_delta_w(self, self_attn: torch.nn.Module):
|
| 842 |
+
"""Delta w setup"""
|
| 843 |
+
self.override_q_proj = delta_w_factory(
|
| 844 |
+
self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj)
|
| 845 |
+
)
|
| 846 |
+
self.override_k_proj = delta_w_factory(
|
| 847 |
+
self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj)
|
| 848 |
+
)
|
| 849 |
+
self.override_v_proj = delta_w_factory(
|
| 850 |
+
self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj)
|
| 851 |
+
)
|
| 852 |
+
self.override_o_proj = delta_w_factory(
|
| 853 |
+
self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj)
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
with torch.no_grad():
|
| 857 |
+
torch.nn.init.zeros_(self.q_proj_casa.weight)
|
| 858 |
+
torch.nn.init.zeros_(self.k_proj_casa.weight)
|
| 859 |
+
torch.nn.init.zeros_(self.v_proj_casa.weight)
|
| 860 |
+
torch.nn.init.zeros_(self.o_proj_casa.weight)
|
| 861 |
+
if self.q_proj_casa.bias is not None:
|
| 862 |
+
torch.nn.init.zeros_(self.q_proj_casa.bias)
|
| 863 |
+
if self.k_proj_casa.bias is not None:
|
| 864 |
+
torch.nn.init.zeros_(self.k_proj_casa.bias)
|
| 865 |
+
if self.v_proj_casa.bias is not None:
|
| 866 |
+
torch.nn.init.zeros_(self.v_proj_casa.bias)
|
| 867 |
+
if self.o_proj_casa.bias is not None:
|
| 868 |
+
torch.nn.init.zeros_(self.o_proj_casa.bias)
|
| 869 |
+
|
| 870 |
+
def init_from_config_proj(
|
| 871 |
+
self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
|
| 872 |
+
) -> torch.nn.Linear:
|
| 873 |
+
"""Initialize the Linear proj in this module"""
|
| 874 |
+
raise NotImplementedError("Abastract class.")
|
| 875 |
+
|
| 876 |
+
def apply_position_embeddings(
|
| 877 |
+
self,
|
| 878 |
+
key: Literal["q", "kv"],
|
| 879 |
+
x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
|
| 880 |
+
casa_handler: CASAAttentionHandler | None,
|
| 881 |
+
num_queries: int = 0,
|
| 882 |
+
unsqueeze_dim: int = 1,
|
| 883 |
+
) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
|
| 884 |
+
"""Apply position embeddings to query and key states"""
|
| 885 |
+
raise NotImplementedError("Abastract class.")
|
| 886 |
+
|
| 887 |
+
def forward(
|
| 888 |
+
self,
|
| 889 |
+
hidden_states: torch.Tensor,
|
| 890 |
+
casa_handler: CASAAttentionHandler | None,
|
| 891 |
+
) -> torch.Tensor | None:
|
| 892 |
+
"""Generic forward for CASA uses for instance in `helium1_attention`"""
|
| 893 |
+
og_dtype = hidden_states.dtype
|
| 894 |
+
if self.is_streaming:
|
| 895 |
+
casa_handler = self.streaming_state.maybe_get_casa_handler(
|
| 896 |
+
casa_handler,
|
| 897 |
+
is_first_casa_layer=self.is_first_casa_layer,
|
| 898 |
+
num_queries=hidden_states.shape[1],
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
# Case of text-only samples at training (or inference when no handler was cached)
|
| 902 |
+
# in this case we just skip CASA so we return None (no casa_update)
|
| 903 |
+
if casa_handler is None:
|
| 904 |
+
return None
|
| 905 |
+
|
| 906 |
+
if self.is_streaming:
|
| 907 |
+
assert casa_handler.use_asymetric_qkv, (
|
| 908 |
+
"You should set `use_asymetric_qkv` to True during inference"
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
og_shape = hidden_states.shape
|
| 912 |
+
|
| 913 |
+
# Build Q inputs
|
| 914 |
+
if casa_handler.use_asymetric_qkv:
|
| 915 |
+
q_inputs = hidden_states.flatten(0, 1)[None, ...]
|
| 916 |
+
if casa_handler.attention_mask is not None:
|
| 917 |
+
q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()]
|
| 918 |
+
else:
|
| 919 |
+
q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
|
| 920 |
+
|
| 921 |
+
# Case 1: Training or first inference step
|
| 922 |
+
if not self.is_streaming or self.streaming_state.offset == 0:
|
| 923 |
+
kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
|
| 924 |
+
else:
|
| 925 |
+
# during streaming, the KV cache including image embeddings
|
| 926 |
+
# will be inserted later so for now we only update the incoming queries
|
| 927 |
+
kv_inputs = q_inputs
|
| 928 |
+
|
| 929 |
+
# Compute QKV for the blockwise attention
|
| 930 |
+
bs, total_seq_len = kv_inputs.shape[:2]
|
| 931 |
+
hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim)
|
| 932 |
+
hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim)
|
| 933 |
+
|
| 934 |
+
if self.override_q_proj is None:
|
| 935 |
+
query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q)
|
| 936 |
+
else:
|
| 937 |
+
query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q)
|
| 938 |
+
|
| 939 |
+
if self.override_k_proj is None:
|
| 940 |
+
key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv)
|
| 941 |
+
else:
|
| 942 |
+
key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv)
|
| 943 |
+
|
| 944 |
+
if self.override_v_proj is None:
|
| 945 |
+
value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv)
|
| 946 |
+
else:
|
| 947 |
+
value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv)
|
| 948 |
+
|
| 949 |
+
# Apply position embedding at the right offset
|
| 950 |
+
num_queries = 0
|
| 951 |
+
if self.streaming and self.streaming_state.offset > 0:
|
| 952 |
+
num_queries = og_shape[1]
|
| 953 |
+
|
| 954 |
+
query_states = self.apply_position_embeddings(
|
| 955 |
+
"q", query_states, num_queries=num_queries, casa_handler=casa_handler
|
| 956 |
+
)
|
| 957 |
+
key_states = self.apply_position_embeddings(
|
| 958 |
+
"kv", key_states, num_queries=num_queries, casa_handler=casa_handler
|
| 959 |
+
)
|
| 960 |
+
assert flash_attn_varlen_func is not None, (
|
| 961 |
+
"flash_attention is not installed but required for block-wise attention"
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
# Flashattention has different efficient implem for streaming
|
| 965 |
+
# In that case, the KV cache has to be batched and has been extended
|
| 966 |
+
# to accomodate the shape of ne the new updates
|
| 967 |
+
if self.is_streaming:
|
| 968 |
+
key_states, value_states = self.streaming_state.extend_kv(
|
| 969 |
+
key_states=key_states, value_states=value_states
|
| 970 |
+
)
|
| 971 |
+
if casa_handler.use_asymetric_qkv:
|
| 972 |
+
cu_seqlens_q = casa_handler.cu_seqlens_q
|
| 973 |
+
max_seqlen_q = casa_handler.max_seqlen_q
|
| 974 |
+
else:
|
| 975 |
+
cu_seqlens_q = casa_handler.cu_seqlens_kv
|
| 976 |
+
max_seqlen_q = casa_handler.max_seqlen_kv
|
| 977 |
+
assert cu_seqlens_q[-1] == query_states.shape[1], (
|
| 978 |
+
f"{cu_seqlens_q[-1]} != {query_states.shape[1]}"
|
| 979 |
+
)
|
| 980 |
+
assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], (
|
| 981 |
+
f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}"
|
| 982 |
+
)
|
| 983 |
+
# for quer
|
| 984 |
+
attn_output: torch.Tensor = flash_attn_varlen_func(
|
| 985 |
+
query_states[0].to(torch.bfloat16),
|
| 986 |
+
key_states[0].to(torch.bfloat16),
|
| 987 |
+
value_states[0].to(torch.bfloat16),
|
| 988 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 989 |
+
cu_seqlens_k=casa_handler.cu_seqlens_kv,
|
| 990 |
+
max_seqlen_q=max_seqlen_q,
|
| 991 |
+
max_seqlen_k=casa_handler.max_seqlen_kv,
|
| 992 |
+
dropout_p=0.0,
|
| 993 |
+
# softmax_scale=None, # defaults to 1/sqrt(d)
|
| 994 |
+
causal=True,
|
| 995 |
+
).to(og_dtype)
|
| 996 |
+
|
| 997 |
+
attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous()
|
| 998 |
+
if self.override_o_proj is None:
|
| 999 |
+
attn_output = self.o_proj_casa(attn_output)
|
| 1000 |
+
else:
|
| 1001 |
+
attn_output = self.override_o_proj(attn_output)
|
| 1002 |
+
|
| 1003 |
+
attn_output = casa_handler.recover_text_embeds(
|
| 1004 |
+
attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds
|
| 1005 |
+
)
|
| 1006 |
+
attn_output = attn_output.reshape(og_shape)
|
| 1007 |
+
|
| 1008 |
+
if self.is_streaming:
|
| 1009 |
+
self.streaming_state.offset += attn_output.shape[1]
|
| 1010 |
+
return attn_output
|
config.json
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"attention_dropout": 0.0,
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configuration_helium1_casa.Helium1CASAConfig",
|
| 6 |
+
"AutoModel": "modeling_helium1_casa.V2Helium1"
|
| 7 |
+
},
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"casa_attention": true,
|
| 10 |
+
"casa_delta_w": false,
|
| 11 |
+
"casa_use_asymetric_qkv": true,
|
| 12 |
+
"casa_windows": "images",
|
| 13 |
+
"eos_token_id": null,
|
| 14 |
+
"head_dim": 128,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 2048,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 8192,
|
| 19 |
+
"mask_squash_blockwise": false,
|
| 20 |
+
"max_position_embeddings": 4096,
|
| 21 |
+
"mlp_bias": false,
|
| 22 |
+
"model_type": "CASA_Helium1_VL_2B",
|
| 23 |
+
"num_attention_heads": 16,
|
| 24 |
+
"num_hidden_layers": 28,
|
| 25 |
+
"num_key_value_heads": 8,
|
| 26 |
+
"pad_token_id": 3,
|
| 27 |
+
"post_image_tokens": [],
|
| 28 |
+
"pre_image_tokens": [],
|
| 29 |
+
"pretraining_tp": 1,
|
| 30 |
+
"rms_norm_eps": 1e-08,
|
| 31 |
+
"rope_scaling": null,
|
| 32 |
+
"rope_theta": 20000.0,
|
| 33 |
+
"tie_word_embeddings": false,
|
| 34 |
+
"torch_dtype": "bfloat16",
|
| 35 |
+
"transformers_version": "4.51.3",
|
| 36 |
+
"use_cache": true,
|
| 37 |
+
"vision_config": {
|
| 38 |
+
"depth": 32,
|
| 39 |
+
"fullatt_block_indexes": [
|
| 40 |
+
7,
|
| 41 |
+
15,
|
| 42 |
+
23,
|
| 43 |
+
31
|
| 44 |
+
],
|
| 45 |
+
"hidden_act": "silu",
|
| 46 |
+
"hidden_size": 1280,
|
| 47 |
+
"image_mean": [
|
| 48 |
+
0.48145466,
|
| 49 |
+
0.4578275,
|
| 50 |
+
0.40821073
|
| 51 |
+
],
|
| 52 |
+
"image_std": [
|
| 53 |
+
0.26862954,
|
| 54 |
+
0.26130258,
|
| 55 |
+
0.27577711
|
| 56 |
+
],
|
| 57 |
+
"in_channels": 3,
|
| 58 |
+
"in_chans": 3,
|
| 59 |
+
"intermediate_size": 3420,
|
| 60 |
+
"model_type": "qwen2_5_vl",
|
| 61 |
+
"num_heads": 16,
|
| 62 |
+
"out_dim": 2048,
|
| 63 |
+
"out_hidden_size": 2048,
|
| 64 |
+
"patch_size": 14,
|
| 65 |
+
"spatial_merge_size": 2,
|
| 66 |
+
"spatial_patch_size": 14,
|
| 67 |
+
"temporal_patch_size": 1,
|
| 68 |
+
"tokens_per_second": 2,
|
| 69 |
+
"window_size": 112
|
| 70 |
+
},
|
| 71 |
+
"vocab_size": 64000,
|
| 72 |
+
"xa_custom_norm": true,
|
| 73 |
+
"xa_layers": [],
|
| 74 |
+
"xa_norm_on_images": true,
|
| 75 |
+
"xa_order": "ca_first",
|
| 76 |
+
"xa_update_image_embeds": false
|
| 77 |
+
}
|
configuration_helium1_casa.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Literal
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Helium1CASAConfig(PretrainedConfig):
|
| 8 |
+
r"""
|
| 9 |
+
Helium1 Config augmented with CASA options
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 14 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
| 15 |
+
`inputs_ids` passed when calling [`Helium1Model`]
|
| 16 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 17 |
+
Dimension of the hidden representations.
|
| 18 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 19 |
+
Dimension of the MLP representations.
|
| 20 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 21 |
+
Number of hidden layers in the Transformer decoder.
|
| 22 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 23 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 24 |
+
num_key_value_heads (`int`, *optional*):
|
| 25 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 26 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 27 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 28 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 29 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 30 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 31 |
+
`num_attention_heads`.
|
| 32 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 33 |
+
The non-linear activation function (function or string) in the decoder.
|
| 34 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 35 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
| 36 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
| 37 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 38 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 39 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 40 |
+
The epsilon used by the rms normalization layers.
|
| 41 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 42 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 43 |
+
relevant if `config.is_decoder=True`.
|
| 44 |
+
pad_token_id (`int`, *optional*):
|
| 45 |
+
Padding token id.
|
| 46 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 47 |
+
Beginning of stream token id.
|
| 48 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 49 |
+
End of stream token id.
|
| 50 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 51 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 52 |
+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
| 53 |
+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
| 54 |
+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 55 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 56 |
+
Whether to tie weight embeddings
|
| 57 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 58 |
+
The base period of the RoPE embeddings.
|
| 59 |
+
rope_scaling (`Dict`, *optional*):
|
| 60 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 61 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 62 |
+
accordingly.
|
| 63 |
+
Expected contents:
|
| 64 |
+
`rope_type` (`str`):
|
| 65 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 66 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 67 |
+
`factor` (`float`, *optional*):
|
| 68 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 69 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 70 |
+
original maximum pre-trained length.
|
| 71 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 72 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 73 |
+
pretraining.
|
| 74 |
+
`attention_factor` (`float`, *optional*):
|
| 75 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 76 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 77 |
+
`factor` field to infer the suggested value.
|
| 78 |
+
`beta_fast` (`float`, *optional*):
|
| 79 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 80 |
+
ramp function. If unspecified, it defaults to 32.
|
| 81 |
+
`beta_slow` (`float`, *optional*):
|
| 82 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 83 |
+
ramp function. If unspecified, it defaults to 1.
|
| 84 |
+
`short_factor` (`List[float]`, *optional*):
|
| 85 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 86 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 87 |
+
size divided by the number of attention heads divided by 2
|
| 88 |
+
`long_factor` (`List[float]`, *optional*):
|
| 89 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 90 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 91 |
+
size divided by the number of attention heads divided by 2
|
| 92 |
+
`low_freq_factor` (`float`, *optional*):
|
| 93 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 94 |
+
`high_freq_factor` (`float`, *optional*):
|
| 95 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 96 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 97 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 98 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 99 |
+
The dropout ratio for the attention probabilities.
|
| 100 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 101 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
| 102 |
+
head_dim (`int`, *optional*):
|
| 103 |
+
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
model_type = "helium1_casa"
|
| 108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 109 |
+
# Default tensor parallel plan for base model `Helium1Model`
|
| 110 |
+
base_model_tp_plan = {
|
| 111 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 112 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 113 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 114 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 115 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 116 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 117 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 118 |
+
}
|
| 119 |
+
base_model_pp_plan = { # pyright: ignore[reportAssignmentType]
|
| 120 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 121 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 122 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
vocab_size: int = 32000,
|
| 128 |
+
hidden_size: int = 4096,
|
| 129 |
+
intermediate_size: int = 11008,
|
| 130 |
+
num_hidden_layers: int = 32,
|
| 131 |
+
num_attention_heads: int = 32,
|
| 132 |
+
num_key_value_heads: None | int = None,
|
| 133 |
+
head_dim: None | int = None,
|
| 134 |
+
hidden_act: str = "silu",
|
| 135 |
+
attention_dropout: float = 0.0,
|
| 136 |
+
max_position_embeddings: int = 2048,
|
| 137 |
+
initializer_range: float = 0.02,
|
| 138 |
+
rms_norm_eps: float = 1e-6,
|
| 139 |
+
use_cache: bool = True,
|
| 140 |
+
tie_word_embeddings: bool = False,
|
| 141 |
+
rope_theta: float = 10000.0,
|
| 142 |
+
pad_token_id: int = 3,
|
| 143 |
+
eos_token_id: int = 2,
|
| 144 |
+
bos_token_id: int = 1,
|
| 145 |
+
pretraining_tp: int = 1,
|
| 146 |
+
rope_scaling: None | dict = None,
|
| 147 |
+
attention_bias: bool = False,
|
| 148 |
+
mlp_bias: bool = False,
|
| 149 |
+
# Our fusion mechanisms
|
| 150 |
+
# Common to all fusion mechanisms
|
| 151 |
+
xa_layers: None | tuple = None,
|
| 152 |
+
xa_order: Literal["ca_first", "parallel", "instead"] = "ca_first",
|
| 153 |
+
xa_norm_on_images: bool = False,
|
| 154 |
+
xa_update_image_embeds: bool = False,
|
| 155 |
+
mask_squash_blockwise: bool = False,
|
| 156 |
+
# CASA
|
| 157 |
+
casa_attention: bool = False,
|
| 158 |
+
casa_delta_w: bool = False,
|
| 159 |
+
casa_windows: Literal["batch", "squashed", "images", "turn_based"] = "batch",
|
| 160 |
+
casa_use_asymetric_qkv: bool = True,
|
| 161 |
+
xa_custom_norm: bool = False,
|
| 162 |
+
# Qwen2.5-VL vision config
|
| 163 |
+
vision_config: dict[str, Any] | None = None,
|
| 164 |
+
**kwargs: Any,
|
| 165 |
+
):
|
| 166 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 167 |
+
|
| 168 |
+
self.vocab_size = vocab_size
|
| 169 |
+
self.max_position_embeddings = max_position_embeddings
|
| 170 |
+
self.hidden_size = hidden_size
|
| 171 |
+
self.intermediate_size = intermediate_size
|
| 172 |
+
self.num_hidden_layers = num_hidden_layers
|
| 173 |
+
self.num_attention_heads = num_attention_heads
|
| 174 |
+
|
| 175 |
+
# for backward compatibility
|
| 176 |
+
if num_key_value_heads is None:
|
| 177 |
+
num_key_value_heads = num_attention_heads
|
| 178 |
+
|
| 179 |
+
self.num_key_value_heads = num_key_value_heads
|
| 180 |
+
self.head_dim = (
|
| 181 |
+
head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 182 |
+
)
|
| 183 |
+
self.hidden_act = hidden_act
|
| 184 |
+
self.initializer_range = initializer_range
|
| 185 |
+
self.rms_norm_eps = rms_norm_eps
|
| 186 |
+
self.pretraining_tp = pretraining_tp
|
| 187 |
+
self.use_cache = use_cache
|
| 188 |
+
self.rope_theta = rope_theta
|
| 189 |
+
self.rope_scaling = rope_scaling
|
| 190 |
+
self.attention_bias = attention_bias
|
| 191 |
+
self.attention_dropout = attention_dropout
|
| 192 |
+
self.mlp_bias = mlp_bias
|
| 193 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 194 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 195 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 196 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 197 |
+
rope_config_validation(self)
|
| 198 |
+
|
| 199 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 200 |
+
self.xa_layers = xa_layers
|
| 201 |
+
self.xa_order: Literal["ca_first", "parallel", "instead"] = xa_order
|
| 202 |
+
self.xa_norm_on_images = xa_norm_on_images
|
| 203 |
+
self.xa_update_image_embeds = xa_update_image_embeds
|
| 204 |
+
self.mask_squash_blockwise = mask_squash_blockwise
|
| 205 |
+
# CASA config
|
| 206 |
+
self.casa_attention = casa_attention
|
| 207 |
+
self.casa_delta_w = casa_delta_w
|
| 208 |
+
self.casa_windows: Literal["batch", "squashed", "images", "turn_based"] = casa_windows
|
| 209 |
+
self.casa_use_asymetric_qkv = casa_use_asymetric_qkv
|
| 210 |
+
self.xa_custom_norm = xa_custom_norm
|
| 211 |
+
|
| 212 |
+
if vision_config is None:
|
| 213 |
+
vision_config = dict()
|
| 214 |
+
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
|
| 215 |
+
self.vision_config.temporal_patch_size = 1
|
| 216 |
+
self.vision_config.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 217 |
+
self.vision_config.image_std = [0.26862954, 0.26130258, 0.27577711]
|
| 218 |
+
self.vision_config.out_dim = 2048
|
| 219 |
+
|
| 220 |
+
self.pre_image_tokens = []
|
| 221 |
+
self.post_image_tokens = []
|
| 222 |
+
|
| 223 |
+
super().__init__(
|
| 224 |
+
pad_token_id=pad_token_id,
|
| 225 |
+
bos_token_id=bos_token_id,
|
| 226 |
+
eos_token_id=eos_token_id,
|
| 227 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 228 |
+
**kwargs,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
import argparse
|
| 234 |
+
from pathlib import Path
|
| 235 |
+
|
| 236 |
+
import rich
|
| 237 |
+
import yaml
|
| 238 |
+
from transformers.models.auto.configuration_auto import AutoConfig
|
| 239 |
+
|
| 240 |
+
parser = argparse.ArgumentParser()
|
| 241 |
+
parser.add_argument("--out_dir", type=str, default="./saved_config/")
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--ckpt_path",
|
| 244 |
+
type=str,
|
| 245 |
+
default="/lustre/scwpod02/client/kyutai/juliette/experiments/finext_casa_896_xtxt_up_b20_64gpu/fdf76e6774",
|
| 246 |
+
)
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
path = Path(args.ckpt_path) / "kyuteye_config.yml"
|
| 249 |
+
|
| 250 |
+
helium_config = AutoConfig.from_pretrained("kyutai/helium-1-2b")
|
| 251 |
+
vision_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct").vision_config
|
| 252 |
+
|
| 253 |
+
# 3) Create YOUR config by merging both
|
| 254 |
+
config = Helium1CASAConfig(
|
| 255 |
+
**helium_config.to_dict(), # all helium parameters
|
| 256 |
+
vision_config=vision_config.to_dict(), # override or add vision_config
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
with open(path) as stream:
|
| 260 |
+
kconfig = yaml.safe_load(stream)
|
| 261 |
+
|
| 262 |
+
# print keys that are in kconfig and in config
|
| 263 |
+
for key in set(kconfig.keys()).intersection(set(config.to_dict().keys())):
|
| 264 |
+
rich.print(f"Overwriting [bold green]{key:>50s}[/]: [bold red]{kconfig[key]}")
|
| 265 |
+
setattr(config, key, kconfig[key])
|
| 266 |
+
# TODO: handle casa_own_norm -> xa_custom_norm
|
| 267 |
+
print("Configuration successfully loaded.")
|
| 268 |
+
# Save config to json
|
| 269 |
+
config.save_pretrained(args.out_dir)
|
| 270 |
+
print(f"Configuration saved to {args.out_dir}/config.json")
|
generation_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
103,
|
| 6 |
+
3
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 3,
|
| 9 |
+
"transformers_version": "4.51.3"
|
| 10 |
+
}
|
image_encoder.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qwen2.5VL encoder with delayed normalization"""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 6 |
+
Qwen2_5_VisionTransformerPretrainedModel,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def prepare_for_qwen_encoder(
|
| 11 |
+
x: torch.Tensor | list[torch.Tensor], mean: torch.Tensor, std: torch.Tensor
|
| 12 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 13 |
+
"""
|
| 14 |
+
Preprocessing for Qwen encoder
|
| 15 |
+
Image mean and std come from processor.image_processor.image_mean and image_std
|
| 16 |
+
"""
|
| 17 |
+
grid_thw = torch.Tensor([[1, img.shape[0], img.shape[1]] for img in x]).to(x[0].device)
|
| 18 |
+
hws_flatten_shape = torch.prod(grid_thw, dim=-1)
|
| 19 |
+
x = torch.cat(
|
| 20 |
+
[img.reshape((int(hws_flatten_shape[idx].item()), -1)) for idx, img in enumerate(x)],
|
| 21 |
+
dim=0,
|
| 22 |
+
)
|
| 23 |
+
assert x.min() >= 0.0 and x.max() <= 1.0
|
| 24 |
+
og_shape = x.shape
|
| 25 |
+
x = rearrange(x, "L (c d) -> L c d", c=3)
|
| 26 |
+
x = (x - mean) / std
|
| 27 |
+
x = x.view(og_shape).to(torch.bfloat16)
|
| 28 |
+
return x, grid_thw
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Qwen25VLEncoder(torch.nn.Module):
|
| 32 |
+
"""Qwen2.5 VL encoder with pre/post processing to be compatible for
|
| 33 |
+
our CASA attention implementation"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
visual: "Qwen2_5_VisionTransformerPretrainedModel",
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.visual = visual
|
| 41 |
+
self.image_mean = torch.tensor(self.visual.config.image_mean).view(1, 3, 1)
|
| 42 |
+
self.image_std = torch.tensor(self.visual.config.image_std).view(1, 3, 1)
|
| 43 |
+
|
| 44 |
+
def forward(
|
| 45 |
+
self, x: torch.Tensor | list[torch.Tensor]
|
| 46 |
+
) -> dict[str, torch.Tensor | list[torch.Tensor]]:
|
| 47 |
+
x, grid_thw = prepare_for_qwen_encoder(
|
| 48 |
+
x, mean=self.image_mean.to(x[0].device), std=self.image_std.to(x[0].device)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
grid_thw = grid_thw.type(torch.int)
|
| 52 |
+
assert len(x) == grid_thw.prod(dim=1).sum()
|
| 53 |
+
out = self.visual(x, grid_thw=grid_thw)
|
| 54 |
+
|
| 55 |
+
split_sizes = (grid_thw.prod(dim=-1) // self.visual.spatial_merge_size**2).tolist()
|
| 56 |
+
embeds = list(torch.split(out, split_sizes, dim=0)) # Ni * (seq, C)
|
| 57 |
+
return {"image_embeds": embeds, "grid_thw": grid_thw}
|
language_helium1_casa.py
ADDED
|
@@ -0,0 +1,1077 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADAPTED FROM https://github.com/huggingface/transformers/blob/main/src/transformers/models/helium/modeling_helium.py
|
| 2 |
+
# GIT HASH 1b222903c3e1cfd9492d75e4b2548aa8bd458674
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Any, Callable, Literal, Optional
|
| 8 |
+
from typing import cast as type_cast
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from transformers import (
|
| 13 |
+
ROPE_INIT_FUNCTIONS, # pyright: ignore[reportPrivateImportUsage]
|
| 14 |
+
dynamic_rope_update, # pyright: ignore[reportPrivateImportUsage]
|
| 15 |
+
)
|
| 16 |
+
from transformers.activations import ACT2FN
|
| 17 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 18 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
+
from transformers.generation.utils import GenerationMixin
|
| 20 |
+
from transformers.loss.loss_utils import ForCausalLMLoss
|
| 21 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 22 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 23 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 24 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 25 |
+
from transformers.processing_utils import Unpack
|
| 26 |
+
from transformers.utils.generic import LossKwargs, can_return_tuple
|
| 27 |
+
from transformers.utils.import_utils import is_torch_flex_attn_available
|
| 28 |
+
|
| 29 |
+
from .casa_attention import CASAAttention, CASAAttentionHandler, insert_image_tokens
|
| 30 |
+
from .configuration_helium1_casa import Helium1CASAConfig
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
if is_torch_flex_attn_available():
|
| 35 |
+
from transformers.integrations.flex_attention import make_flex_block_causal_mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def remove_image_tokens(
|
| 39 |
+
inputs_embeds: torch.Tensor,
|
| 40 |
+
image_tokens_mask: torch.Tensor,
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
"""Remove the image tokens from inputs_embeds as indicated by image_tokens_mask
|
| 43 |
+
|
| 44 |
+
:param inputs_embeds: Tokens of shape (Batch, Seqlen, Dims) containing image tokens
|
| 45 |
+
:param image_tokens_mask: 1-0 mask indicating where image tokens are; (Batch, Seqlen)
|
| 46 |
+
|
| 47 |
+
:return: Tokens tensor of shape (Batch, S' < Seqlen, Dims)
|
| 48 |
+
"""
|
| 49 |
+
image_seq_lengths = torch.sum(image_tokens_mask, dim=1)[:, 0]
|
| 50 |
+
image_seq_length = int(image_seq_lengths[0].item())
|
| 51 |
+
assert torch.all(image_seq_lengths == image_seq_length)
|
| 52 |
+
new_shape = (
|
| 53 |
+
inputs_embeds.shape[0],
|
| 54 |
+
inputs_embeds.shape[1] - image_seq_length,
|
| 55 |
+
inputs_embeds.shape[-1],
|
| 56 |
+
)
|
| 57 |
+
tokens = torch.masked_select(
|
| 58 |
+
inputs_embeds,
|
| 59 |
+
torch.logical_not(image_tokens_mask).expand((-1, -1, inputs_embeds.shape[-1])),
|
| 60 |
+
)
|
| 61 |
+
return tokens.reshape(new_shape)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 67 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 68 |
+
"""
|
| 69 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 70 |
+
if n_rep == 1:
|
| 71 |
+
return hidden_states
|
| 72 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 73 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 74 |
+
)
|
| 75 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def eager_attention_forward(
|
| 79 |
+
module: "HeliumAttention",
|
| 80 |
+
query: torch.Tensor,
|
| 81 |
+
key: torch.Tensor,
|
| 82 |
+
value: torch.Tensor,
|
| 83 |
+
attention_mask: None | torch.Tensor,
|
| 84 |
+
scaling: float,
|
| 85 |
+
dropout: float = 0.0,
|
| 86 |
+
**kwargs: Any,
|
| 87 |
+
):
|
| 88 |
+
del kwargs # unused
|
| 89 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 90 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 91 |
+
|
| 92 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 93 |
+
if attention_mask is not None:
|
| 94 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 95 |
+
attn_weights = attn_weights + causal_mask
|
| 96 |
+
|
| 97 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 98 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 99 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 100 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 101 |
+
|
| 102 |
+
return attn_output, attn_weights
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Different Attention Classes
|
| 106 |
+
class HeliumAttention(torch.nn.Module):
|
| 107 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, config: Helium1CASAConfig, layer_idx: None | int = None):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.config = config
|
| 112 |
+
assert layer_idx is not None
|
| 113 |
+
self.layer_idx: int = layer_idx
|
| 114 |
+
|
| 115 |
+
self.apply_rotary_fn = ApplyRotaryPosEmbHelium1()
|
| 116 |
+
self.head_dim = getattr(
|
| 117 |
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
| 118 |
+
)
|
| 119 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 120 |
+
self.scaling = 1 / math.sqrt(self.head_dim)
|
| 121 |
+
self.attention_dropout = config.attention_dropout
|
| 122 |
+
self.is_causal = True
|
| 123 |
+
|
| 124 |
+
self.q_proj = nn.Linear(
|
| 125 |
+
config.hidden_size,
|
| 126 |
+
config.num_attention_heads * self.head_dim,
|
| 127 |
+
bias=config.attention_bias,
|
| 128 |
+
)
|
| 129 |
+
self.k_proj = nn.Linear(
|
| 130 |
+
config.hidden_size,
|
| 131 |
+
config.num_key_value_heads * self.head_dim,
|
| 132 |
+
bias=config.attention_bias,
|
| 133 |
+
)
|
| 134 |
+
self.v_proj = nn.Linear(
|
| 135 |
+
config.hidden_size,
|
| 136 |
+
config.num_key_value_heads * self.head_dim,
|
| 137 |
+
bias=config.attention_bias,
|
| 138 |
+
)
|
| 139 |
+
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
| 140 |
+
|
| 141 |
+
def forward(
|
| 142 |
+
self,
|
| 143 |
+
hidden_states: torch.Tensor,
|
| 144 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 145 |
+
attention_mask: None | torch.Tensor,
|
| 146 |
+
past_key_values: None | Cache = None,
|
| 147 |
+
cache_position: None | torch.LongTensor = None,
|
| 148 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 149 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 150 |
+
# del (cache_position, past_key_value) # we use our own generate/caching
|
| 151 |
+
bs, seq_len, _ = hidden_states.shape
|
| 152 |
+
# Get QKV
|
| 153 |
+
hidden_shape = (bs, seq_len, -1, self.head_dim)
|
| 154 |
+
|
| 155 |
+
# Embed Queries
|
| 156 |
+
# Shape: (batch_size, num_heads, seq_len, head_dim)
|
| 157 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 158 |
+
num_queries = query_states.shape[2]
|
| 159 |
+
|
| 160 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 161 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 162 |
+
|
| 163 |
+
# Applies rotation
|
| 164 |
+
cos, sin = position_embeddings
|
| 165 |
+
query_states, key_states = self.apply_rotary_fn(
|
| 166 |
+
query_states, key_states, cos, sin, num_queries=num_queries
|
| 167 |
+
)
|
| 168 |
+
assert key_states is not None and query_states is not None
|
| 169 |
+
|
| 170 |
+
attention_interface: Callable = eager_attention_forward
|
| 171 |
+
|
| 172 |
+
if self.config._attn_implementation != "eager":
|
| 173 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get(
|
| 174 |
+
"output_attentions", False
|
| 175 |
+
):
|
| 176 |
+
print(
|
| 177 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support"
|
| 178 |
+
" `output_attentions=True`. Falling back to "
|
| 179 |
+
'eager attention. This warning can be removed using the argument"\
|
| 180 |
+
" `attn_implementation="eager"` when loading the model.'
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 184 |
+
|
| 185 |
+
if past_key_values is not None:
|
| 186 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 187 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 188 |
+
key_states, value_states = past_key_values.update(
|
| 189 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
| 190 |
+
)
|
| 191 |
+
attn_output, attn_weights = attention_interface(
|
| 192 |
+
self,
|
| 193 |
+
query_states,
|
| 194 |
+
key_states,
|
| 195 |
+
value_states,
|
| 196 |
+
attention_mask,
|
| 197 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 198 |
+
scaling=self.scaling,
|
| 199 |
+
**kwargs,
|
| 200 |
+
)
|
| 201 |
+
attn_output = attn_output.reshape(bs, num_queries, -1).contiguous()
|
| 202 |
+
attn_output = self.o_proj(attn_output)
|
| 203 |
+
|
| 204 |
+
assert isinstance(attn_output, torch.Tensor)
|
| 205 |
+
return attn_output, attn_weights
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class ApplyRotaryPosEmbHelium1:
|
| 209 |
+
@staticmethod
|
| 210 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 211 |
+
"""Rotates half the hidden dims of the input."""
|
| 212 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 213 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 214 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def __call__(
|
| 218 |
+
q: torch.Tensor,
|
| 219 |
+
k: torch.Tensor,
|
| 220 |
+
cos: torch.Tensor,
|
| 221 |
+
sin: torch.Tensor,
|
| 222 |
+
position_ids: torch.Tensor | None = None,
|
| 223 |
+
unsqueeze_dim: int = 1,
|
| 224 |
+
num_queries: int | None = None,
|
| 225 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 226 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
q (`torch.Tensor`): The query tensor.
|
| 230 |
+
k (`torch.Tensor`): The key tensor.
|
| 231 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 232 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 233 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 234 |
+
Deprecated and unused.
|
| 235 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 236 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 237 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 238 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 239 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 240 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 241 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 242 |
+
Returns:
|
| 243 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 244 |
+
"""
|
| 245 |
+
del position_ids
|
| 246 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 247 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 248 |
+
if num_queries is None:
|
| 249 |
+
offset = 0
|
| 250 |
+
else:
|
| 251 |
+
offset = -num_queries
|
| 252 |
+
|
| 253 |
+
q_embed = (q * cos[:, :, offset:]) + (
|
| 254 |
+
ApplyRotaryPosEmbHelium1.rotate_half(q) * sin[:, :, offset:]
|
| 255 |
+
)
|
| 256 |
+
k_embed = (k * cos) + (ApplyRotaryPosEmbHelium1.rotate_half(k) * sin)
|
| 257 |
+
|
| 258 |
+
return q_embed, k_embed
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class HeliumRotaryEmbedding(nn.Module):
|
| 262 |
+
def __init__(self, config: Helium1CASAConfig, device: None | torch.device | str = None):
|
| 263 |
+
super().__init__()
|
| 264 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 265 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 266 |
+
else:
|
| 267 |
+
self.rope_type = "default"
|
| 268 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 269 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 270 |
+
|
| 271 |
+
self.config = config
|
| 272 |
+
assert self.rope_type in ROPE_INIT_FUNCTIONS, (
|
| 273 |
+
f"Invalid rope type {self.rope_type}. Supported types are: {list(ROPE_INIT_FUNCTIONS.keys())}"
|
| 274 |
+
)
|
| 275 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 276 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(config, device=device)
|
| 277 |
+
self.inv_freq: torch.Tensor # only defined for typing
|
| 278 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 279 |
+
self.original_inv_freq = self.inv_freq
|
| 280 |
+
|
| 281 |
+
@torch.no_grad()
|
| 282 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 283 |
+
def forward(
|
| 284 |
+
self, x: torch.Tensor, position_ids: torch.Tensor
|
| 285 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 286 |
+
inv_freq_expanded = (
|
| 287 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 288 |
+
)
|
| 289 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 290 |
+
|
| 291 |
+
device_type = (
|
| 292 |
+
x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 293 |
+
)
|
| 294 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 295 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 296 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 297 |
+
cos = emb.cos() * self.attention_scaling
|
| 298 |
+
sin = emb.sin() * self.attention_scaling
|
| 299 |
+
|
| 300 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class Helium1CASAAttention(CASAAttention):
|
| 304 |
+
"""A CASA Attention layer compatible with Qwen"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
config: Helium1CASAConfig,
|
| 309 |
+
layer_idx: int | None,
|
| 310 |
+
self_attn: torch.nn.Module | None = None,
|
| 311 |
+
input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
| 312 |
+
):
|
| 313 |
+
# Only adding this init for typing purposes for the config
|
| 314 |
+
super().__init__(config, layer_idx, self_attn, input_layernorm_fn) # pyright: ignore[reportArgumentType]
|
| 315 |
+
|
| 316 |
+
@staticmethod
|
| 317 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 318 |
+
"""Rotates half the hidden dims of the input."""
|
| 319 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 320 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 321 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 322 |
+
|
| 323 |
+
def apply_position_embeddings(
|
| 324 |
+
self,
|
| 325 |
+
key: Literal["q", "kv"],
|
| 326 |
+
x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
|
| 327 |
+
casa_handler: CASAAttentionHandler | None,
|
| 328 |
+
num_queries: int = 0,
|
| 329 |
+
unsqueeze_dim: int = 1,
|
| 330 |
+
) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
|
| 331 |
+
"""Apply position embeddings to query and key states"""
|
| 332 |
+
if casa_handler is not None:
|
| 333 |
+
posemb = casa_handler.get_position_embedding(key, num_queries=num_queries)
|
| 334 |
+
|
| 335 |
+
if posemb is not None:
|
| 336 |
+
x = x.transpose(1, 2).to(torch.float32)
|
| 337 |
+
x = (x * posemb[0].unsqueeze(dim=unsqueeze_dim)) + (
|
| 338 |
+
self.rotate_half(x) * posemb[1].unsqueeze(dim=unsqueeze_dim)
|
| 339 |
+
)
|
| 340 |
+
return x.transpose(1, 2)
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
def init_from_config_proj(
|
| 344 |
+
self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
|
| 345 |
+
) -> torch.nn.Linear:
|
| 346 |
+
"""Initialize the Linear proj in this module"""
|
| 347 |
+
num_heads = config.num_key_value_heads if key in {"k", "v"} else config.num_attention_heads
|
| 348 |
+
return torch.nn.Linear(
|
| 349 |
+
config.hidden_size,
|
| 350 |
+
num_heads * config.head_dim,
|
| 351 |
+
bias=config.attention_bias if key != "o" else False,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# NORMALISATION LAYER
|
| 356 |
+
def __rms_norm_forward__(
|
| 357 |
+
hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6
|
| 358 |
+
) -> torch.Tensor:
|
| 359 |
+
input_dtype = hidden_states.dtype
|
| 360 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 361 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 362 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| 363 |
+
return weight * hidden_states.to(input_dtype)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class Helium1RMSNorm(nn.Module):
|
| 367 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
| 368 |
+
"""
|
| 369 |
+
Helium1RMSNorm is equivalent to T5LayerNorm
|
| 370 |
+
"""
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 373 |
+
self.variance_epsilon = eps
|
| 374 |
+
|
| 375 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 376 |
+
return __rms_norm_forward__(hidden_states, self.weight, self.variance_epsilon)
|
| 377 |
+
|
| 378 |
+
def extra_repr(self):
|
| 379 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def delta_w_factory_rms_norm(
|
| 383 |
+
org_lin: Helium1RMSNorm, new_lin: Helium1RMSNorm
|
| 384 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 385 |
+
"""Factory for building rms norm where the weights are the sum of two layers' weights"""
|
| 386 |
+
|
| 387 |
+
def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
|
| 388 |
+
nonlocal org_lin, new_lin
|
| 389 |
+
return __rms_norm_forward__(
|
| 390 |
+
input, org_lin.weight + new_lin.weight, new_lin.variance_epsilon
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
return _delta_w_fwd
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# FULL CONNECTED LAYER
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class HeliumMLP(nn.Module):
|
| 400 |
+
def __init__(self, config: Helium1CASAConfig) -> None:
|
| 401 |
+
super().__init__()
|
| 402 |
+
self.config = config
|
| 403 |
+
self.hidden_size = config.hidden_size
|
| 404 |
+
self.intermediate_size = config.intermediate_size
|
| 405 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 406 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 407 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 408 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 409 |
+
|
| 410 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 411 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 412 |
+
return down_proj
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class HeliumDecoderLayer(nn.Module):
|
| 416 |
+
def __init__(self, config: Helium1CASAConfig, layer_idx: None | int = None):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.hidden_size = config.hidden_size
|
| 419 |
+
self.config = config
|
| 420 |
+
self.mlp = HeliumMLP(config)
|
| 421 |
+
self.input_layernorm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 422 |
+
self.post_attention_layernorm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 423 |
+
|
| 424 |
+
# Self-attention
|
| 425 |
+
self.self_attn = HeliumAttention(config=config, layer_idx=layer_idx)
|
| 426 |
+
|
| 427 |
+
# Setup norm for fusion mechanisms; Note that this norm is on the text tokens
|
| 428 |
+
is_xa_layer = layer_idx is None or not config.xa_layers or layer_idx in config.xa_layers
|
| 429 |
+
self.norm_cross: None | Helium1RMSNorm = None
|
| 430 |
+
self.override_norm_cross: Callable[[torch.Tensor], torch.Tensor] | None = None
|
| 431 |
+
if is_xa_layer and config.casa_attention:
|
| 432 |
+
# Custom normalization layer for the extra fusion module
|
| 433 |
+
if self.config.xa_custom_norm:
|
| 434 |
+
self.norm_cross = Helium1RMSNorm(config.hidden_size)
|
| 435 |
+
if config.casa_delta_w:
|
| 436 |
+
self.override_norm_cross = delta_w_factory_rms_norm(
|
| 437 |
+
self.input_layernorm, self.norm_cross
|
| 438 |
+
)
|
| 439 |
+
with torch.no_grad():
|
| 440 |
+
torch.nn.init.ones_(self.norm_cross.weight)
|
| 441 |
+
|
| 442 |
+
# Setup additional norm for images tokens which is set in each individual mechansims
|
| 443 |
+
norm_on_images_fn = (
|
| 444 |
+
None
|
| 445 |
+
if not self.config.xa_norm_on_images
|
| 446 |
+
else self.override_norm_cross
|
| 447 |
+
if self.override_norm_cross is not None
|
| 448 |
+
else self.norm_cross.forward
|
| 449 |
+
if self.norm_cross is not None
|
| 450 |
+
else self.input_layernorm.forward
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# CASA
|
| 454 |
+
self.casa_attn: Helium1CASAAttention | None = None
|
| 455 |
+
if config.casa_attention and is_xa_layer:
|
| 456 |
+
self.casa_attn = Helium1CASAAttention(
|
| 457 |
+
config, layer_idx, self_attn=self.self_attn, input_layernorm_fn=norm_on_images_fn
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def forward(
|
| 461 |
+
self,
|
| 462 |
+
hidden_states: torch.Tensor,
|
| 463 |
+
attention_mask: None | torch.Tensor = None,
|
| 464 |
+
position_ids: None | torch.LongTensor = None,
|
| 465 |
+
past_key_values: None | Cache = None,
|
| 466 |
+
output_attentions: None | bool = False,
|
| 467 |
+
use_cache: None | bool = False,
|
| 468 |
+
cache_position: None | torch.LongTensor = None,
|
| 469 |
+
position_embeddings: None
|
| 470 |
+
| tuple[torch.Tensor, torch.Tensor] = None, # necessary, but kept here for BC
|
| 471 |
+
# CASA
|
| 472 |
+
casa_handler: CASAAttentionHandler | None = None,
|
| 473 |
+
cu_seqlens: torch.Tensor | None = None,
|
| 474 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 475 |
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
|
| 476 |
+
# Image fusion mechanisms
|
| 477 |
+
apply_ca = self.casa_attn is not None
|
| 478 |
+
ca_update: torch.Tensor | None = None
|
| 479 |
+
if (
|
| 480 |
+
self.config.xa_order
|
| 481 |
+
in {
|
| 482 |
+
"parallel",
|
| 483 |
+
"ca_first",
|
| 484 |
+
"instead",
|
| 485 |
+
}
|
| 486 |
+
and apply_ca
|
| 487 |
+
):
|
| 488 |
+
# Apply layer norm
|
| 489 |
+
assert self.norm_cross is not None
|
| 490 |
+
ca_input = (
|
| 491 |
+
self.override_norm_cross
|
| 492 |
+
if self.override_norm_cross is not None
|
| 493 |
+
else self.norm_cross
|
| 494 |
+
)(hidden_states)
|
| 495 |
+
# CASA
|
| 496 |
+
if self.casa_attn is not None:
|
| 497 |
+
ca_update = self.casa_attn(ca_input, casa_handler=casa_handler)
|
| 498 |
+
|
| 499 |
+
# If we're here, it's because we had proper inputs (no text-only samples)
|
| 500 |
+
# so the output better be not None !
|
| 501 |
+
if ca_update is not None:
|
| 502 |
+
# `instead`: directly return the output of the CA module as residual
|
| 503 |
+
if self.config.xa_order == "instead":
|
| 504 |
+
outputs = (hidden_states + ca_update,)
|
| 505 |
+
if output_attentions:
|
| 506 |
+
outputs += (
|
| 507 |
+
torch.zeros((), device=ca_update.device, dtype=ca_update.dtype),
|
| 508 |
+
)
|
| 509 |
+
return outputs
|
| 510 |
+
|
| 511 |
+
# `ca_first`: update then continue with normal self-attention
|
| 512 |
+
if self.config.xa_order == "ca_first":
|
| 513 |
+
hidden_states = hidden_states + ca_update
|
| 514 |
+
ca_update = None
|
| 515 |
+
|
| 516 |
+
# Self Attention with initial input layer norm
|
| 517 |
+
residual = hidden_states
|
| 518 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 519 |
+
hidden_states=self.input_layernorm(hidden_states),
|
| 520 |
+
attention_mask=attention_mask,
|
| 521 |
+
position_ids=position_ids,
|
| 522 |
+
past_key_values=past_key_values,
|
| 523 |
+
output_attentions=output_attentions,
|
| 524 |
+
use_cache=use_cache,
|
| 525 |
+
cache_position=cache_position,
|
| 526 |
+
position_embeddings=position_embeddings,
|
| 527 |
+
cu_seqlens=cu_seqlens,
|
| 528 |
+
**kwargs,
|
| 529 |
+
)
|
| 530 |
+
hidden_states = residual + hidden_states
|
| 531 |
+
|
| 532 |
+
# parallel - residual update
|
| 533 |
+
if self.config.xa_order == "parallel" and apply_ca and ca_update is not None:
|
| 534 |
+
hidden_states = hidden_states + ca_update
|
| 535 |
+
|
| 536 |
+
# Fully Connected layer
|
| 537 |
+
residual = hidden_states
|
| 538 |
+
# MLP updates for image embeddings
|
| 539 |
+
if (
|
| 540 |
+
self.config.xa_update_image_embeds
|
| 541 |
+
and self.casa_attn is not None
|
| 542 |
+
and casa_handler is not None
|
| 543 |
+
and casa_handler.image_embeds is not None
|
| 544 |
+
):
|
| 545 |
+
# Text flattening
|
| 546 |
+
hs = self.post_attention_layernorm(hidden_states).reshape(-1, hidden_states.shape[-1])
|
| 547 |
+
# Image flattening
|
| 548 |
+
img_seq_lengths = [_x.shape[0] for _x in casa_handler.image_embeds]
|
| 549 |
+
img_residual = torch.cat(list(casa_handler.image_embeds), dim=0)
|
| 550 |
+
update = self.mlp(torch.cat([hs, self.post_attention_layernorm(img_residual)], dim=0))
|
| 551 |
+
# update text
|
| 552 |
+
hidden_states = hidden_states + update[: hs.shape[0]].reshape(hidden_states.shape)
|
| 553 |
+
casa_handler.image_embeds = list(
|
| 554 |
+
torch.split(img_residual + update[hs.shape[0] :], img_seq_lengths)
|
| 555 |
+
)
|
| 556 |
+
else:
|
| 557 |
+
hidden_states = self.mlp(self.post_attention_layernorm(hidden_states))
|
| 558 |
+
hidden_states = residual + hidden_states
|
| 559 |
+
|
| 560 |
+
# Outputs
|
| 561 |
+
outputs = (hidden_states,)
|
| 562 |
+
if output_attentions:
|
| 563 |
+
outputs += (self_attn_weights,)
|
| 564 |
+
|
| 565 |
+
return outputs
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# FULL HELIUM MODEL
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
@dataclass
|
| 572 |
+
class CausalHeliumOutput(CausalLMOutputWithPast):
|
| 573 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 574 |
+
num_image_tokens_log: Optional[torch.Tensor] = None
|
| 575 |
+
num_text_tokens_log: Optional[torch.Tensor] = None
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class Helium1PreTrainedModel(PreTrainedModel):
|
| 579 |
+
config_class = Helium1CASAConfig
|
| 580 |
+
base_model_prefix = "model"
|
| 581 |
+
supports_gradient_checkpointing = True
|
| 582 |
+
_no_split_modules = ["HeliumDecoderLayer"]
|
| 583 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 584 |
+
_supports_flash_attn_2 = True
|
| 585 |
+
_supports_sdpa = True
|
| 586 |
+
_supports_flex_attn = True
|
| 587 |
+
_supports_cache_class = True
|
| 588 |
+
_supports_quantized_cache = True
|
| 589 |
+
_supports_static_cache = True
|
| 590 |
+
_supports_attention_backend = True
|
| 591 |
+
|
| 592 |
+
def _init_weights(self, module: torch.nn.Module) -> None:
|
| 593 |
+
std = self.config.initializer_range
|
| 594 |
+
if isinstance(module, nn.Linear):
|
| 595 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 596 |
+
if module.bias is not None:
|
| 597 |
+
module.bias.data.zero_()
|
| 598 |
+
elif isinstance(module, nn.Embedding):
|
| 599 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 600 |
+
if module.padding_idx is not None:
|
| 601 |
+
module.weight.data[module.padding_idx].zero_()
|
| 602 |
+
elif isinstance(module, Helium1RMSNorm):
|
| 603 |
+
module.weight.data.fill_(1.0)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class Helium1Model(Helium1PreTrainedModel):
|
| 607 |
+
"""
|
| 608 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
config: Helium1CASAConfig
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
def __init__(self, config: Helium1CASAConfig):
|
| 615 |
+
Helium1PreTrainedModel.__init__(self, config)
|
| 616 |
+
self.training: bool
|
| 617 |
+
self._gradient_checkpointing_func: Callable
|
| 618 |
+
self.config = config
|
| 619 |
+
self.padding_idx = config.pad_token_id
|
| 620 |
+
self.vocab_size = config.vocab_size
|
| 621 |
+
|
| 622 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 623 |
+
self.layers = nn.ModuleList(
|
| 624 |
+
[HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 625 |
+
)
|
| 626 |
+
self.norm = Helium1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 627 |
+
self.rotary_emb = HeliumRotaryEmbedding(config=config)
|
| 628 |
+
self.gradient_checkpointing = False
|
| 629 |
+
|
| 630 |
+
# Initialize weights and apply final processing
|
| 631 |
+
self.post_init()
|
| 632 |
+
|
| 633 |
+
def get_input_embeddings(self):
|
| 634 |
+
return self.embed_tokens
|
| 635 |
+
|
| 636 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 637 |
+
self.embed_tokens = value
|
| 638 |
+
|
| 639 |
+
@can_return_tuple
|
| 640 |
+
def forward(
|
| 641 |
+
self,
|
| 642 |
+
input_ids: None | torch.LongTensor = None,
|
| 643 |
+
attention_mask: None | torch.Tensor = None,
|
| 644 |
+
position_ids: None | torch.Tensor = None,
|
| 645 |
+
past_key_values: None | DynamicCache = None,
|
| 646 |
+
inputs_embeds: None | torch.Tensor = None,
|
| 647 |
+
use_cache: None | bool = None,
|
| 648 |
+
output_attentions: None | bool = None,
|
| 649 |
+
output_hidden_states: None | bool = None,
|
| 650 |
+
cache_position: None | torch.Tensor = None,
|
| 651 |
+
# Insertion
|
| 652 |
+
image_tokens_mask: torch.Tensor | None = None,
|
| 653 |
+
# CASA
|
| 654 |
+
casa_handler: CASAAttentionHandler | None = None,
|
| 655 |
+
cu_seqlens: torch.Tensor | None = None,
|
| 656 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 657 |
+
) -> BaseModelOutputWithPast:
|
| 658 |
+
output_attentions = (
|
| 659 |
+
output_attentions if output_attentions is not None else self.config.output_attentions
|
| 660 |
+
)
|
| 661 |
+
output_hidden_states = (
|
| 662 |
+
output_hidden_states
|
| 663 |
+
if output_hidden_states is not None
|
| 664 |
+
else self.config.output_hidden_states
|
| 665 |
+
)
|
| 666 |
+
use_cache = not self.training and (
|
| 667 |
+
use_cache if use_cache is not None else self.config.use_cache
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 671 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 672 |
+
|
| 673 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 674 |
+
print(
|
| 675 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 676 |
+
)
|
| 677 |
+
use_cache = False
|
| 678 |
+
|
| 679 |
+
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 680 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 681 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 682 |
+
|
| 683 |
+
if inputs_embeds is None:
|
| 684 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 685 |
+
assert inputs_embeds is not None
|
| 686 |
+
|
| 687 |
+
if use_cache and past_key_values is None:
|
| 688 |
+
past_key_values = DynamicCache()
|
| 689 |
+
|
| 690 |
+
if cache_position is None:
|
| 691 |
+
past_seen_tokens = 0 if past_key_values is None else past_key_values._seen_tokens
|
| 692 |
+
assert inputs_embeds is not None
|
| 693 |
+
cache_position = torch.arange(
|
| 694 |
+
past_seen_tokens,
|
| 695 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 696 |
+
device=inputs_embeds.device,
|
| 697 |
+
)
|
| 698 |
+
assert cache_position is not None
|
| 699 |
+
|
| 700 |
+
if position_ids is None:
|
| 701 |
+
position_ids = cache_position.unsqueeze(0)
|
| 702 |
+
|
| 703 |
+
# Get attention mask
|
| 704 |
+
causal_mask: None | torch.Tensor = self._update_causal_mask(
|
| 705 |
+
attention_mask,
|
| 706 |
+
inputs_embeds,
|
| 707 |
+
cache_position,
|
| 708 |
+
past_key_values,
|
| 709 |
+
output_attentions,
|
| 710 |
+
force_mask=False,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# create position embeddings to be shared across the decoder layers
|
| 714 |
+
hidden_states = inputs_embeds
|
| 715 |
+
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
|
| 716 |
+
|
| 717 |
+
# decoder layers
|
| 718 |
+
all_hidden_states = () if output_hidden_states else None
|
| 719 |
+
all_self_attns = () if output_attentions else None
|
| 720 |
+
|
| 721 |
+
for decoder_layer_idx, decoder_layer in enumerate(
|
| 722 |
+
self.layers[: self.config.num_hidden_layers]
|
| 723 |
+
):
|
| 724 |
+
is_xa_layer = not self.config.xa_layers or decoder_layer_idx in self.config.xa_layers
|
| 725 |
+
if output_hidden_states is not None:
|
| 726 |
+
if all_hidden_states is None:
|
| 727 |
+
all_hidden_states = ()
|
| 728 |
+
all_hidden_states += (hidden_states,)
|
| 729 |
+
|
| 730 |
+
if self.gradient_checkpointing and self.training:
|
| 731 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 732 |
+
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
| 733 |
+
hidden_states,
|
| 734 |
+
causal_mask,
|
| 735 |
+
position_ids,
|
| 736 |
+
past_key_values,
|
| 737 |
+
output_attentions,
|
| 738 |
+
use_cache,
|
| 739 |
+
cache_position,
|
| 740 |
+
position_embeddings,
|
| 741 |
+
casa_handler if is_xa_layer else None,
|
| 742 |
+
cu_seqlens,
|
| 743 |
+
)
|
| 744 |
+
else:
|
| 745 |
+
layer_outputs = decoder_layer(
|
| 746 |
+
hidden_states,
|
| 747 |
+
attention_mask=causal_mask,
|
| 748 |
+
position_ids=position_ids,
|
| 749 |
+
past_key_values=past_key_values,
|
| 750 |
+
output_attentions=output_attentions,
|
| 751 |
+
use_cache=use_cache,
|
| 752 |
+
cache_position=cache_position,
|
| 753 |
+
position_embeddings=position_embeddings,
|
| 754 |
+
casa_handler=casa_handler if is_xa_layer else None,
|
| 755 |
+
cu_seqlens=cu_seqlens,
|
| 756 |
+
**flash_attn_kwargs,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
hidden_states = layer_outputs[0]
|
| 760 |
+
|
| 761 |
+
if output_attentions:
|
| 762 |
+
if all_self_attns is None:
|
| 763 |
+
all_self_attns = ()
|
| 764 |
+
all_self_attns += (layer_outputs[1],)
|
| 765 |
+
|
| 766 |
+
hidden_states = self.norm(hidden_states)
|
| 767 |
+
|
| 768 |
+
# add hidden states from the last decoder layer
|
| 769 |
+
if output_hidden_states:
|
| 770 |
+
if all_hidden_states is None:
|
| 771 |
+
all_hidden_states = ()
|
| 772 |
+
all_hidden_states += (hidden_states,)
|
| 773 |
+
|
| 774 |
+
return BaseModelOutputWithPast(
|
| 775 |
+
last_hidden_state=hidden_states,
|
| 776 |
+
past_key_values=past_key_values if use_cache else None, # pyright: ignore[reportArgumentType]
|
| 777 |
+
hidden_states=all_hidden_states, # pyright: ignore[reportArgumentType]
|
| 778 |
+
attentions=all_self_attns,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
def _update_causal_mask(
|
| 782 |
+
self,
|
| 783 |
+
attention_mask: torch.Tensor | None,
|
| 784 |
+
input_tensor: torch.Tensor,
|
| 785 |
+
cache_position: torch.Tensor,
|
| 786 |
+
past_key_values: None | DynamicCache | Cache,
|
| 787 |
+
output_attentions: bool = False,
|
| 788 |
+
force_mask: bool = False,
|
| 789 |
+
) -> torch.Tensor | None:
|
| 790 |
+
if self.config._attn_implementation == "flex_attention":
|
| 791 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 792 |
+
attention_mask = make_flex_block_causal_mask(attention_mask) # type: ignore
|
| 793 |
+
return attention_mask
|
| 794 |
+
|
| 795 |
+
assert attention_mask is None or isinstance(attention_mask, torch.Tensor)
|
| 796 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 797 |
+
if attention_mask is not None and (force_mask or (attention_mask == 0.0).any()):
|
| 798 |
+
return attention_mask
|
| 799 |
+
return None
|
| 800 |
+
|
| 801 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 802 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 803 |
+
# to infer the attention mask.
|
| 804 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 805 |
+
using_compilable_cache = (
|
| 806 |
+
past_key_values.is_compileable if past_key_values is not None else False
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 810 |
+
if (
|
| 811 |
+
self.config._attn_implementation == "sdpa"
|
| 812 |
+
and not using_compilable_cache
|
| 813 |
+
and not output_attentions
|
| 814 |
+
):
|
| 815 |
+
if not force_mask and AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 816 |
+
attention_mask,
|
| 817 |
+
inputs_embeds=input_tensor,
|
| 818 |
+
past_key_values_length=past_seen_tokens,
|
| 819 |
+
is_training=self.training,
|
| 820 |
+
):
|
| 821 |
+
return None
|
| 822 |
+
|
| 823 |
+
dtype = input_tensor.dtype
|
| 824 |
+
sequence_length = input_tensor.shape[1]
|
| 825 |
+
if using_compilable_cache and past_key_values is not None:
|
| 826 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 827 |
+
else:
|
| 828 |
+
target_length = (
|
| 829 |
+
attention_mask.shape[-1]
|
| 830 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 831 |
+
else past_seen_tokens + sequence_length
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 835 |
+
assert target_length is not None
|
| 836 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 837 |
+
attention_mask,
|
| 838 |
+
sequence_length=sequence_length,
|
| 839 |
+
target_length=target_length,
|
| 840 |
+
dtype=dtype,
|
| 841 |
+
cache_position=cache_position,
|
| 842 |
+
batch_size=input_tensor.shape[0],
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
if (
|
| 846 |
+
self.config._attn_implementation == "sdpa"
|
| 847 |
+
and attention_mask is not None
|
| 848 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 849 |
+
and not output_attentions
|
| 850 |
+
):
|
| 851 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 852 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 853 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 854 |
+
min_dtype = torch.finfo(dtype).min
|
| 855 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(
|
| 856 |
+
type_cast(torch.FloatTensor, causal_mask), min_dtype
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
return causal_mask
|
| 860 |
+
|
| 861 |
+
@staticmethod
|
| 862 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 863 |
+
attention_mask: torch.Tensor | None,
|
| 864 |
+
sequence_length: int,
|
| 865 |
+
target_length: int,
|
| 866 |
+
dtype: torch.dtype,
|
| 867 |
+
cache_position: torch.Tensor,
|
| 868 |
+
batch_size: int,
|
| 869 |
+
**kwargs: Any,
|
| 870 |
+
):
|
| 871 |
+
"""
|
| 872 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 873 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
attention_mask (`torch.Tensor`):
|
| 877 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 878 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 879 |
+
sequence_length (`int`):
|
| 880 |
+
The sequence length being processed.
|
| 881 |
+
target_length (`int`):
|
| 882 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 883 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 884 |
+
dtype (`torch.dtype`):
|
| 885 |
+
The dtype to use for the 4D attention mask.
|
| 886 |
+
cache_position (`torch.Tensor`):
|
| 887 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 888 |
+
batch_size (`torch.Tensor`):
|
| 889 |
+
Batch size.
|
| 890 |
+
"""
|
| 891 |
+
del kwargs
|
| 892 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 893 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 894 |
+
causal_mask = attention_mask
|
| 895 |
+
else:
|
| 896 |
+
min_dtype = torch.finfo(dtype).min
|
| 897 |
+
causal_mask = torch.full(
|
| 898 |
+
(sequence_length, target_length),
|
| 899 |
+
fill_value=min_dtype,
|
| 900 |
+
dtype=dtype,
|
| 901 |
+
device=cache_position.device,
|
| 902 |
+
)
|
| 903 |
+
if sequence_length != 1:
|
| 904 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 905 |
+
causal_mask *= torch.arange(
|
| 906 |
+
target_length, device=cache_position.device
|
| 907 |
+
) > cache_position.reshape(-1, 1)
|
| 908 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 909 |
+
if attention_mask is not None:
|
| 910 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 911 |
+
mask_length = attention_mask.shape[-1]
|
| 912 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
|
| 913 |
+
:, None, None, :
|
| 914 |
+
].to(causal_mask.device)
|
| 915 |
+
padding_mask = padding_mask == 0
|
| 916 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 917 |
+
padding_mask, min_dtype
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
return causal_mask
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
class Helium1ForCausalLM(Helium1PreTrainedModel, GenerationMixin):
|
| 927 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 928 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 929 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 930 |
+
|
| 931 |
+
def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
|
| 932 |
+
del kwargs
|
| 933 |
+
super().__init__(config)
|
| 934 |
+
self.model: Helium1Model
|
| 935 |
+
self.model = Helium1Model(config)
|
| 936 |
+
self.vocab_size = config.vocab_size
|
| 937 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 938 |
+
self._loss_function = ForCausalLMLoss
|
| 939 |
+
|
| 940 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 941 |
+
return self.model.embed_tokens
|
| 942 |
+
|
| 943 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 944 |
+
self.model.embed_tokens = value
|
| 945 |
+
|
| 946 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 947 |
+
return self.lm_head
|
| 948 |
+
|
| 949 |
+
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
| 950 |
+
self.lm_head = new_embeddings
|
| 951 |
+
|
| 952 |
+
def set_decoder(self, decoder: Helium1Model) -> None:
|
| 953 |
+
self.model = decoder
|
| 954 |
+
|
| 955 |
+
def get_decoder(self) -> Helium1Model:
|
| 956 |
+
return self.model
|
| 957 |
+
|
| 958 |
+
@can_return_tuple
|
| 959 |
+
def forward(
|
| 960 |
+
self,
|
| 961 |
+
input_ids: None | torch.LongTensor = None,
|
| 962 |
+
attention_mask: None | torch.Tensor = None,
|
| 963 |
+
position_ids: None | torch.LongTensor = None,
|
| 964 |
+
past_key_values: None | Cache = None,
|
| 965 |
+
inputs_embeds: None | torch.Tensor = None,
|
| 966 |
+
image_embeds: None | torch.Tensor | list[torch.Tensor] = None,
|
| 967 |
+
image_embeds_insertion_points: None | list[torch.Tensor] = None,
|
| 968 |
+
labels: None | torch.LongTensor = None,
|
| 969 |
+
use_cache: None | bool = None,
|
| 970 |
+
output_attentions: None | bool = None,
|
| 971 |
+
output_hidden_states: None | bool = None,
|
| 972 |
+
cache_position: None | torch.LongTensor = None,
|
| 973 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 974 |
+
# CASA
|
| 975 |
+
casa_windows_info: None | dict = None,
|
| 976 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 977 |
+
) -> CausalHeliumOutput:
|
| 978 |
+
r"""
|
| 979 |
+
Helium1 augmented with CASA layers
|
| 980 |
+
"""
|
| 981 |
+
output_attentions = (
|
| 982 |
+
output_attentions if output_attentions is not None else self.config.output_attentions
|
| 983 |
+
)
|
| 984 |
+
output_hidden_states = (
|
| 985 |
+
output_hidden_states
|
| 986 |
+
if output_hidden_states is not None
|
| 987 |
+
else self.config.output_hidden_states
|
| 988 |
+
)
|
| 989 |
+
if input_ids is not None:
|
| 990 |
+
assert inputs_embeds is None, (
|
| 991 |
+
"Need to provide only one of `input_ids` or `inputs_embeds`."
|
| 992 |
+
)
|
| 993 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
| 994 |
+
assert inputs_embeds is not None
|
| 995 |
+
|
| 996 |
+
# Setup image + text token fusion
|
| 997 |
+
bs, og_seq_len, _ = inputs_embeds.shape
|
| 998 |
+
image_tokens_mask: torch.Tensor | None = None
|
| 999 |
+
casa_handler: CASAAttentionHandler | None = None
|
| 1000 |
+
|
| 1001 |
+
num_image_tokens = -1
|
| 1002 |
+
if image_embeds is not None:
|
| 1003 |
+
num_image_tokens = sum(_x.shape[0] for _x in image_embeds)
|
| 1004 |
+
assert image_embeds_insertion_points is not None, (
|
| 1005 |
+
"Missing image embeddings insertion points"
|
| 1006 |
+
)
|
| 1007 |
+
# B1. CASA layers: We need to init the shared Handler
|
| 1008 |
+
if self.model.config.casa_attention:
|
| 1009 |
+
casa_handler = CASAAttentionHandler(
|
| 1010 |
+
# for text tokens, we don't need the actual values
|
| 1011 |
+
inputs_embeds=torch.zeros_like(inputs_embeds),
|
| 1012 |
+
# for image embeddings, we put real inputs as this will be fixed
|
| 1013 |
+
image_embeds=image_embeds,
|
| 1014 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 1015 |
+
# attention mask is only needed at inference / left padding
|
| 1016 |
+
attention_mask=None if self.training else attention_mask,
|
| 1017 |
+
rope_fn=self.model.rotary_emb,
|
| 1018 |
+
windows=self.model.config.casa_windows,
|
| 1019 |
+
use_asymetric_q_kv=self.model.config.casa_use_asymetric_qkv,
|
| 1020 |
+
# further params are fed to the funtion computing attention
|
| 1021 |
+
casa_windows_info=casa_windows_info,
|
| 1022 |
+
)
|
| 1023 |
+
# B2. Direct image insertion
|
| 1024 |
+
else:
|
| 1025 |
+
inputs_embeds, _, attention_mask, image_tokens_mask = insert_image_tokens(
|
| 1026 |
+
inputs_embeds=inputs_embeds,
|
| 1027 |
+
image_embeds=image_embeds,
|
| 1028 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 1029 |
+
attention_mask=attention_mask,
|
| 1030 |
+
padding_side="right" if self.training else "left",
|
| 1031 |
+
recover_batch_dim=True,
|
| 1032 |
+
)
|
| 1033 |
+
|
| 1034 |
+
del image_embeds
|
| 1035 |
+
del input_ids
|
| 1036 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 1037 |
+
inputs_embeds=inputs_embeds,
|
| 1038 |
+
attention_mask=attention_mask,
|
| 1039 |
+
position_ids=position_ids,
|
| 1040 |
+
past_key_values=past_key_values,
|
| 1041 |
+
use_cache=use_cache,
|
| 1042 |
+
output_attentions=output_attentions,
|
| 1043 |
+
output_hidden_states=output_hidden_states,
|
| 1044 |
+
cache_position=cache_position,
|
| 1045 |
+
image_tokens_mask=image_tokens_mask,
|
| 1046 |
+
casa_handler=casa_handler,
|
| 1047 |
+
**kwargs,
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
hidden_states = outputs.last_hidden_state
|
| 1051 |
+
assert hidden_states is not None
|
| 1052 |
+
if image_tokens_mask is not None:
|
| 1053 |
+
hidden_states = remove_image_tokens(hidden_states, image_tokens_mask)
|
| 1054 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1055 |
+
slice_indices = (
|
| 1056 |
+
slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1057 |
+
)
|
| 1058 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1059 |
+
|
| 1060 |
+
loss = None
|
| 1061 |
+
if labels is not None:
|
| 1062 |
+
loss = self.loss_function(
|
| 1063 |
+
logits=logits,
|
| 1064 |
+
labels=labels,
|
| 1065 |
+
vocab_size=self.config.vocab_size,
|
| 1066 |
+
**kwargs,
|
| 1067 |
+
)
|
| 1068 |
+
out = CausalHeliumOutput(
|
| 1069 |
+
loss=loss,
|
| 1070 |
+
logits=logits,
|
| 1071 |
+
past_key_values=outputs.past_key_values,
|
| 1072 |
+
hidden_states=outputs.hidden_states,
|
| 1073 |
+
attentions=outputs.attentions,
|
| 1074 |
+
num_image_tokens_log=torch.tensor(num_image_tokens).to(logits.device).to(torch.float),
|
| 1075 |
+
num_text_tokens_log=torch.tensor(og_seq_len).to(logits.device).to(torch.float),
|
| 1076 |
+
)
|
| 1077 |
+
return out
|
model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84b4a971906aaacc7c0acf8ac98d4f59afe073954284790855b70f8ab8488df3
|
| 3 |
+
size 4987411648
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea916a815ac0deb1a397aa6ef2edf5a8baca4811327597fccf2f21ef67cf4295
|
| 3 |
+
size 4993506144
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e0bc7d4d08240d3de6fedfbe8711c16345c70e0fbebddfc89685cb361f208aeb
|
| 3 |
+
size 2195900304
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 12176723968
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"image_prefix.enc.visual.blocks.0.attn.proj.bias": "model-00002-of-00003.safetensors",
|
| 7 |
+
"image_prefix.enc.visual.blocks.0.attn.proj.weight": "model-00002-of-00003.safetensors",
|
| 8 |
+
"image_prefix.enc.visual.blocks.0.attn.qkv.bias": "model-00002-of-00003.safetensors",
|
| 9 |
+
"image_prefix.enc.visual.blocks.0.attn.qkv.weight": "model-00002-of-00003.safetensors",
|
| 10 |
+
"image_prefix.enc.visual.blocks.0.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
|
| 11 |
+
"image_prefix.enc.visual.blocks.0.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 12 |
+
"image_prefix.enc.visual.blocks.0.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
|
| 13 |
+
"image_prefix.enc.visual.blocks.0.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 14 |
+
"image_prefix.enc.visual.blocks.0.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
|
| 15 |
+
"image_prefix.enc.visual.blocks.0.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 16 |
+
"image_prefix.enc.visual.blocks.0.norm1.weight": "model-00002-of-00003.safetensors",
|
| 17 |
+
"image_prefix.enc.visual.blocks.0.norm2.weight": "model-00002-of-00003.safetensors",
|
| 18 |
+
"image_prefix.enc.visual.blocks.1.attn.proj.bias": "model-00002-of-00003.safetensors",
|
| 19 |
+
"image_prefix.enc.visual.blocks.1.attn.proj.weight": "model-00002-of-00003.safetensors",
|
| 20 |
+
"image_prefix.enc.visual.blocks.1.attn.qkv.bias": "model-00002-of-00003.safetensors",
|
| 21 |
+
"image_prefix.enc.visual.blocks.1.attn.qkv.weight": "model-00002-of-00003.safetensors",
|
| 22 |
+
"image_prefix.enc.visual.blocks.1.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
|
| 23 |
+
"image_prefix.enc.visual.blocks.1.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 24 |
+
"image_prefix.enc.visual.blocks.1.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
|
| 25 |
+
"image_prefix.enc.visual.blocks.1.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 26 |
+
"image_prefix.enc.visual.blocks.1.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
|
| 27 |
+
"image_prefix.enc.visual.blocks.1.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 28 |
+
"image_prefix.enc.visual.blocks.1.norm1.weight": "model-00002-of-00003.safetensors",
|
| 29 |
+
"image_prefix.enc.visual.blocks.1.norm2.weight": "model-00002-of-00003.safetensors",
|
| 30 |
+
"image_prefix.enc.visual.blocks.10.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 31 |
+
"image_prefix.enc.visual.blocks.10.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 32 |
+
"image_prefix.enc.visual.blocks.10.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 33 |
+
"image_prefix.enc.visual.blocks.10.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 34 |
+
"image_prefix.enc.visual.blocks.10.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 35 |
+
"image_prefix.enc.visual.blocks.10.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 36 |
+
"image_prefix.enc.visual.blocks.10.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 37 |
+
"image_prefix.enc.visual.blocks.10.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 38 |
+
"image_prefix.enc.visual.blocks.10.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 39 |
+
"image_prefix.enc.visual.blocks.10.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 40 |
+
"image_prefix.enc.visual.blocks.10.norm1.weight": "model-00003-of-00003.safetensors",
|
| 41 |
+
"image_prefix.enc.visual.blocks.10.norm2.weight": "model-00003-of-00003.safetensors",
|
| 42 |
+
"image_prefix.enc.visual.blocks.11.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 43 |
+
"image_prefix.enc.visual.blocks.11.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 44 |
+
"image_prefix.enc.visual.blocks.11.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 45 |
+
"image_prefix.enc.visual.blocks.11.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 46 |
+
"image_prefix.enc.visual.blocks.11.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 47 |
+
"image_prefix.enc.visual.blocks.11.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 48 |
+
"image_prefix.enc.visual.blocks.11.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 49 |
+
"image_prefix.enc.visual.blocks.11.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 50 |
+
"image_prefix.enc.visual.blocks.11.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 51 |
+
"image_prefix.enc.visual.blocks.11.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 52 |
+
"image_prefix.enc.visual.blocks.11.norm1.weight": "model-00003-of-00003.safetensors",
|
| 53 |
+
"image_prefix.enc.visual.blocks.11.norm2.weight": "model-00003-of-00003.safetensors",
|
| 54 |
+
"image_prefix.enc.visual.blocks.12.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 55 |
+
"image_prefix.enc.visual.blocks.12.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 56 |
+
"image_prefix.enc.visual.blocks.12.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 57 |
+
"image_prefix.enc.visual.blocks.12.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 58 |
+
"image_prefix.enc.visual.blocks.12.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 59 |
+
"image_prefix.enc.visual.blocks.12.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 60 |
+
"image_prefix.enc.visual.blocks.12.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 61 |
+
"image_prefix.enc.visual.blocks.12.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 62 |
+
"image_prefix.enc.visual.blocks.12.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 63 |
+
"image_prefix.enc.visual.blocks.12.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 64 |
+
"image_prefix.enc.visual.blocks.12.norm1.weight": "model-00003-of-00003.safetensors",
|
| 65 |
+
"image_prefix.enc.visual.blocks.12.norm2.weight": "model-00003-of-00003.safetensors",
|
| 66 |
+
"image_prefix.enc.visual.blocks.13.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 67 |
+
"image_prefix.enc.visual.blocks.13.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 68 |
+
"image_prefix.enc.visual.blocks.13.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 69 |
+
"image_prefix.enc.visual.blocks.13.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 70 |
+
"image_prefix.enc.visual.blocks.13.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 71 |
+
"image_prefix.enc.visual.blocks.13.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 72 |
+
"image_prefix.enc.visual.blocks.13.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 73 |
+
"image_prefix.enc.visual.blocks.13.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 74 |
+
"image_prefix.enc.visual.blocks.13.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 75 |
+
"image_prefix.enc.visual.blocks.13.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 76 |
+
"image_prefix.enc.visual.blocks.13.norm1.weight": "model-00003-of-00003.safetensors",
|
| 77 |
+
"image_prefix.enc.visual.blocks.13.norm2.weight": "model-00003-of-00003.safetensors",
|
| 78 |
+
"image_prefix.enc.visual.blocks.14.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 79 |
+
"image_prefix.enc.visual.blocks.14.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 80 |
+
"image_prefix.enc.visual.blocks.14.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 81 |
+
"image_prefix.enc.visual.blocks.14.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 82 |
+
"image_prefix.enc.visual.blocks.14.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 83 |
+
"image_prefix.enc.visual.blocks.14.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 84 |
+
"image_prefix.enc.visual.blocks.14.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 85 |
+
"image_prefix.enc.visual.blocks.14.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 86 |
+
"image_prefix.enc.visual.blocks.14.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 87 |
+
"image_prefix.enc.visual.blocks.14.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 88 |
+
"image_prefix.enc.visual.blocks.14.norm1.weight": "model-00003-of-00003.safetensors",
|
| 89 |
+
"image_prefix.enc.visual.blocks.14.norm2.weight": "model-00003-of-00003.safetensors",
|
| 90 |
+
"image_prefix.enc.visual.blocks.15.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 91 |
+
"image_prefix.enc.visual.blocks.15.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 92 |
+
"image_prefix.enc.visual.blocks.15.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 93 |
+
"image_prefix.enc.visual.blocks.15.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 94 |
+
"image_prefix.enc.visual.blocks.15.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 95 |
+
"image_prefix.enc.visual.blocks.15.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 96 |
+
"image_prefix.enc.visual.blocks.15.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 97 |
+
"image_prefix.enc.visual.blocks.15.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 98 |
+
"image_prefix.enc.visual.blocks.15.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 99 |
+
"image_prefix.enc.visual.blocks.15.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 100 |
+
"image_prefix.enc.visual.blocks.15.norm1.weight": "model-00003-of-00003.safetensors",
|
| 101 |
+
"image_prefix.enc.visual.blocks.15.norm2.weight": "model-00003-of-00003.safetensors",
|
| 102 |
+
"image_prefix.enc.visual.blocks.16.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 103 |
+
"image_prefix.enc.visual.blocks.16.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 104 |
+
"image_prefix.enc.visual.blocks.16.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 105 |
+
"image_prefix.enc.visual.blocks.16.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 106 |
+
"image_prefix.enc.visual.blocks.16.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 107 |
+
"image_prefix.enc.visual.blocks.16.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 108 |
+
"image_prefix.enc.visual.blocks.16.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 109 |
+
"image_prefix.enc.visual.blocks.16.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 110 |
+
"image_prefix.enc.visual.blocks.16.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 111 |
+
"image_prefix.enc.visual.blocks.16.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 112 |
+
"image_prefix.enc.visual.blocks.16.norm1.weight": "model-00003-of-00003.safetensors",
|
| 113 |
+
"image_prefix.enc.visual.blocks.16.norm2.weight": "model-00003-of-00003.safetensors",
|
| 114 |
+
"image_prefix.enc.visual.blocks.17.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 115 |
+
"image_prefix.enc.visual.blocks.17.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 116 |
+
"image_prefix.enc.visual.blocks.17.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 117 |
+
"image_prefix.enc.visual.blocks.17.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 118 |
+
"image_prefix.enc.visual.blocks.17.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 119 |
+
"image_prefix.enc.visual.blocks.17.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 120 |
+
"image_prefix.enc.visual.blocks.17.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 121 |
+
"image_prefix.enc.visual.blocks.17.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 122 |
+
"image_prefix.enc.visual.blocks.17.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 123 |
+
"image_prefix.enc.visual.blocks.17.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 124 |
+
"image_prefix.enc.visual.blocks.17.norm1.weight": "model-00003-of-00003.safetensors",
|
| 125 |
+
"image_prefix.enc.visual.blocks.17.norm2.weight": "model-00003-of-00003.safetensors",
|
| 126 |
+
"image_prefix.enc.visual.blocks.18.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 127 |
+
"image_prefix.enc.visual.blocks.18.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 128 |
+
"image_prefix.enc.visual.blocks.18.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 129 |
+
"image_prefix.enc.visual.blocks.18.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 130 |
+
"image_prefix.enc.visual.blocks.18.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 131 |
+
"image_prefix.enc.visual.blocks.18.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 132 |
+
"image_prefix.enc.visual.blocks.18.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 133 |
+
"image_prefix.enc.visual.blocks.18.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 134 |
+
"image_prefix.enc.visual.blocks.18.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 135 |
+
"image_prefix.enc.visual.blocks.18.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 136 |
+
"image_prefix.enc.visual.blocks.18.norm1.weight": "model-00003-of-00003.safetensors",
|
| 137 |
+
"image_prefix.enc.visual.blocks.18.norm2.weight": "model-00003-of-00003.safetensors",
|
| 138 |
+
"image_prefix.enc.visual.blocks.19.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 139 |
+
"image_prefix.enc.visual.blocks.19.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 140 |
+
"image_prefix.enc.visual.blocks.19.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 141 |
+
"image_prefix.enc.visual.blocks.19.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 142 |
+
"image_prefix.enc.visual.blocks.19.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 143 |
+
"image_prefix.enc.visual.blocks.19.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 144 |
+
"image_prefix.enc.visual.blocks.19.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 145 |
+
"image_prefix.enc.visual.blocks.19.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 146 |
+
"image_prefix.enc.visual.blocks.19.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 147 |
+
"image_prefix.enc.visual.blocks.19.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 148 |
+
"image_prefix.enc.visual.blocks.19.norm1.weight": "model-00003-of-00003.safetensors",
|
| 149 |
+
"image_prefix.enc.visual.blocks.19.norm2.weight": "model-00003-of-00003.safetensors",
|
| 150 |
+
"image_prefix.enc.visual.blocks.2.attn.proj.bias": "model-00002-of-00003.safetensors",
|
| 151 |
+
"image_prefix.enc.visual.blocks.2.attn.proj.weight": "model-00002-of-00003.safetensors",
|
| 152 |
+
"image_prefix.enc.visual.blocks.2.attn.qkv.bias": "model-00002-of-00003.safetensors",
|
| 153 |
+
"image_prefix.enc.visual.blocks.2.attn.qkv.weight": "model-00002-of-00003.safetensors",
|
| 154 |
+
"image_prefix.enc.visual.blocks.2.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
|
| 155 |
+
"image_prefix.enc.visual.blocks.2.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 156 |
+
"image_prefix.enc.visual.blocks.2.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
|
| 157 |
+
"image_prefix.enc.visual.blocks.2.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 158 |
+
"image_prefix.enc.visual.blocks.2.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
|
| 159 |
+
"image_prefix.enc.visual.blocks.2.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 160 |
+
"image_prefix.enc.visual.blocks.2.norm1.weight": "model-00002-of-00003.safetensors",
|
| 161 |
+
"image_prefix.enc.visual.blocks.2.norm2.weight": "model-00002-of-00003.safetensors",
|
| 162 |
+
"image_prefix.enc.visual.blocks.20.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 163 |
+
"image_prefix.enc.visual.blocks.20.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 164 |
+
"image_prefix.enc.visual.blocks.20.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 165 |
+
"image_prefix.enc.visual.blocks.20.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 166 |
+
"image_prefix.enc.visual.blocks.20.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 167 |
+
"image_prefix.enc.visual.blocks.20.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 168 |
+
"image_prefix.enc.visual.blocks.20.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 169 |
+
"image_prefix.enc.visual.blocks.20.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 170 |
+
"image_prefix.enc.visual.blocks.20.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 171 |
+
"image_prefix.enc.visual.blocks.20.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 172 |
+
"image_prefix.enc.visual.blocks.20.norm1.weight": "model-00003-of-00003.safetensors",
|
| 173 |
+
"image_prefix.enc.visual.blocks.20.norm2.weight": "model-00003-of-00003.safetensors",
|
| 174 |
+
"image_prefix.enc.visual.blocks.21.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 175 |
+
"image_prefix.enc.visual.blocks.21.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 176 |
+
"image_prefix.enc.visual.blocks.21.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 177 |
+
"image_prefix.enc.visual.blocks.21.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 178 |
+
"image_prefix.enc.visual.blocks.21.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 179 |
+
"image_prefix.enc.visual.blocks.21.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 180 |
+
"image_prefix.enc.visual.blocks.21.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 181 |
+
"image_prefix.enc.visual.blocks.21.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 182 |
+
"image_prefix.enc.visual.blocks.21.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 183 |
+
"image_prefix.enc.visual.blocks.21.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 184 |
+
"image_prefix.enc.visual.blocks.21.norm1.weight": "model-00003-of-00003.safetensors",
|
| 185 |
+
"image_prefix.enc.visual.blocks.21.norm2.weight": "model-00003-of-00003.safetensors",
|
| 186 |
+
"image_prefix.enc.visual.blocks.22.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 187 |
+
"image_prefix.enc.visual.blocks.22.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 188 |
+
"image_prefix.enc.visual.blocks.22.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 189 |
+
"image_prefix.enc.visual.blocks.22.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 190 |
+
"image_prefix.enc.visual.blocks.22.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 191 |
+
"image_prefix.enc.visual.blocks.22.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 192 |
+
"image_prefix.enc.visual.blocks.22.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 193 |
+
"image_prefix.enc.visual.blocks.22.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 194 |
+
"image_prefix.enc.visual.blocks.22.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 195 |
+
"image_prefix.enc.visual.blocks.22.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 196 |
+
"image_prefix.enc.visual.blocks.22.norm1.weight": "model-00003-of-00003.safetensors",
|
| 197 |
+
"image_prefix.enc.visual.blocks.22.norm2.weight": "model-00003-of-00003.safetensors",
|
| 198 |
+
"image_prefix.enc.visual.blocks.23.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 199 |
+
"image_prefix.enc.visual.blocks.23.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 200 |
+
"image_prefix.enc.visual.blocks.23.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 201 |
+
"image_prefix.enc.visual.blocks.23.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 202 |
+
"image_prefix.enc.visual.blocks.23.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 203 |
+
"image_prefix.enc.visual.blocks.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 204 |
+
"image_prefix.enc.visual.blocks.23.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 205 |
+
"image_prefix.enc.visual.blocks.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 206 |
+
"image_prefix.enc.visual.blocks.23.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 207 |
+
"image_prefix.enc.visual.blocks.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 208 |
+
"image_prefix.enc.visual.blocks.23.norm1.weight": "model-00003-of-00003.safetensors",
|
| 209 |
+
"image_prefix.enc.visual.blocks.23.norm2.weight": "model-00003-of-00003.safetensors",
|
| 210 |
+
"image_prefix.enc.visual.blocks.24.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 211 |
+
"image_prefix.enc.visual.blocks.24.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 212 |
+
"image_prefix.enc.visual.blocks.24.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 213 |
+
"image_prefix.enc.visual.blocks.24.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 214 |
+
"image_prefix.enc.visual.blocks.24.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 215 |
+
"image_prefix.enc.visual.blocks.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 216 |
+
"image_prefix.enc.visual.blocks.24.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 217 |
+
"image_prefix.enc.visual.blocks.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 218 |
+
"image_prefix.enc.visual.blocks.24.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 219 |
+
"image_prefix.enc.visual.blocks.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 220 |
+
"image_prefix.enc.visual.blocks.24.norm1.weight": "model-00003-of-00003.safetensors",
|
| 221 |
+
"image_prefix.enc.visual.blocks.24.norm2.weight": "model-00003-of-00003.safetensors",
|
| 222 |
+
"image_prefix.enc.visual.blocks.25.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 223 |
+
"image_prefix.enc.visual.blocks.25.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 224 |
+
"image_prefix.enc.visual.blocks.25.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 225 |
+
"image_prefix.enc.visual.blocks.25.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 226 |
+
"image_prefix.enc.visual.blocks.25.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 227 |
+
"image_prefix.enc.visual.blocks.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 228 |
+
"image_prefix.enc.visual.blocks.25.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 229 |
+
"image_prefix.enc.visual.blocks.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 230 |
+
"image_prefix.enc.visual.blocks.25.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 231 |
+
"image_prefix.enc.visual.blocks.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 232 |
+
"image_prefix.enc.visual.blocks.25.norm1.weight": "model-00003-of-00003.safetensors",
|
| 233 |
+
"image_prefix.enc.visual.blocks.25.norm2.weight": "model-00003-of-00003.safetensors",
|
| 234 |
+
"image_prefix.enc.visual.blocks.26.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 235 |
+
"image_prefix.enc.visual.blocks.26.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 236 |
+
"image_prefix.enc.visual.blocks.26.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 237 |
+
"image_prefix.enc.visual.blocks.26.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 238 |
+
"image_prefix.enc.visual.blocks.26.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 239 |
+
"image_prefix.enc.visual.blocks.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 240 |
+
"image_prefix.enc.visual.blocks.26.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 241 |
+
"image_prefix.enc.visual.blocks.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 242 |
+
"image_prefix.enc.visual.blocks.26.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 243 |
+
"image_prefix.enc.visual.blocks.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 244 |
+
"image_prefix.enc.visual.blocks.26.norm1.weight": "model-00003-of-00003.safetensors",
|
| 245 |
+
"image_prefix.enc.visual.blocks.26.norm2.weight": "model-00003-of-00003.safetensors",
|
| 246 |
+
"image_prefix.enc.visual.blocks.27.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 247 |
+
"image_prefix.enc.visual.blocks.27.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 248 |
+
"image_prefix.enc.visual.blocks.27.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 249 |
+
"image_prefix.enc.visual.blocks.27.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 250 |
+
"image_prefix.enc.visual.blocks.27.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 251 |
+
"image_prefix.enc.visual.blocks.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 252 |
+
"image_prefix.enc.visual.blocks.27.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 253 |
+
"image_prefix.enc.visual.blocks.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 254 |
+
"image_prefix.enc.visual.blocks.27.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 255 |
+
"image_prefix.enc.visual.blocks.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 256 |
+
"image_prefix.enc.visual.blocks.27.norm1.weight": "model-00003-of-00003.safetensors",
|
| 257 |
+
"image_prefix.enc.visual.blocks.27.norm2.weight": "model-00003-of-00003.safetensors",
|
| 258 |
+
"image_prefix.enc.visual.blocks.28.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 259 |
+
"image_prefix.enc.visual.blocks.28.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 260 |
+
"image_prefix.enc.visual.blocks.28.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 261 |
+
"image_prefix.enc.visual.blocks.28.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 262 |
+
"image_prefix.enc.visual.blocks.28.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 263 |
+
"image_prefix.enc.visual.blocks.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 264 |
+
"image_prefix.enc.visual.blocks.28.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 265 |
+
"image_prefix.enc.visual.blocks.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 266 |
+
"image_prefix.enc.visual.blocks.28.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 267 |
+
"image_prefix.enc.visual.blocks.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 268 |
+
"image_prefix.enc.visual.blocks.28.norm1.weight": "model-00003-of-00003.safetensors",
|
| 269 |
+
"image_prefix.enc.visual.blocks.28.norm2.weight": "model-00003-of-00003.safetensors",
|
| 270 |
+
"image_prefix.enc.visual.blocks.29.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 271 |
+
"image_prefix.enc.visual.blocks.29.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 272 |
+
"image_prefix.enc.visual.blocks.29.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 273 |
+
"image_prefix.enc.visual.blocks.29.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 274 |
+
"image_prefix.enc.visual.blocks.29.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 275 |
+
"image_prefix.enc.visual.blocks.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 276 |
+
"image_prefix.enc.visual.blocks.29.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 277 |
+
"image_prefix.enc.visual.blocks.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 278 |
+
"image_prefix.enc.visual.blocks.29.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 279 |
+
"image_prefix.enc.visual.blocks.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 280 |
+
"image_prefix.enc.visual.blocks.29.norm1.weight": "model-00003-of-00003.safetensors",
|
| 281 |
+
"image_prefix.enc.visual.blocks.29.norm2.weight": "model-00003-of-00003.safetensors",
|
| 282 |
+
"image_prefix.enc.visual.blocks.3.attn.proj.bias": "model-00002-of-00003.safetensors",
|
| 283 |
+
"image_prefix.enc.visual.blocks.3.attn.proj.weight": "model-00002-of-00003.safetensors",
|
| 284 |
+
"image_prefix.enc.visual.blocks.3.attn.qkv.bias": "model-00002-of-00003.safetensors",
|
| 285 |
+
"image_prefix.enc.visual.blocks.3.attn.qkv.weight": "model-00002-of-00003.safetensors",
|
| 286 |
+
"image_prefix.enc.visual.blocks.3.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
|
| 287 |
+
"image_prefix.enc.visual.blocks.3.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 288 |
+
"image_prefix.enc.visual.blocks.3.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
|
| 289 |
+
"image_prefix.enc.visual.blocks.3.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 290 |
+
"image_prefix.enc.visual.blocks.3.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
|
| 291 |
+
"image_prefix.enc.visual.blocks.3.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 292 |
+
"image_prefix.enc.visual.blocks.3.norm1.weight": "model-00002-of-00003.safetensors",
|
| 293 |
+
"image_prefix.enc.visual.blocks.3.norm2.weight": "model-00002-of-00003.safetensors",
|
| 294 |
+
"image_prefix.enc.visual.blocks.30.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 295 |
+
"image_prefix.enc.visual.blocks.30.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 296 |
+
"image_prefix.enc.visual.blocks.30.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 297 |
+
"image_prefix.enc.visual.blocks.30.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 298 |
+
"image_prefix.enc.visual.blocks.30.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 299 |
+
"image_prefix.enc.visual.blocks.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 300 |
+
"image_prefix.enc.visual.blocks.30.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 301 |
+
"image_prefix.enc.visual.blocks.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 302 |
+
"image_prefix.enc.visual.blocks.30.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 303 |
+
"image_prefix.enc.visual.blocks.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 304 |
+
"image_prefix.enc.visual.blocks.30.norm1.weight": "model-00003-of-00003.safetensors",
|
| 305 |
+
"image_prefix.enc.visual.blocks.30.norm2.weight": "model-00003-of-00003.safetensors",
|
| 306 |
+
"image_prefix.enc.visual.blocks.31.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 307 |
+
"image_prefix.enc.visual.blocks.31.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 308 |
+
"image_prefix.enc.visual.blocks.31.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 309 |
+
"image_prefix.enc.visual.blocks.31.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 310 |
+
"image_prefix.enc.visual.blocks.31.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 311 |
+
"image_prefix.enc.visual.blocks.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 312 |
+
"image_prefix.enc.visual.blocks.31.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 313 |
+
"image_prefix.enc.visual.blocks.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 314 |
+
"image_prefix.enc.visual.blocks.31.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 315 |
+
"image_prefix.enc.visual.blocks.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 316 |
+
"image_prefix.enc.visual.blocks.31.norm1.weight": "model-00003-of-00003.safetensors",
|
| 317 |
+
"image_prefix.enc.visual.blocks.31.norm2.weight": "model-00003-of-00003.safetensors",
|
| 318 |
+
"image_prefix.enc.visual.blocks.4.attn.proj.bias": "model-00002-of-00003.safetensors",
|
| 319 |
+
"image_prefix.enc.visual.blocks.4.attn.proj.weight": "model-00002-of-00003.safetensors",
|
| 320 |
+
"image_prefix.enc.visual.blocks.4.attn.qkv.bias": "model-00002-of-00003.safetensors",
|
| 321 |
+
"image_prefix.enc.visual.blocks.4.attn.qkv.weight": "model-00002-of-00003.safetensors",
|
| 322 |
+
"image_prefix.enc.visual.blocks.4.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
|
| 323 |
+
"image_prefix.enc.visual.blocks.4.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 324 |
+
"image_prefix.enc.visual.blocks.4.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
|
| 325 |
+
"image_prefix.enc.visual.blocks.4.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 326 |
+
"image_prefix.enc.visual.blocks.4.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
|
| 327 |
+
"image_prefix.enc.visual.blocks.4.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 328 |
+
"image_prefix.enc.visual.blocks.4.norm1.weight": "model-00002-of-00003.safetensors",
|
| 329 |
+
"image_prefix.enc.visual.blocks.4.norm2.weight": "model-00002-of-00003.safetensors",
|
| 330 |
+
"image_prefix.enc.visual.blocks.5.attn.proj.bias": "model-00002-of-00003.safetensors",
|
| 331 |
+
"image_prefix.enc.visual.blocks.5.attn.proj.weight": "model-00002-of-00003.safetensors",
|
| 332 |
+
"image_prefix.enc.visual.blocks.5.attn.qkv.bias": "model-00002-of-00003.safetensors",
|
| 333 |
+
"image_prefix.enc.visual.blocks.5.attn.qkv.weight": "model-00002-of-00003.safetensors",
|
| 334 |
+
"image_prefix.enc.visual.blocks.5.mlp.down_proj.bias": "model-00002-of-00003.safetensors",
|
| 335 |
+
"image_prefix.enc.visual.blocks.5.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 336 |
+
"image_prefix.enc.visual.blocks.5.mlp.gate_proj.bias": "model-00002-of-00003.safetensors",
|
| 337 |
+
"image_prefix.enc.visual.blocks.5.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 338 |
+
"image_prefix.enc.visual.blocks.5.mlp.up_proj.bias": "model-00002-of-00003.safetensors",
|
| 339 |
+
"image_prefix.enc.visual.blocks.5.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 340 |
+
"image_prefix.enc.visual.blocks.5.norm1.weight": "model-00002-of-00003.safetensors",
|
| 341 |
+
"image_prefix.enc.visual.blocks.5.norm2.weight": "model-00002-of-00003.safetensors",
|
| 342 |
+
"image_prefix.enc.visual.blocks.6.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 343 |
+
"image_prefix.enc.visual.blocks.6.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 344 |
+
"image_prefix.enc.visual.blocks.6.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 345 |
+
"image_prefix.enc.visual.blocks.6.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 346 |
+
"image_prefix.enc.visual.blocks.6.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 347 |
+
"image_prefix.enc.visual.blocks.6.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 348 |
+
"image_prefix.enc.visual.blocks.6.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 349 |
+
"image_prefix.enc.visual.blocks.6.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 350 |
+
"image_prefix.enc.visual.blocks.6.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 351 |
+
"image_prefix.enc.visual.blocks.6.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 352 |
+
"image_prefix.enc.visual.blocks.6.norm1.weight": "model-00002-of-00003.safetensors",
|
| 353 |
+
"image_prefix.enc.visual.blocks.6.norm2.weight": "model-00002-of-00003.safetensors",
|
| 354 |
+
"image_prefix.enc.visual.blocks.7.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 355 |
+
"image_prefix.enc.visual.blocks.7.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 356 |
+
"image_prefix.enc.visual.blocks.7.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 357 |
+
"image_prefix.enc.visual.blocks.7.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 358 |
+
"image_prefix.enc.visual.blocks.7.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 359 |
+
"image_prefix.enc.visual.blocks.7.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 360 |
+
"image_prefix.enc.visual.blocks.7.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 361 |
+
"image_prefix.enc.visual.blocks.7.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 362 |
+
"image_prefix.enc.visual.blocks.7.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 363 |
+
"image_prefix.enc.visual.blocks.7.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 364 |
+
"image_prefix.enc.visual.blocks.7.norm1.weight": "model-00003-of-00003.safetensors",
|
| 365 |
+
"image_prefix.enc.visual.blocks.7.norm2.weight": "model-00003-of-00003.safetensors",
|
| 366 |
+
"image_prefix.enc.visual.blocks.8.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 367 |
+
"image_prefix.enc.visual.blocks.8.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 368 |
+
"image_prefix.enc.visual.blocks.8.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 369 |
+
"image_prefix.enc.visual.blocks.8.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 370 |
+
"image_prefix.enc.visual.blocks.8.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 371 |
+
"image_prefix.enc.visual.blocks.8.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 372 |
+
"image_prefix.enc.visual.blocks.8.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 373 |
+
"image_prefix.enc.visual.blocks.8.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 374 |
+
"image_prefix.enc.visual.blocks.8.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 375 |
+
"image_prefix.enc.visual.blocks.8.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 376 |
+
"image_prefix.enc.visual.blocks.8.norm1.weight": "model-00003-of-00003.safetensors",
|
| 377 |
+
"image_prefix.enc.visual.blocks.8.norm2.weight": "model-00003-of-00003.safetensors",
|
| 378 |
+
"image_prefix.enc.visual.blocks.9.attn.proj.bias": "model-00003-of-00003.safetensors",
|
| 379 |
+
"image_prefix.enc.visual.blocks.9.attn.proj.weight": "model-00003-of-00003.safetensors",
|
| 380 |
+
"image_prefix.enc.visual.blocks.9.attn.qkv.bias": "model-00003-of-00003.safetensors",
|
| 381 |
+
"image_prefix.enc.visual.blocks.9.attn.qkv.weight": "model-00003-of-00003.safetensors",
|
| 382 |
+
"image_prefix.enc.visual.blocks.9.mlp.down_proj.bias": "model-00003-of-00003.safetensors",
|
| 383 |
+
"image_prefix.enc.visual.blocks.9.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 384 |
+
"image_prefix.enc.visual.blocks.9.mlp.gate_proj.bias": "model-00003-of-00003.safetensors",
|
| 385 |
+
"image_prefix.enc.visual.blocks.9.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 386 |
+
"image_prefix.enc.visual.blocks.9.mlp.up_proj.bias": "model-00003-of-00003.safetensors",
|
| 387 |
+
"image_prefix.enc.visual.blocks.9.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 388 |
+
"image_prefix.enc.visual.blocks.9.norm1.weight": "model-00003-of-00003.safetensors",
|
| 389 |
+
"image_prefix.enc.visual.blocks.9.norm2.weight": "model-00003-of-00003.safetensors",
|
| 390 |
+
"image_prefix.enc.visual.merger.ln_q.weight": "model-00003-of-00003.safetensors",
|
| 391 |
+
"image_prefix.enc.visual.merger.mlp.0.bias": "model-00003-of-00003.safetensors",
|
| 392 |
+
"image_prefix.enc.visual.merger.mlp.0.weight": "model-00003-of-00003.safetensors",
|
| 393 |
+
"image_prefix.enc.visual.merger.mlp.2.bias": "model-00003-of-00003.safetensors",
|
| 394 |
+
"image_prefix.enc.visual.merger.mlp.2.weight": "model-00003-of-00003.safetensors",
|
| 395 |
+
"image_prefix.enc.visual.patch_embed.proj.weight": "model-00002-of-00003.safetensors",
|
| 396 |
+
"image_prefix.norm_extra.weight": "model-00003-of-00003.safetensors",
|
| 397 |
+
"lm_head.weight": "model-00002-of-00003.safetensors",
|
| 398 |
+
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
| 399 |
+
"model.layers.0.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 400 |
+
"model.layers.0.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 401 |
+
"model.layers.0.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 402 |
+
"model.layers.0.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 403 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 404 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 405 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 406 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 407 |
+
"model.layers.0.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 408 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 409 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 410 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 411 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 412 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 413 |
+
"model.layers.1.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 414 |
+
"model.layers.1.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 415 |
+
"model.layers.1.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 416 |
+
"model.layers.1.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 417 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 418 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 419 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 420 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 421 |
+
"model.layers.1.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 422 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 423 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 424 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 425 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 426 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 427 |
+
"model.layers.10.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 428 |
+
"model.layers.10.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 429 |
+
"model.layers.10.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 430 |
+
"model.layers.10.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 431 |
+
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 432 |
+
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 433 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 434 |
+
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 435 |
+
"model.layers.10.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 436 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 437 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 438 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 439 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 440 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 441 |
+
"model.layers.11.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 442 |
+
"model.layers.11.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 443 |
+
"model.layers.11.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 444 |
+
"model.layers.11.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 445 |
+
"model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 446 |
+
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 447 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 448 |
+
"model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 449 |
+
"model.layers.11.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 450 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 451 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 452 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 453 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 454 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 455 |
+
"model.layers.12.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 456 |
+
"model.layers.12.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 457 |
+
"model.layers.12.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 458 |
+
"model.layers.12.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 459 |
+
"model.layers.12.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 460 |
+
"model.layers.12.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 461 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 462 |
+
"model.layers.12.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 463 |
+
"model.layers.12.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 464 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 465 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 466 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 467 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 468 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 469 |
+
"model.layers.13.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 470 |
+
"model.layers.13.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 471 |
+
"model.layers.13.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 472 |
+
"model.layers.13.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 473 |
+
"model.layers.13.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 474 |
+
"model.layers.13.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 475 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 476 |
+
"model.layers.13.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 477 |
+
"model.layers.13.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 478 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 479 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 480 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 481 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 482 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 483 |
+
"model.layers.14.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 484 |
+
"model.layers.14.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 485 |
+
"model.layers.14.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 486 |
+
"model.layers.14.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 487 |
+
"model.layers.14.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 488 |
+
"model.layers.14.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 489 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 490 |
+
"model.layers.14.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 491 |
+
"model.layers.14.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 492 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 493 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 494 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 495 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 496 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 497 |
+
"model.layers.15.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 498 |
+
"model.layers.15.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 499 |
+
"model.layers.15.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 500 |
+
"model.layers.15.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 501 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 502 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 503 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 504 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 505 |
+
"model.layers.15.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 506 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 507 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 508 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 509 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 510 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 511 |
+
"model.layers.16.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 512 |
+
"model.layers.16.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 513 |
+
"model.layers.16.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 514 |
+
"model.layers.16.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 515 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 516 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 517 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 518 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 519 |
+
"model.layers.16.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 520 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 521 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 522 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 523 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 524 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 525 |
+
"model.layers.17.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 526 |
+
"model.layers.17.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 527 |
+
"model.layers.17.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 528 |
+
"model.layers.17.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 529 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 530 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 531 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 532 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 533 |
+
"model.layers.17.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 534 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 535 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 536 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 537 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 538 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 539 |
+
"model.layers.18.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 540 |
+
"model.layers.18.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 541 |
+
"model.layers.18.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 542 |
+
"model.layers.18.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 543 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 544 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 545 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 546 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 547 |
+
"model.layers.18.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 548 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 549 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 550 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 551 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 552 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 553 |
+
"model.layers.19.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 554 |
+
"model.layers.19.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 555 |
+
"model.layers.19.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 556 |
+
"model.layers.19.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 557 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 558 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 559 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 560 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 561 |
+
"model.layers.19.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 562 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 563 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 564 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 565 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 566 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 567 |
+
"model.layers.2.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 568 |
+
"model.layers.2.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 569 |
+
"model.layers.2.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 570 |
+
"model.layers.2.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 571 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 572 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 573 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 574 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 575 |
+
"model.layers.2.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 576 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 577 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 578 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 579 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 580 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 581 |
+
"model.layers.20.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 582 |
+
"model.layers.20.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 583 |
+
"model.layers.20.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 584 |
+
"model.layers.20.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 585 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 586 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 587 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 588 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 589 |
+
"model.layers.20.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 590 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 591 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 592 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 593 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 594 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 595 |
+
"model.layers.21.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 596 |
+
"model.layers.21.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 597 |
+
"model.layers.21.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 598 |
+
"model.layers.21.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 599 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 600 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 601 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 602 |
+
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 603 |
+
"model.layers.21.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 604 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 605 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 606 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 607 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 608 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 609 |
+
"model.layers.22.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 610 |
+
"model.layers.22.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 611 |
+
"model.layers.22.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 612 |
+
"model.layers.22.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 613 |
+
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 614 |
+
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 615 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 616 |
+
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 617 |
+
"model.layers.22.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 618 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 619 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 620 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 621 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 622 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 623 |
+
"model.layers.23.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 624 |
+
"model.layers.23.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 625 |
+
"model.layers.23.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 626 |
+
"model.layers.23.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 627 |
+
"model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 628 |
+
"model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 629 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 630 |
+
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 631 |
+
"model.layers.23.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 632 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 633 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 634 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 635 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 636 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 637 |
+
"model.layers.24.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 638 |
+
"model.layers.24.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 639 |
+
"model.layers.24.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 640 |
+
"model.layers.24.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 641 |
+
"model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 642 |
+
"model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 643 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 644 |
+
"model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 645 |
+
"model.layers.24.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 646 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 647 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 648 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 649 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 650 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 651 |
+
"model.layers.25.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 652 |
+
"model.layers.25.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 653 |
+
"model.layers.25.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 654 |
+
"model.layers.25.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 655 |
+
"model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 656 |
+
"model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 657 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 658 |
+
"model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 659 |
+
"model.layers.25.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 660 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 661 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 662 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 663 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 664 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 665 |
+
"model.layers.26.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 666 |
+
"model.layers.26.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 667 |
+
"model.layers.26.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 668 |
+
"model.layers.26.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 669 |
+
"model.layers.26.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 670 |
+
"model.layers.26.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 671 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 672 |
+
"model.layers.26.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 673 |
+
"model.layers.26.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 674 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 675 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 676 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 677 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 678 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 679 |
+
"model.layers.27.casa_attn.k_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 680 |
+
"model.layers.27.casa_attn.o_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 681 |
+
"model.layers.27.casa_attn.q_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 682 |
+
"model.layers.27.casa_attn.v_proj_casa.weight": "model-00002-of-00003.safetensors",
|
| 683 |
+
"model.layers.27.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 684 |
+
"model.layers.27.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 685 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 686 |
+
"model.layers.27.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 687 |
+
"model.layers.27.norm_cross.weight": "model-00002-of-00003.safetensors",
|
| 688 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 689 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 690 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 691 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 692 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 693 |
+
"model.layers.3.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 694 |
+
"model.layers.3.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 695 |
+
"model.layers.3.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 696 |
+
"model.layers.3.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 697 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 698 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 699 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 700 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 701 |
+
"model.layers.3.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 702 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 703 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 704 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 705 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 706 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 707 |
+
"model.layers.4.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 708 |
+
"model.layers.4.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 709 |
+
"model.layers.4.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 710 |
+
"model.layers.4.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 711 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 712 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 713 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 714 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 715 |
+
"model.layers.4.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 716 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 717 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 718 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 719 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 720 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 721 |
+
"model.layers.5.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 722 |
+
"model.layers.5.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 723 |
+
"model.layers.5.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 724 |
+
"model.layers.5.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 725 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 726 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 727 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 728 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 729 |
+
"model.layers.5.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 730 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 731 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 732 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 733 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 734 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 735 |
+
"model.layers.6.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 736 |
+
"model.layers.6.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 737 |
+
"model.layers.6.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 738 |
+
"model.layers.6.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 739 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 740 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 741 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 742 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 743 |
+
"model.layers.6.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 744 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 745 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 746 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 747 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 748 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 749 |
+
"model.layers.7.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 750 |
+
"model.layers.7.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 751 |
+
"model.layers.7.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 752 |
+
"model.layers.7.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 753 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 754 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 755 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 756 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 757 |
+
"model.layers.7.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 758 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 759 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 760 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 761 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 762 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 763 |
+
"model.layers.8.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 764 |
+
"model.layers.8.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 765 |
+
"model.layers.8.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 766 |
+
"model.layers.8.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 767 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 768 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 769 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 770 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 771 |
+
"model.layers.8.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 772 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 773 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 774 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 775 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 776 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 777 |
+
"model.layers.9.casa_attn.k_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 778 |
+
"model.layers.9.casa_attn.o_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 779 |
+
"model.layers.9.casa_attn.q_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 780 |
+
"model.layers.9.casa_attn.v_proj_casa.weight": "model-00001-of-00003.safetensors",
|
| 781 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 782 |
+
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 783 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 784 |
+
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 785 |
+
"model.layers.9.norm_cross.weight": "model-00001-of-00003.safetensors",
|
| 786 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 787 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 788 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 789 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 790 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 791 |
+
"model.norm.weight": "model-00002-of-00003.safetensors"
|
| 792 |
+
}
|
| 793 |
+
}
|
modeling_helium1_casa.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable
|
| 2 |
+
from typing import cast as type_cast
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers.cache_utils import DynamicCache
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
from transformers.generation.utils import GenerateOutput
|
| 8 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 9 |
+
Qwen2_5_VisionTransformerPretrainedModel,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from .image_encoder import Qwen25VLEncoder
|
| 13 |
+
from .configuration_helium1_casa import Helium1CASAConfig
|
| 14 |
+
from .language_helium1_casa import (
|
| 15 |
+
CausalHeliumOutput,
|
| 16 |
+
Helium1CASAAttention,
|
| 17 |
+
Helium1ForCausalLM,
|
| 18 |
+
Helium1RMSNorm,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def meta_project(
|
| 23 |
+
logits: torch.Tensor | list[torch.Tensor],
|
| 24 |
+
projector: torch.nn.Module,
|
| 25 |
+
norm: torch.nn.Module | None = None,
|
| 26 |
+
) -> torch.Tensor | list[torch.Tensor]:
|
| 27 |
+
"""Projection operation that handles both tensors and list of tensors
|
| 28 |
+
|
| 29 |
+
Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where
|
| 30 |
+
S can be a different sequence length per image)
|
| 31 |
+
"""
|
| 32 |
+
split_sizes: list[int] | None = None
|
| 33 |
+
if not isinstance(logits, torch.Tensor):
|
| 34 |
+
split_sizes = [_x.shape[0] for _x in logits]
|
| 35 |
+
logits = torch.cat(logits, dim=0)[None, :, :]
|
| 36 |
+
logits = type_cast(torch.Tensor, logits)
|
| 37 |
+
logits = projector(logits)
|
| 38 |
+
|
| 39 |
+
assert isinstance(logits, torch.Tensor)
|
| 40 |
+
if norm is not None:
|
| 41 |
+
logits = norm(logits)
|
| 42 |
+
if split_sizes is not None:
|
| 43 |
+
return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0))
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ImageProjection(torch.nn.Module):
|
| 48 |
+
"""Takes in a batch or sequence of images and returns embeddings
|
| 49 |
+
which are then fed to the LM.
|
| 50 |
+
|
| 51 |
+
:param config: KyuteyeConfig object
|
| 52 |
+
:param lm_model_dim: Output dimension (number of channels) for this module
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.config = config
|
| 58 |
+
self.out_dim = lm_model_dim
|
| 59 |
+
visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
|
| 60 |
+
|
| 61 |
+
self.enc = Qwen25VLEncoder(visual=visual)
|
| 62 |
+
# Projection layer
|
| 63 |
+
self.proj_extra = self.init_proj_module()
|
| 64 |
+
# Output normalizations
|
| 65 |
+
self.norm_extra = Helium1RMSNorm(self.out_dim)
|
| 66 |
+
|
| 67 |
+
def init_proj_module(self) -> torch.nn.Module:
|
| 68 |
+
"""Init the project module for the inserted and/or cross-attended image tokens"""
|
| 69 |
+
if self.config.vision_config.out_dim == self.out_dim:
|
| 70 |
+
return torch.nn.Identity()
|
| 71 |
+
return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim)
|
| 72 |
+
|
| 73 |
+
def forward(
|
| 74 |
+
self, x: torch.Tensor | list[torch.Tensor]
|
| 75 |
+
) -> dict[
|
| 76 |
+
str,
|
| 77 |
+
torch.Tensor | list[torch.Tensor],
|
| 78 |
+
]:
|
| 79 |
+
"""Image embedding mapping
|
| 80 |
+
|
| 81 |
+
:param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors
|
| 82 |
+
with shape (C, H, W) (or (H, W, C) in the case of Qwen)
|
| 83 |
+
|
| 84 |
+
:return: Either a tensor with shape (num_total_image, S, D) or, if images
|
| 85 |
+
can have different seq length, a list of `num_total_images` Tensors with shape
|
| 86 |
+
(S, D)
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
# Apply image encoder
|
| 90 |
+
og_dtype = x[0].dtype
|
| 91 |
+
encoded = self.enc(x)["image_embeds"]
|
| 92 |
+
encoded = [_x.to(og_dtype) for _x in encoded]
|
| 93 |
+
if all(x.shape[0] == encoded[0].shape[0] for x in encoded):
|
| 94 |
+
encoded = torch.stack(encoded, dim=0)
|
| 95 |
+
|
| 96 |
+
# Extra projection
|
| 97 |
+
image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra)
|
| 98 |
+
|
| 99 |
+
# Apply different projection for extra vs cross attended tokens
|
| 100 |
+
return {"image_embeds": image_embeds}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class V2Helium1(Helium1ForCausalLM): # pyright: ignore[reportIncompatibleMethodOverride]
|
| 104 |
+
config_class = Helium1CASAConfig
|
| 105 |
+
|
| 106 |
+
def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
|
| 107 |
+
del kwargs
|
| 108 |
+
super().__init__(config)
|
| 109 |
+
self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim)
|
| 110 |
+
|
| 111 |
+
def get_device(self) -> str:
|
| 112 |
+
"""Return the device type of the model"""
|
| 113 |
+
return next(self.parameters()).device.type
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def token_dim(self) -> int:
|
| 117 |
+
"""Returns the number of dimensions for the token representation"""
|
| 118 |
+
return self.config.hidden_size
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def rotary_embed(self) -> Callable:
|
| 122 |
+
"""Returns the rotary embedding function of the underlying model"""
|
| 123 |
+
return self.model.rotary_emb
|
| 124 |
+
|
| 125 |
+
def _update_model_kwargs_for_generation(
|
| 126 |
+
self,
|
| 127 |
+
outputs: Any,
|
| 128 |
+
model_kwargs: dict[str, Any],
|
| 129 |
+
is_encoder_decoder: bool = False,
|
| 130 |
+
num_new_tokens: int = 1,
|
| 131 |
+
):
|
| 132 |
+
"""This is required to handle multiple gen calls for subtitles"""
|
| 133 |
+
# Call parent to get default updates
|
| 134 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
| 135 |
+
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
|
| 136 |
+
)
|
| 137 |
+
# Used by prepare_inputs_for_generation
|
| 138 |
+
model_kwargs["__is_first_gen_call__"] = False
|
| 139 |
+
return model_kwargs
|
| 140 |
+
|
| 141 |
+
def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
|
| 142 |
+
self,
|
| 143 |
+
input_ids: torch.Tensor,
|
| 144 |
+
past_key_values: DynamicCache | None = None,
|
| 145 |
+
**kwargs: Any,
|
| 146 |
+
):
|
| 147 |
+
__is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True)
|
| 148 |
+
if past_key_values is not None and (
|
| 149 |
+
kwargs.get("cache_position") is None
|
| 150 |
+
or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
|
| 151 |
+
):
|
| 152 |
+
# We're continuing from a cached state
|
| 153 |
+
past_length = past_key_values._seen_tokens
|
| 154 |
+
kwargs["cache_position"] = torch.arange(
|
| 155 |
+
past_length,
|
| 156 |
+
past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
|
| 157 |
+
dtype=torch.long,
|
| 158 |
+
device=input_ids.device,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return super().prepare_inputs_for_generation(
|
| 162 |
+
type_cast(torch.LongTensor, input_ids),
|
| 163 |
+
past_key_values=past_key_values,
|
| 164 |
+
**kwargs,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def prepare_multimodal_inputs(
|
| 168 |
+
self,
|
| 169 |
+
# text only training
|
| 170 |
+
input_ids: torch.Tensor | None = None,
|
| 171 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 172 |
+
attention_mask: torch.Tensor | None = None,
|
| 173 |
+
image_embeds_insertion_points: list[torch.Tensor] | None = None,
|
| 174 |
+
labels: torch.Tensor | None = None,
|
| 175 |
+
# image values
|
| 176 |
+
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
|
| 177 |
+
pre_image_tokens: list[int] | None = None,
|
| 178 |
+
post_image_tokens: list[int] | None = None,
|
| 179 |
+
**_kwargs: Any,
|
| 180 |
+
) -> dict:
|
| 181 |
+
"""Get a batch data mixing text and image data"""
|
| 182 |
+
del _kwargs
|
| 183 |
+
|
| 184 |
+
processed_inputs = {
|
| 185 |
+
"input_ids": input_ids,
|
| 186 |
+
"inputs_embeds": inputs_embeds,
|
| 187 |
+
"labels": labels,
|
| 188 |
+
"attention_mask": attention_mask,
|
| 189 |
+
"image_embeds_insertion_points": image_embeds_insertion_points,
|
| 190 |
+
}
|
| 191 |
+
if pixel_values is not None:
|
| 192 |
+
processed_inputs.update(self.image_prefix(pixel_values))
|
| 193 |
+
assert "image_embeds" in processed_inputs
|
| 194 |
+
assert (
|
| 195 |
+
isinstance(processed_inputs["image_embeds"], torch.Tensor)
|
| 196 |
+
and processed_inputs["image_embeds"].ndim == 3
|
| 197 |
+
) or (
|
| 198 |
+
isinstance(processed_inputs["image_embeds"], list)
|
| 199 |
+
and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Add kwargs necessary to compute cu_seqlens windows for CASA
|
| 203 |
+
processed_inputs["casa_windows_info"] = {
|
| 204 |
+
"num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
|
| 205 |
+
"num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
return processed_inputs
|
| 209 |
+
|
| 210 |
+
def forward( # pyright: ignore[reportIncompatibleMethodOverride]
|
| 211 |
+
self,
|
| 212 |
+
input_ids: torch.Tensor | None = None,
|
| 213 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 214 |
+
attention_mask: torch.Tensor | None = None,
|
| 215 |
+
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
|
| 216 |
+
return_loss: bool = True,
|
| 217 |
+
labels: torch.Tensor | None = None,
|
| 218 |
+
image_embeds_insertion_points: list[torch.Tensor] | None = None,
|
| 219 |
+
pre_image_tokens: list[int] | None = None,
|
| 220 |
+
post_image_tokens: list[int] | None = None,
|
| 221 |
+
**kwargs: Any,
|
| 222 |
+
) -> CausalHeliumOutput:
|
| 223 |
+
"""Multi modal forward pass"""
|
| 224 |
+
assert input_ids is not None or inputs_embeds is not None
|
| 225 |
+
|
| 226 |
+
if self.training:
|
| 227 |
+
assert return_loss is True, (
|
| 228 |
+
"Helium models always compute its own labels/losses in train mode"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Case 1: For first generation call we need to compute pixel values and CASA states
|
| 232 |
+
if kwargs.get("__is_first_gen_call__", True):
|
| 233 |
+
processed_inputs = self.prepare_multimodal_inputs(
|
| 234 |
+
input_ids=input_ids,
|
| 235 |
+
inputs_embeds=inputs_embeds,
|
| 236 |
+
attention_mask=attention_mask,
|
| 237 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 238 |
+
pixel_values=pixel_values,
|
| 239 |
+
labels=labels,
|
| 240 |
+
pre_image_tokens=pre_image_tokens,
|
| 241 |
+
post_image_tokens=post_image_tokens,
|
| 242 |
+
)
|
| 243 |
+
processed_inputs.pop("inputs_embeds", None)
|
| 244 |
+
else:
|
| 245 |
+
processed_inputs = {
|
| 246 |
+
"inputs_embeds": self.model.embed_tokens(input_ids),
|
| 247 |
+
"attention_mask": attention_mask,
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
# For Helium prefix, we need to update the positions by the number
|
| 251 |
+
# of image tokens inserted in the first call
|
| 252 |
+
if (
|
| 253 |
+
not self.config.casa_attention
|
| 254 |
+
and (cp := kwargs.get("cache_position", None)) is not None
|
| 255 |
+
and pixel_values is not None
|
| 256 |
+
):
|
| 257 |
+
start = kwargs["cache_position"][0].item()
|
| 258 |
+
num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4
|
| 259 |
+
num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # type: ignore
|
| 260 |
+
kwargs["cache_position"] = torch.arange(
|
| 261 |
+
start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens),
|
| 262 |
+
start + num_tokens + num_image_tokens,
|
| 263 |
+
dtype=cp.dtype,
|
| 264 |
+
device=cp.device,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
kwargs.pop("__is_first_gen_call__", True)
|
| 268 |
+
out = super().forward(
|
| 269 |
+
**processed_inputs, # type: ignore
|
| 270 |
+
**kwargs,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
return out
|
| 274 |
+
|
| 275 |
+
@torch.no_grad()
|
| 276 |
+
def generate_from_image( # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride]
|
| 277 |
+
self,
|
| 278 |
+
input_ids: torch.Tensor | None = None,
|
| 279 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 280 |
+
attention_mask: torch.Tensor | None = None,
|
| 281 |
+
image_embeds_insertion_points: list[torch.Tensor] | None = None,
|
| 282 |
+
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
|
| 283 |
+
reset_streaming: bool = True,
|
| 284 |
+
**kwargs: Any,
|
| 285 |
+
) -> "GenerateOutput | torch.LongTensor":
|
| 286 |
+
assert input_ids is not None and inputs_embeds is None, (
|
| 287 |
+
"Input IDs must be provided for generation"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# init self-attention KVCache
|
| 291 |
+
if kwargs.get("past_key_values", None) is None:
|
| 292 |
+
kwargs["past_key_values"] = DynamicCache()
|
| 293 |
+
|
| 294 |
+
# To avoid generate warning
|
| 295 |
+
if kwargs.get("pad_token_id", None) is None:
|
| 296 |
+
kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
|
| 297 |
+
if isinstance(kwargs["pad_token_id"], (list, tuple)):
|
| 298 |
+
kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
|
| 299 |
+
|
| 300 |
+
self.start_casa_streaming_states()
|
| 301 |
+
outputs = self.generate(
|
| 302 |
+
input_ids,
|
| 303 |
+
attention_mask=attention_mask,
|
| 304 |
+
pixel_values=pixel_values,
|
| 305 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 306 |
+
use_cache=True,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
if reset_streaming:
|
| 310 |
+
self.reset_casa_streaming_states()
|
| 311 |
+
return outputs
|
| 312 |
+
|
| 313 |
+
def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
|
| 314 |
+
def __reset__(m: torch.nn.Module):
|
| 315 |
+
if isinstance(m, Helium1CASAAttention):
|
| 316 |
+
m._set_streaming(False, ())
|
| 317 |
+
m.reset_streaming()
|
| 318 |
+
if clean_cache:
|
| 319 |
+
del m.streaming_state.k
|
| 320 |
+
del m.streaming_state.v
|
| 321 |
+
del m.streaming_state.casa_handler
|
| 322 |
+
|
| 323 |
+
self.apply(__reset__)
|
| 324 |
+
|
| 325 |
+
def start_casa_streaming_states(self) -> None:
|
| 326 |
+
def __start__(m: torch.nn.Module):
|
| 327 |
+
if isinstance(m, Helium1CASAAttention):
|
| 328 |
+
m._set_streaming(True, ())
|
| 329 |
+
|
| 330 |
+
self.apply(__start__)
|
processing.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=no-member # avoid weird pylint warnings from SentencePieceProcessor
|
| 2 |
+
"""Text and Image processor for CASA models using Qwen2.5_VL image encoder"""
|
| 3 |
+
|
| 4 |
+
from math import ceil
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload
|
| 6 |
+
from typing import cast as type_cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms.v2 as T
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torchvision.transforms import InterpolationMode
|
| 13 |
+
from torchvision.transforms.functional import to_tensor as pil_to_tensor
|
| 14 |
+
from torchvision.transforms.v2 import functional as F
|
| 15 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 16 |
+
from transformers.processing_utils import ProcessorMixin
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
| 20 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
ImageMessage = TypedDict(
|
| 24 |
+
"ImageMessage",
|
| 25 |
+
{
|
| 26 |
+
"type": Literal["image"],
|
| 27 |
+
"image": str | Image.Image | None,
|
| 28 |
+
},
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
TextMessage = TypedDict(
|
| 32 |
+
"TextMessage",
|
| 33 |
+
{
|
| 34 |
+
"type": Literal["text"],
|
| 35 |
+
"text": str,
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
MessageContent = list[ImageMessage | TextMessage]
|
| 40 |
+
|
| 41 |
+
Message = TypedDict(
|
| 42 |
+
"Message",
|
| 43 |
+
{
|
| 44 |
+
"role": Literal["system", "user", "assistant"],
|
| 45 |
+
"content": MessageContent,
|
| 46 |
+
},
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
ProcessorInput = list[list[Message]] | list[Message]
|
| 50 |
+
|
| 51 |
+
__INTERP_NAME_TO_MODE__ = {
|
| 52 |
+
"nearest": InterpolationMode.NEAREST,
|
| 53 |
+
"bilinear": InterpolationMode.BILINEAR,
|
| 54 |
+
"bicubic": InterpolationMode.BICUBIC,
|
| 55 |
+
"lanczos": InterpolationMode.LANCZOS,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
__INTERP_INT_TO_MODE__ = {
|
| 59 |
+
0: InterpolationMode.NEAREST,
|
| 60 |
+
2: InterpolationMode.BILINEAR,
|
| 61 |
+
3: InterpolationMode.BICUBIC,
|
| 62 |
+
4: InterpolationMode.BOX,
|
| 63 |
+
5: InterpolationMode.HAMMING,
|
| 64 |
+
1: InterpolationMode.LANCZOS,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@overload
|
| 69 |
+
def universal_resize(
|
| 70 |
+
img: Image.Image,
|
| 71 |
+
size: tuple[int, int],
|
| 72 |
+
interpolation: str | InterpolationMode | int = "bilinear",
|
| 73 |
+
antialias: bool = True,
|
| 74 |
+
) -> Image.Image: ...
|
| 75 |
+
@overload
|
| 76 |
+
def universal_resize(
|
| 77 |
+
img: torch.Tensor,
|
| 78 |
+
size: tuple[int, int],
|
| 79 |
+
interpolation: str | InterpolationMode | int = "bilinear",
|
| 80 |
+
antialias: bool = True,
|
| 81 |
+
) -> torch.Tensor: ...
|
| 82 |
+
def universal_resize(
|
| 83 |
+
img: Image.Image | torch.Tensor,
|
| 84 |
+
size: tuple[int, int],
|
| 85 |
+
interpolation: str | InterpolationMode | int = "bilinear",
|
| 86 |
+
antialias: bool = True,
|
| 87 |
+
) -> Image.Image | torch.Tensor:
|
| 88 |
+
"""Resize that works for PIL.Image, CHW tensor, or BCHW tensor"""
|
| 89 |
+
if isinstance(interpolation, str):
|
| 90 |
+
interpolation = __INTERP_NAME_TO_MODE__[interpolation]
|
| 91 |
+
elif isinstance(interpolation, int):
|
| 92 |
+
interpolation = __INTERP_INT_TO_MODE__[interpolation]
|
| 93 |
+
|
| 94 |
+
return F.resize(
|
| 95 |
+
img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@overload
|
| 100 |
+
def convert_to_rgb(img: Image.Image) -> Image.Image: ...
|
| 101 |
+
@overload
|
| 102 |
+
def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ...
|
| 103 |
+
def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor:
|
| 104 |
+
"""Convert any image to RGB in a way that does not throw PIL warning"""
|
| 105 |
+
if isinstance(img, torch.Tensor):
|
| 106 |
+
return img
|
| 107 |
+
if img.mode == "RGB": # no changes
|
| 108 |
+
return img
|
| 109 |
+
if img.mode == "P": # palette images need to be converted to RGBA first
|
| 110 |
+
return img.convert("RGBA").convert("RGB")
|
| 111 |
+
return img.convert("RGB")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class QwenImageProcessor(BaseImageProcessor):
|
| 115 |
+
"""Resizing for the Qwen2.5VL encoder. Note that the normalization is
|
| 116 |
+
handled in the image_encoder in the model forward"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
img_size: int = 448,
|
| 121 |
+
interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic",
|
| 122 |
+
max_ratio: int = 10,
|
| 123 |
+
round_to_patch_size: int = 56,
|
| 124 |
+
use_fast: bool = True,
|
| 125 |
+
**kwargs: Any,
|
| 126 |
+
) -> None:
|
| 127 |
+
# this will also be used in V2llms to determine whether to remove
|
| 128 |
+
# the temporal conv
|
| 129 |
+
self._num_target_channels = 588
|
| 130 |
+
self._merge_size = 2
|
| 131 |
+
self._patch_size = 14
|
| 132 |
+
super().__init__(
|
| 133 |
+
use_fast=use_fast,
|
| 134 |
+
do_normalize=False,
|
| 135 |
+
**kwargs,
|
| 136 |
+
)
|
| 137 |
+
self.img_size = img_size
|
| 138 |
+
self.interpolation = interpolation
|
| 139 |
+
self.max_ratio = max_ratio
|
| 140 |
+
self.round_to_patch_size = round_to_patch_size
|
| 141 |
+
|
| 142 |
+
def resize_transform(
|
| 143 |
+
self, img: Image.Image | torch.Tensor, img_size: int | None = None
|
| 144 |
+
) -> Image.Image | torch.Tensor:
|
| 145 |
+
if img_size is None:
|
| 146 |
+
img_size = self.img_size
|
| 147 |
+
max_area = img_size**2
|
| 148 |
+
if isinstance(img, Image.Image):
|
| 149 |
+
img = convert_to_rgb(img)
|
| 150 |
+
w_og, h_og = img.size
|
| 151 |
+
else:
|
| 152 |
+
h_og, w_og = img.shape[-2:]
|
| 153 |
+
w, h = w_og, h_og
|
| 154 |
+
|
| 155 |
+
# Qwen requires max ratio of 10 between max and min sizes
|
| 156 |
+
if self.max_ratio > 0:
|
| 157 |
+
w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio)
|
| 158 |
+
|
| 159 |
+
# resize to max area
|
| 160 |
+
current_area = w * h
|
| 161 |
+
if current_area > max_area:
|
| 162 |
+
scale = (max_area / current_area) ** 0.5
|
| 163 |
+
w, h = int(w * scale), int(h * scale)
|
| 164 |
+
|
| 165 |
+
# resize to patch size
|
| 166 |
+
if self.round_to_patch_size > 0:
|
| 167 |
+
w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size
|
| 168 |
+
h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size
|
| 169 |
+
|
| 170 |
+
# resize
|
| 171 |
+
if w != w_og or h != h_og:
|
| 172 |
+
img = universal_resize(img, (h, w), self.interpolation)
|
| 173 |
+
if isinstance(img, torch.Tensor):
|
| 174 |
+
img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img))
|
| 175 |
+
return img
|
| 176 |
+
|
| 177 |
+
def __process_one__(
|
| 178 |
+
self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
"""Same operation as __process_one_with_processor__ but without going through numpy"""
|
| 181 |
+
video_or_img = self.resize_transform(video_or_img, img_size)
|
| 182 |
+
if isinstance(video_or_img, Image.Image):
|
| 183 |
+
video_or_img = pil_to_tensor(video_or_img)
|
| 184 |
+
assert isinstance(video_or_img, torch.Tensor)
|
| 185 |
+
if video_or_img.ndim == 3:
|
| 186 |
+
video_or_img = video_or_img[None]
|
| 187 |
+
assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, (
|
| 188 |
+
f"Invalid shape {video_or_img.shape}."
|
| 189 |
+
)
|
| 190 |
+
t, c, h, w = video_or_img.shape
|
| 191 |
+
p = self._patch_size
|
| 192 |
+
m = self._merge_size
|
| 193 |
+
|
| 194 |
+
# Convert to RGB
|
| 195 |
+
if c == 1:
|
| 196 |
+
video_or_img = video_or_img.expand((-1, 3, -1, -1))
|
| 197 |
+
if c == 4:
|
| 198 |
+
video_or_img = video_or_img[:, :3]
|
| 199 |
+
c = video_or_img.shape[1]
|
| 200 |
+
assert c == 3, "Expecting RGB image in QwenNormalize"
|
| 201 |
+
|
| 202 |
+
# Reshape to t h w c' format
|
| 203 |
+
h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p
|
| 204 |
+
rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m)
|
| 205 |
+
|
| 206 |
+
video_or_img = rearrange(
|
| 207 |
+
video_or_img,
|
| 208 |
+
"t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)",
|
| 209 |
+
**rearrange_dict,
|
| 210 |
+
)
|
| 211 |
+
assert video_or_img.shape[-1] == self._num_target_channels, (
|
| 212 |
+
f"{video_or_img.shape[-1]} != {self._num_target_channels}"
|
| 213 |
+
)
|
| 214 |
+
video_or_img = video_or_img.view((-1, h, w, self._num_target_channels))
|
| 215 |
+
|
| 216 |
+
return video_or_img
|
| 217 |
+
|
| 218 |
+
@overload
|
| 219 |
+
def process_images(
|
| 220 |
+
self, image: Image.Image | torch.Tensor, img_size: int | None = None
|
| 221 |
+
) -> torch.Tensor: ...
|
| 222 |
+
@overload
|
| 223 |
+
def process_images(
|
| 224 |
+
self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None
|
| 225 |
+
) -> list[torch.Tensor]: ...
|
| 226 |
+
def process_images(
|
| 227 |
+
self,
|
| 228 |
+
image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor],
|
| 229 |
+
img_size: int | None = None,
|
| 230 |
+
) -> torch.Tensor | list[torch.Tensor]:
|
| 231 |
+
if isinstance(image, list):
|
| 232 |
+
return [self.__process_one__(_x, img_size) for _x in image]
|
| 233 |
+
return self.__process_one__(image, img_size)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ProcessorOutput(dict):
|
| 237 |
+
input_ids: torch.Tensor
|
| 238 |
+
attention_mask: torch.Tensor
|
| 239 |
+
image_embeds_insertion_points: list[torch.Tensor] | None
|
| 240 |
+
pixel_values: torch.Tensor | list[torch.Tensor] | None
|
| 241 |
+
|
| 242 |
+
def to(
|
| 243 |
+
self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16
|
| 244 |
+
) -> "ProcessorOutput":
|
| 245 |
+
return ProcessorOutput(
|
| 246 |
+
{
|
| 247 |
+
"input_ids": self["input_ids"].to(device),
|
| 248 |
+
"attention_mask": self["attention_mask"].to(device),
|
| 249 |
+
"image_embeds_insertion_points": self["image_embeds_insertion_points"],
|
| 250 |
+
"pixel_values": (
|
| 251 |
+
self["pixel_values"].to(dtype).to(device)
|
| 252 |
+
if isinstance(self["pixel_values"], torch.Tensor)
|
| 253 |
+
else [x.to(dtype).to(device) for x in self["pixel_values"]]
|
| 254 |
+
if self["pixel_values"] is not None
|
| 255 |
+
else None
|
| 256 |
+
),
|
| 257 |
+
}
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class BaseProcessor(ProcessorMixin):
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer",
|
| 265 |
+
pre_image_tokens: tuple[int, ...] = (),
|
| 266 |
+
post_image_tokens: tuple[int, ...] = (),
|
| 267 |
+
system_start_tokens: tuple[int, ...] = (),
|
| 268 |
+
system_end_tokens: tuple[int, ...] = (),
|
| 269 |
+
user_start_tokens: tuple[int, ...] = (),
|
| 270 |
+
user_end_tokens: tuple[int, ...] = (),
|
| 271 |
+
asst_start_tokens: tuple[int, ...] = (),
|
| 272 |
+
asst_end_tokens: tuple[int, ...] = (),
|
| 273 |
+
allow_system_prompt: bool = True,
|
| 274 |
+
pad_token: int = 0,
|
| 275 |
+
bos_token: int | None = None,
|
| 276 |
+
) -> None:
|
| 277 |
+
self.pre_image_tokens = list(pre_image_tokens)
|
| 278 |
+
self.post_image_tokens = list(post_image_tokens)
|
| 279 |
+
self.system_start_tokens = list(system_start_tokens)
|
| 280 |
+
self.system_end_tokens = list(system_end_tokens)
|
| 281 |
+
self.user_start_tokens = list(user_start_tokens)
|
| 282 |
+
self.user_end_tokens = list(user_end_tokens)
|
| 283 |
+
self.asst_start_tokens = list(asst_start_tokens)
|
| 284 |
+
self.asst_end_tokens = list(asst_end_tokens)
|
| 285 |
+
self._allow_system_prompt = allow_system_prompt
|
| 286 |
+
self.tokenizer = tokenizer
|
| 287 |
+
self._image_processor = None
|
| 288 |
+
self._pad_token = pad_token
|
| 289 |
+
self.bos_token = bos_token
|
| 290 |
+
|
| 291 |
+
@property
|
| 292 |
+
def image_processor(self) -> QwenImageProcessor:
|
| 293 |
+
assert self._image_processor is not None
|
| 294 |
+
return self._image_processor
|
| 295 |
+
|
| 296 |
+
def _process_content(
|
| 297 |
+
self,
|
| 298 |
+
message_content: MessageContent,
|
| 299 |
+
role: Literal["system", "user", "assistant"],
|
| 300 |
+
tokenized_messages: list[torch.Tensor],
|
| 301 |
+
insertion_points: list[int],
|
| 302 |
+
image_list: list[torch.Tensor | None],
|
| 303 |
+
token_count: int,
|
| 304 |
+
img_size: int | None = None,
|
| 305 |
+
**kwargs: Any,
|
| 306 |
+
) -> int:
|
| 307 |
+
mapping = {
|
| 308 |
+
"user": (self.user_start_tokens, self.user_end_tokens),
|
| 309 |
+
"assistant": (self.asst_start_tokens, self.asst_end_tokens),
|
| 310 |
+
"system": (self.system_start_tokens, self.system_end_tokens),
|
| 311 |
+
}
|
| 312 |
+
if role.lower() not in mapping:
|
| 313 |
+
raise ValueError(f"Unknown role '{role}' encountered in messages.")
|
| 314 |
+
start_tokens, end_tokens = mapping[role.lower()]
|
| 315 |
+
# 1) Add the start tokens
|
| 316 |
+
if start_tokens:
|
| 317 |
+
tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long))
|
| 318 |
+
token_count += len(start_tokens)
|
| 319 |
+
# 2) Process the message content one by one (potentially interleaved image and text)
|
| 320 |
+
for part in message_content:
|
| 321 |
+
elt_type = part["type"]
|
| 322 |
+
if elt_type == "image":
|
| 323 |
+
part = cast(ImageMessage, part)
|
| 324 |
+
self._process_image_message(
|
| 325 |
+
part,
|
| 326 |
+
tokenized_messages,
|
| 327 |
+
image_list,
|
| 328 |
+
img_size=img_size,
|
| 329 |
+
)
|
| 330 |
+
token_count += len(self.pre_image_tokens)
|
| 331 |
+
insertion_points.append(token_count)
|
| 332 |
+
token_count += len(self.post_image_tokens)
|
| 333 |
+
else:
|
| 334 |
+
part = cast(TextMessage, part)
|
| 335 |
+
self._process_text_message(
|
| 336 |
+
part["text"],
|
| 337 |
+
role=role,
|
| 338 |
+
token_list=tokenized_messages,
|
| 339 |
+
**kwargs,
|
| 340 |
+
)
|
| 341 |
+
token_count += tokenized_messages[-1].size(0)
|
| 342 |
+
# 3) Add the end tokens
|
| 343 |
+
if end_tokens:
|
| 344 |
+
tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long))
|
| 345 |
+
token_count += len(end_tokens)
|
| 346 |
+
return token_count
|
| 347 |
+
|
| 348 |
+
def _process_text_message(
|
| 349 |
+
self,
|
| 350 |
+
message: str,
|
| 351 |
+
role: Literal["system", "user", "assistant"],
|
| 352 |
+
token_list: list[torch.Tensor],
|
| 353 |
+
**kwargs: Any,
|
| 354 |
+
) -> None:
|
| 355 |
+
if role.lower() == "system" and not self._allow_system_prompt:
|
| 356 |
+
raise ValueError("System prompts are not allowed in this tokenizer configuration.")
|
| 357 |
+
tokens = self.tokenizer.encode(
|
| 358 |
+
message, add_special_tokens=False, return_tensors="pt", **kwargs
|
| 359 |
+
)
|
| 360 |
+
tokens = cast(torch.Tensor, tokens)
|
| 361 |
+
token_list.append(tokens.flatten().to(torch.long))
|
| 362 |
+
|
| 363 |
+
def _process_image_message(
|
| 364 |
+
self,
|
| 365 |
+
message: ImageMessage,
|
| 366 |
+
token_list: list[torch.Tensor],
|
| 367 |
+
image_list: list[torch.Tensor | None],
|
| 368 |
+
img_size: int | None = None,
|
| 369 |
+
) -> None:
|
| 370 |
+
img = message["image"]
|
| 371 |
+
if img is None:
|
| 372 |
+
image_list.append(None)
|
| 373 |
+
else:
|
| 374 |
+
image_list.append(
|
| 375 |
+
self.image_processor.process_images(
|
| 376 |
+
self._load_image(img), img_size=img_size
|
| 377 |
+
).squeeze(0)
|
| 378 |
+
)
|
| 379 |
+
if self.pre_image_tokens:
|
| 380 |
+
token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long))
|
| 381 |
+
|
| 382 |
+
if self.post_image_tokens:
|
| 383 |
+
token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long))
|
| 384 |
+
|
| 385 |
+
def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image:
|
| 386 |
+
if isinstance(image_path_or_image, str):
|
| 387 |
+
return Image.open(image_path_or_image).convert("RGB")
|
| 388 |
+
return image_path_or_image
|
| 389 |
+
|
| 390 |
+
def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor:
|
| 391 |
+
return torch.nn.functional.pad(
|
| 392 |
+
tokens,
|
| 393 |
+
(0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0),
|
| 394 |
+
value=pad_value,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
def pad_tokenized_messages(
|
| 398 |
+
self,
|
| 399 |
+
tokenized_messages_batch: list[torch.Tensor],
|
| 400 |
+
image_insertion_points_batch: list[torch.Tensor] | None = None,
|
| 401 |
+
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
|
| 402 |
+
max_len = max(len(x) for x in tokenized_messages_batch)
|
| 403 |
+
if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left":
|
| 404 |
+
image_insertion_points_batch = [
|
| 405 |
+
x + max_len - len(tokenized_messages_batch[idx])
|
| 406 |
+
for idx, x in enumerate(image_insertion_points_batch)
|
| 407 |
+
]
|
| 408 |
+
input_ids = torch.stack(
|
| 409 |
+
[
|
| 410 |
+
self._maybe_pad(s, max_len - s.size(0), self._pad_token)
|
| 411 |
+
for s in tokenized_messages_batch
|
| 412 |
+
],
|
| 413 |
+
dim=0,
|
| 414 |
+
)
|
| 415 |
+
attention_mask = torch.stack(
|
| 416 |
+
[
|
| 417 |
+
self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0)
|
| 418 |
+
for s in tokenized_messages_batch
|
| 419 |
+
],
|
| 420 |
+
dim=0,
|
| 421 |
+
)
|
| 422 |
+
return input_ids, attention_mask, image_insertion_points_batch
|
| 423 |
+
|
| 424 |
+
def tokenize_messages(
|
| 425 |
+
self,
|
| 426 |
+
messages: ProcessorInput,
|
| 427 |
+
suppress_bos_token: bool = False,
|
| 428 |
+
**kwargs: Any,
|
| 429 |
+
) -> ProcessorOutput | None:
|
| 430 |
+
"""Tokenize a batch of messages into token IDs suitable for Helium1 CASA model.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages),
|
| 434 |
+
where each message is a list of dictionaries with 'role' and 'content' keys.
|
| 435 |
+
continue_final_message (bool, optional): If True, the final message in each list will not have an end token added.
|
| 436 |
+
Defaults to False.
|
| 437 |
+
suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added.
|
| 438 |
+
Defaults to False.
|
| 439 |
+
**kwargs: Additional keyword arguments passed to the underlying encode method.
|
| 440 |
+
"""
|
| 441 |
+
if not messages:
|
| 442 |
+
return None
|
| 443 |
+
if isinstance(messages[0], dict):
|
| 444 |
+
messages = [messages] # type: ignore[assignment]
|
| 445 |
+
|
| 446 |
+
messages = cast(list[list[Message]], messages)
|
| 447 |
+
image_insertion_points_batch = []
|
| 448 |
+
tokenized_messages_batch = []
|
| 449 |
+
image_list: list[torch.Tensor | None] = []
|
| 450 |
+
for msgs in messages:
|
| 451 |
+
# msgs.append({
|
| 452 |
+
# "role": "assistant",
|
| 453 |
+
# "content": [{"type": "text", "text": ""}]
|
| 454 |
+
# })
|
| 455 |
+
tokenized_messages = []
|
| 456 |
+
if not suppress_bos_token and self.bos_token is not None:
|
| 457 |
+
tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long))
|
| 458 |
+
insertion_points = []
|
| 459 |
+
token_count = 0
|
| 460 |
+
for msg in msgs:
|
| 461 |
+
token_count = self._process_content(
|
| 462 |
+
msg["content"],
|
| 463 |
+
role=msg["role"],
|
| 464 |
+
tokenized_messages=tokenized_messages,
|
| 465 |
+
insertion_points=insertion_points,
|
| 466 |
+
image_list=image_list,
|
| 467 |
+
token_count=token_count,
|
| 468 |
+
**kwargs,
|
| 469 |
+
)
|
| 470 |
+
tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long))
|
| 471 |
+
image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long))
|
| 472 |
+
|
| 473 |
+
if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant":
|
| 474 |
+
# Remove the assistant end tokens from the final message
|
| 475 |
+
end_token_len = len(self.asst_end_tokens)
|
| 476 |
+
tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len]
|
| 477 |
+
if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user":
|
| 478 |
+
# Remove the assistant end tokens from the final message
|
| 479 |
+
end_token_len = len(self.asst_end_tokens)
|
| 480 |
+
tokenized_messages_batch[-1] = torch.cat(
|
| 481 |
+
[
|
| 482 |
+
tokenized_messages_batch[-1],
|
| 483 |
+
torch.Tensor(self.asst_start_tokens).to(torch.long),
|
| 484 |
+
]
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages(
|
| 488 |
+
tokenized_messages_batch, image_insertion_points_batch
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if image_list:
|
| 492 |
+
assert sum(img is None for img in image_list) % len(image_list) == 0, (
|
| 493 |
+
"Either all or no image must be None."
|
| 494 |
+
)
|
| 495 |
+
pixel_values: None | torch.Tensor | list[torch.Tensor]
|
| 496 |
+
if image_list[0] is None:
|
| 497 |
+
pixel_values = None
|
| 498 |
+
else:
|
| 499 |
+
pixel_values = cast(list[torch.Tensor], image_list)
|
| 500 |
+
return ProcessorOutput(
|
| 501 |
+
input_ids=input_ids,
|
| 502 |
+
image_embeds_insertion_points=image_embeds_insertion_points,
|
| 503 |
+
attention_mask=attention_mask,
|
| 504 |
+
pixel_values=pixel_values,
|
| 505 |
+
)
|
processing_helium1_casa.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 2 |
+
|
| 3 |
+
from .processing import BaseProcessor, QwenImageProcessor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Helium1CASAProcessor(BaseProcessor):
|
| 7 |
+
attributes = ["tokenizer"]
|
| 8 |
+
tokenizer_class = "PreTrainedTokenizerFast"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 13 |
+
pre_image_tokens: tuple[int, ...] = tuple(),
|
| 14 |
+
post_image_tokens: tuple[int, ...] = tuple(),
|
| 15 |
+
system_start_tokens: tuple[int, ...] = tuple(),
|
| 16 |
+
system_end_tokens: tuple[int, ...] = tuple(),
|
| 17 |
+
user_start_tokens: tuple[int, ...] = (104,),
|
| 18 |
+
user_end_tokens: tuple[int, ...] = (105,),
|
| 19 |
+
asst_start_tokens: tuple[int, ...] = (102,),
|
| 20 |
+
asst_end_tokens: tuple[int, ...] = (103,),
|
| 21 |
+
bos_token: int = 1,
|
| 22 |
+
image_size: int = 896,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(
|
| 25 |
+
tokenizer=tokenizer,
|
| 26 |
+
pre_image_tokens=pre_image_tokens,
|
| 27 |
+
post_image_tokens=post_image_tokens,
|
| 28 |
+
system_start_tokens=system_start_tokens,
|
| 29 |
+
system_end_tokens=system_end_tokens,
|
| 30 |
+
user_start_tokens=user_start_tokens,
|
| 31 |
+
user_end_tokens=user_end_tokens,
|
| 32 |
+
asst_start_tokens=asst_start_tokens,
|
| 33 |
+
asst_end_tokens=asst_end_tokens,
|
| 34 |
+
allow_system_prompt=False,
|
| 35 |
+
bos_token=bos_token,
|
| 36 |
+
)
|
| 37 |
+
self._image_processor = QwenImageProcessor(img_size=image_size)
|
processor_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_helium1_casa.Helium1CASAProcessor"
|
| 4 |
+
},
|
| 5 |
+
"bos_token": 1,
|
| 6 |
+
"image_size": 896,
|
| 7 |
+
"post_image_tokens": [],
|
| 8 |
+
"pre_image_tokens": [],
|
| 9 |
+
"processor_class": "Helium1CASAProcessor"
|
| 10 |
+
}
|
readme_images/CASA.png
ADDED
|
Git LFS Details
|
readme_images/casa_explainer.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5ee73e4e8ea65ebc8d3e53d6468a3d2688d5f656259bcae73bffd843e5a0e69
|
| 3 |
+
size 299297
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer.model
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html class="">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no" />
|
| 7 |
+
|
| 8 |
+
<meta name="description" content="We’re on a journey to advance and democratize artificial intelligence through open source and open science." />
|
| 9 |
+
|
| 10 |
+
<meta property="fb:app_id" content="1321688464574422" />
|
| 11 |
+
|
| 12 |
+
<meta name="twitter:card" content="summary_large_image" />
|
| 13 |
+
|
| 14 |
+
<meta name="twitter:site" content="@huggingface" />
|
| 15 |
+
|
| 16 |
+
<meta name="twitter:image" content="https://cdn-thumbnails.huggingface.co/social-thumbnails/models/kyutai/helium-1-2b.png" />
|
| 17 |
+
|
| 18 |
+
<meta property="og:title" content="tokenizer.model · kyutai/helium-1-2b at main" />
|
| 19 |
+
|
| 20 |
+
<meta property="og:type" content="website" />
|
| 21 |
+
|
| 22 |
+
<meta property="og:url" content="https://huggingface.co/kyutai/helium-1-2b/blob/main/tokenizer.model" />
|
| 23 |
+
|
| 24 |
+
<meta property="og:image" content="https://cdn-thumbnails.huggingface.co/social-thumbnails/models/kyutai/helium-1-2b.png" />
|
| 25 |
+
|
| 26 |
+
<link rel="stylesheet" href="/front/build/kube-02d86c8/style.css" />
|
| 27 |
+
|
| 28 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" />
|
| 29 |
+
|
| 30 |
+
<link
|
| 31 |
+
href="https://fonts.googleapis.com/css2?family=Source+Sans+Pro:ital,wght@0,200;0,300;0,400;0,600;0,700;1,200;1,300;1,400;1,600;1,700&display=swap"
|
| 32 |
+
rel="stylesheet"
|
| 33 |
+
/>
|
| 34 |
+
|
| 35 |
+
<link
|
| 36 |
+
href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600;700&display=swap"
|
| 37 |
+
rel="stylesheet"
|
| 38 |
+
/>
|
| 39 |
+
|
| 40 |
+
<link
|
| 41 |
+
rel="preload"
|
| 42 |
+
href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.12.0/katex.min.css"
|
| 43 |
+
as="style"
|
| 44 |
+
onload="this.onload=null;this.rel='stylesheet'"
|
| 45 |
+
/>
|
| 46 |
+
|
| 47 |
+
<noscript>
|
| 48 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.12.0/katex.min.css" />
|
| 49 |
+
</noscript>
|
| 50 |
+
<script>const guestTheme = document.cookie.match(/theme=(\w+)/)?.[1]; document.documentElement.classList.toggle('dark', guestTheme === 'dark' || ( (!guestTheme || guestTheme === 'system') && window.matchMedia('(prefers-color-scheme: dark)').matches));</script>
|
| 51 |
+
<link rel="canonical" href="https://huggingface.co/kyutai/helium-1-2b/blob/main/tokenizer.model">
|
| 52 |
+
<title>tokenizer.model · kyutai/helium-1-2b at main</title>
|
| 53 |
+
|
| 54 |
+
<script defer src="/js/script.js"></script>
|
| 55 |
+
|
| 56 |
+
<script>
|
| 57 |
+
(window.plausible =
|
| 58 |
+
window.plausible ||
|
| 59 |
+
function () {
|
| 60 |
+
(plausible.q = plausible.q || []).push(arguments);
|
| 61 |
+
}),
|
| 62 |
+
(plausible.init =
|
| 63 |
+
plausible.init ||
|
| 64 |
+
function (i) {
|
| 65 |
+
plausible.o = i || {};
|
| 66 |
+
});
|
| 67 |
+
plausible.init({
|
| 68 |
+
customProperties: {
|
| 69 |
+
loggedIn: "false",
|
| 70 |
+
},
|
| 71 |
+
endpoint: "/api/event",
|
| 72 |
+
});
|
| 73 |
+
</script>
|
| 74 |
+
|
| 75 |
+
<script>
|
| 76 |
+
window.hubConfig = {"features":{"signupDisabled":false},"sshGitUrl":"git@hf.co","moonHttpUrl":"https:\/\/huggingface.co","captchaApiKey":"bd5f2066-93dc-4bdd-a64b-a24646ca3859","datasetViewerPublicUrl":"https:\/\/datasets-server.huggingface.co","stripePublicKey":"pk_live_x2tdjFXBCvXo2FFmMybezpeM00J6gPCAAc","environment":"production","userAgent":"HuggingFace (production)","spacesIframeDomain":"hf.space","spacesApiUrl":"https:\/\/api.hf.space","docSearchKey":"ece5e02e57300e17d152c08056145326e90c4bff3dd07d7d1ae40cf1c8d39cb6","logoDev":{"apiUrl":"https:\/\/img.logo.dev\/","apiKey":"pk_UHS2HZOeRnaSOdDp7jbd5w"}};
|
| 77 |
+
</script>
|
| 78 |
+
<script type="text/javascript" src="https://de5282c3ca0c.edge.sdk.awswaf.com/de5282c3ca0c/526cf06acb0d/challenge.js" defer></script>
|
| 79 |
+
</head>
|
| 80 |
+
<body class="flex flex-col min-h-dvh bg-white dark:bg-gray-950 text-black ViewerBlobPage">
|
| 81 |
+
<div class="flex min-h-dvh flex-col"><div class="SVELTE_HYDRATER contents" data-target="DeviceProvider" data-props="{}"></div>
|
| 82 |
+
<div class="SVELTE_HYDRATER contents" data-target="SystemThemeMonitor" data-props="{"isLoggedIn":false}"></div>
|
| 83 |
+
|
| 84 |
+
<div class="SVELTE_HYDRATER contents" data-target="MainHeader" data-props="{"classNames":"","isWide":false,"isZh":false,"isPro":false}"><header class="border-b border-gray-100 "><div class="w-full px-4 container flex h-16 items-center"><div class="flex flex-1 items-center"><a class="mr-5 flex flex-none items-center lg:mr-6" href="/"><img alt="Hugging Face's logo" class="w-7 md:mr-2" src="/front/assets/huggingface_logo-noborder.svg">
|
| 85 |
+
<span class="hidden whitespace-nowrap text-lg font-bold md:block">Hugging Face</span></a>
|
| 86 |
+
<div class="relative flex-1 lg:max-w-sm mr-2 sm:mr-4 md:mr-3 xl:mr-6"><input autocomplete="off" class="w-full dark:bg-gray-950 pl-8 form-input-alt h-9 pr-3 focus:shadow-xl " name="" placeholder="Search models, datasets, users..." spellcheck="false" type="text" value="">
|
| 87 |
+
<svg class="absolute left-2.5 text-gray-400 top-1/2 transform -translate-y-1/2" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M30 28.59L22.45 21A11 11 0 1 0 21 22.45L28.59 30zM5 14a9 9 0 1 1 9 9a9 9 0 0 1-9-9z" fill="currentColor"></path></svg>
|
| 88 |
+
</div>
|
| 89 |
+
<div class="flex flex-none items-center justify-center p-0.5 place-self-stretch lg:hidden"><button class="relative z-40 flex h-6 w-8 items-center justify-center" type="button"><svg width="1em" height="1em" viewBox="0 0 10 10" class="text-xl" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" preserveAspectRatio="xMidYMid meet" fill="currentColor"><path fill-rule="evenodd" clip-rule="evenodd" d="M1.65039 2.9999C1.65039 2.8066 1.80709 2.6499 2.00039 2.6499H8.00039C8.19369 2.6499 8.35039 2.8066 8.35039 2.9999C8.35039 3.1932 8.19369 3.3499 8.00039 3.3499H2.00039C1.80709 3.3499 1.65039 3.1932 1.65039 2.9999ZM1.65039 4.9999C1.65039 4.8066 1.80709 4.6499 2.00039 4.6499H8.00039C8.19369 4.6499 8.35039 4.8066 8.35039 4.9999C8.35039 5.1932 8.19369 5.3499 8.00039 5.3499H2.00039C1.80709 5.3499 1.65039 5.1932 1.65039 4.9999ZM2.00039 6.6499C1.80709 6.6499 1.65039 6.8066 1.65039 6.9999C1.65039 7.1932 1.80709 7.3499 2.00039 7.3499H8.00039C8.19369 7.3499 8.35039 7.1932 8.35039 6.9999C8.35039 6.8066 8.19369 6.6499 8.00039 6.6499H2.00039Z"></path></svg>
|
| 90 |
+
</button>
|
| 91 |
+
|
| 92 |
+
</div></div>
|
| 93 |
+
<nav aria-label="Main" class="ml-auto hidden lg:block"><ul class="flex items-center gap-x-1 2xl:gap-x-2"><li class="hover:text-indigo-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/models"><svg class="mr-1.5 text-gray-400 group-hover:text-indigo-500" style="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-quaternary" d="M20.23 7.24L12 12L3.77 7.24a1.98 1.98 0 0 1 .7-.71L11 2.76c.62-.35 1.38-.35 2 0l6.53 3.77c.29.173.531.418.7.71z" opacity=".25" fill="currentColor"></path><path class="uim-tertiary" d="M12 12v9.5a2.09 2.09 0 0 1-.91-.21L4.5 17.48a2.003 2.003 0 0 1-1-1.73v-7.5a2.06 2.06 0 0 1 .27-1.01L12 12z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M20.5 8.25v7.5a2.003 2.003 0 0 1-1 1.73l-6.62 3.82c-.275.13-.576.198-.88.2V12l8.23-4.76c.175.308.268.656.27 1.01z" fill="currentColor"></path></svg>
|
| 94 |
+
Models</a>
|
| 95 |
+
</li><li class="hover:text-red-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/datasets"><svg class="mr-1.5 text-gray-400 group-hover:text-red-500" style="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 25 25"><ellipse cx="12.5" cy="5" fill="currentColor" fill-opacity="0.25" rx="7.5" ry="2"></ellipse><path d="M12.5 15C16.6421 15 20 14.1046 20 13V20C20 21.1046 16.6421 22 12.5 22C8.35786 22 5 21.1046 5 20V13C5 14.1046 8.35786 15 12.5 15Z" fill="currentColor" opacity="0.5"></path><path d="M12.5 7C16.6421 7 20 6.10457 20 5V11.5C20 12.6046 16.6421 13.5 12.5 13.5C8.35786 13.5 5 12.6046 5 11.5V5C5 6.10457 8.35786 7 12.5 7Z" fill="currentColor" opacity="0.5"></path><path d="M5.23628 12C5.08204 12.1598 5 12.8273 5 13C5 14.1046 8.35786 15 12.5 15C16.6421 15 20 14.1046 20 13C20 12.8273 19.918 12.1598 19.7637 12C18.9311 12.8626 15.9947 13.5 12.5 13.5C9.0053 13.5 6.06886 12.8626 5.23628 12Z" fill="currentColor"></path></svg>
|
| 96 |
+
Datasets</a>
|
| 97 |
+
</li><li class="hover:text-blue-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/spaces"><svg class="mr-1.5 text-gray-400 group-hover:text-blue-500" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" viewBox="0 0 25 25"><path opacity=".5" d="M6.016 14.674v4.31h4.31v-4.31h-4.31ZM14.674 14.674v4.31h4.31v-4.31h-4.31ZM6.016 6.016v4.31h4.31v-4.31h-4.31Z" fill="currentColor"></path><path opacity=".75" fill-rule="evenodd" clip-rule="evenodd" d="M3 4.914C3 3.857 3.857 3 4.914 3h6.514c.884 0 1.628.6 1.848 1.414a5.171 5.171 0 0 1 7.31 7.31c.815.22 1.414.964 1.414 1.848v6.514A1.914 1.914 0 0 1 20.086 22H4.914A1.914 1.914 0 0 1 3 20.086V4.914Zm3.016 1.102v4.31h4.31v-4.31h-4.31Zm0 12.968v-4.31h4.31v4.31h-4.31Zm8.658 0v-4.31h4.31v4.31h-4.31Zm0-10.813a2.155 2.155 0 1 1 4.31 0 2.155 2.155 0 0 1-4.31 0Z" fill="currentColor"></path><path opacity=".25" d="M16.829 6.016a2.155 2.155 0 1 0 0 4.31 2.155 2.155 0 0 0 0-4.31Z" fill="currentColor"></path></svg>
|
| 98 |
+
Spaces</a>
|
| 99 |
+
</li><li class="max-xl:hidden relative"><div class="relative ">
|
| 100 |
+
<button class="group flex items-center px-2 py-0.5 dark:text-gray-300 hover:text-yellow-700 dark:hover:text-gray-100 " type="button">
|
| 101 |
+
<svg class="mr-1.5 mr-1.5 text-gray-400 text-yellow-500! group-hover:text-yellow-500" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path><path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path></svg>
|
| 102 |
+
Community
|
| 103 |
+
</button>
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
</div>
|
| 107 |
+
</li><li class="hover:text-yellow-700"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/docs"><svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="mr-1.5 text-gray-400 group-hover:text-yellow-500" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 16 16"><path d="m2.28 3.7-.3.16a.67.67 0 0 0-.34.58v8.73l.01.04.02.07.01.04.03.06.02.04.02.03.04.06.05.05.04.04.06.04.06.04.08.04.08.02h.05l.07.02h.11l.04-.01.07-.02.03-.01.07-.03.22-.12a5.33 5.33 0 0 1 5.15.1.67.67 0 0 0 .66 0 5.33 5.33 0 0 1 5.33 0 .67.67 0 0 0 1-.58V4.36a.67.67 0 0 0-.34-.5l-.3-.17v7.78a.63.63 0 0 1-.87.59 4.9 4.9 0 0 0-4.35.35l-.65.39a.29.29 0 0 1-.15.04.29.29 0 0 1-.16-.04l-.65-.4a4.9 4.9 0 0 0-4.34-.34.63.63 0 0 1-.87-.59V3.7Z" fill="currentColor" class="dark:opacity-40"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M8 3.1a5.99 5.99 0 0 0-5.3-.43.66.66 0 0 0-.42.62v8.18c0 .45.46.76.87.59a4.9 4.9 0 0 1 4.34.35l.65.39c.05.03.1.04.16.04.05 0 .1-.01.15-.04l.65-.4a4.9 4.9 0 0 1 4.35-.34.63.63 0 0 0 .86-.59V3.3a.67.67 0 0 0-.41-.62 5.99 5.99 0 0 0-5.3.43l-.3.17L8 3.1Zm.73 1.87a.43.43 0 1 0-.86 0v5.48a.43.43 0 0 0 .86 0V4.97Z" fill="currentColor" class="opacity-40 dark:opacity-100"></path><path d="M8.73 4.97a.43.43 0 1 0-.86 0v5.48a.43.43 0 1 0 .86 0V4.96Z" fill="currentColor" class="dark:opacity-40"></path></svg>
|
| 108 |
+
Docs</a>
|
| 109 |
+
</li><li class="hover:text-black dark:hover:text-white max-2xl:hidden"><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/enterprise"><svg class="mr-1.5 text-gray-400 group-hover:text-black dark:group-hover:text-white" xmlns="http://www.w3.org/2000/svg" fill="none" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 12 12"><path fill-rule="evenodd" clip-rule="evenodd" d="M4.9 1.35a3.16 3.16 0 0 0-2.8 2.07L.37 8.58C0 9.71.7 10.65 1.86 10.65H7.3a3.2 3.2 0 0 0 2.84-2.07l1.67-5.16c.36-1.13-.3-2.07-1.46-2.07H4.91Zm.4 2.07L3.57 8.47h3.57l.36-1.12H5.4l.28-.91h1.75l.4-1.1H6.07l.3-.83h2l.36-1.1H5.27h.04Z" fill="currentColor"></path></svg>
|
| 110 |
+
Enterprise</a>
|
| 111 |
+
</li>
|
| 112 |
+
|
| 113 |
+
<li><a class="group flex items-center px-2 py-0.5 dark:text-gray-300 dark:hover:text-gray-100" href="/pricing">Pricing
|
| 114 |
+
</a></li>
|
| 115 |
+
|
| 116 |
+
<li><div class="relative group">
|
| 117 |
+
<button class="px-2 py-0.5 hover:text-gray-500 dark:hover:text-gray-600 flex items-center " type="button">
|
| 118 |
+
<svg class=" text-gray-500 w-5 group-hover:text-gray-400 dark:text-gray-300 dark:group-hover:text-gray-100" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" viewBox="0 0 32 18" preserveAspectRatio="xMidYMid meet"><path fill-rule="evenodd" clip-rule="evenodd" d="M14.4504 3.30221C14.4504 2.836 14.8284 2.45807 15.2946 2.45807H28.4933C28.9595 2.45807 29.3374 2.836 29.3374 3.30221C29.3374 3.76842 28.9595 4.14635 28.4933 4.14635H15.2946C14.8284 4.14635 14.4504 3.76842 14.4504 3.30221Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M14.4504 9.00002C14.4504 8.53382 14.8284 8.15588 15.2946 8.15588H28.4933C28.9595 8.15588 29.3374 8.53382 29.3374 9.00002C29.3374 9.46623 28.9595 9.84417 28.4933 9.84417H15.2946C14.8284 9.84417 14.4504 9.46623 14.4504 9.00002Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M14.4504 14.6978C14.4504 14.2316 14.8284 13.8537 15.2946 13.8537H28.4933C28.9595 13.8537 29.3374 14.2316 29.3374 14.6978C29.3374 15.164 28.9595 15.542 28.4933 15.542H15.2946C14.8284 15.542 14.4504 15.164 14.4504 14.6978Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M1.94549 6.87377C2.27514 6.54411 2.80962 6.54411 3.13928 6.87377L6.23458 9.96907L9.32988 6.87377C9.65954 6.54411 10.194 6.54411 10.5237 6.87377C10.8533 7.20343 10.8533 7.73791 10.5237 8.06756L6.23458 12.3567L1.94549 8.06756C1.61583 7.73791 1.61583 7.20343 1.94549 6.87377Z" fill="currentColor"></path></svg>
|
| 119 |
+
|
| 120 |
+
</button>
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
</div></li>
|
| 124 |
+
<li><hr class="h-5 w-0.5 border-none bg-gray-100 dark:bg-gray-800"></li>
|
| 125 |
+
<li><a class="block cursor-pointer whitespace-nowrap px-2 py-0.5 hover:text-gray-500 dark:text-gray-300 dark:hover:text-gray-100" href="/login">Log In
|
| 126 |
+
</a></li>
|
| 127 |
+
<li><a class="whitespace-nowrap rounded-full border border-transparent bg-gray-900 px-3 py-1 leading-none text-white hover:border-black hover:bg-white hover:text-black" href="/join">Sign Up
|
| 128 |
+
</a></li></ul></nav></div></header></div>
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
<div class="SVELTE_HYDRATER contents" data-target="SSOBanner" data-props="{}"></div>
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
<main class="flex flex-1 flex-col">
|
| 136 |
+
<div class="SVELTE_HYDRATER contents" data-target="ModelHeader" data-props="{"activeTab":"files","author":{"_id":"6683d6350b54a28aff6645fe","avatarUrl":"https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/8xGdIOlfkopZfhbMitw_k.jpeg","fullname":"Kyutai","name":"kyutai","type":"org","isHf":false,"isHfAdmin":false,"isMod":false,"isEnterprise":false,"followerCount":886},"canReadRepoSettings":false,"canWriteRepoContent":false,"canDisable":false,"model":{"author":"kyutai","cardData":{"library_name":"transformers","license":"cc-by-sa-4.0","language":["bg","cs","da","de","el","en","es","et","fi","fr","ga","hr","hu","it","lt","lv","mt","nl","pl","pt","ro","sk","sl","sv"],"pipeline_tag":"text-generation"},"cardExists":true,"config":{"architectures":["LlamaForCausalLM"],"model_type":"llama"},"createdAt":"2025-04-30T13:59:54.000Z","discussionsDisabled":false,"discussionsSorting":"recently-created","downloads":28282,"downloadsAllTime":456198,"id":"kyutai/helium-1-2b","isLikedByUser":false,"availableInferenceProviders":[],"inference":"","lastModified":"2025-04-30T14:38:01.000Z","likes":42,"pipeline_tag":"text-generation","library_name":"transformers","librariesOther":[],"trackDownloads":true,"model-index":null,"private":false,"repoType":"model","gated":false,"pwcLink":{"error":"Unknown error, can't generate link to Papers With Code."},"tags":["transformers","safetensors","llama","text-generation","bg","cs","da","de","el","en","es","et","fi","fr","ga","hr","hu","it","lt","lv","mt","nl","pl","pt","ro","sk","sl","sv","license:cc-by-sa-4.0","text-generation-inference","endpoints_compatible","region:us"],"tag_objs":[{"id":"text-generation","label":"Text Generation","type":"pipeline_tag","subType":"nlp"},{"id":"transformers","label":"Transformers","type":"library"},{"id":"safetensors","label":"Safetensors","type":"library"},{"id":"bg","label":"Bulgarian","type":"language"},{"id":"cs","label":"Czech","type":"language"},{"id":"da","label":"Danish","type":"language"},{"id":"de","label":"German","type":"language"},{"id":"el","label":"Greek","type":"language"},{"id":"en","label":"English","type":"language"},{"id":"es","label":"Spanish","type":"language"},{"id":"et","label":"Estonian","type":"language"},{"id":"fi","label":"Finnish","type":"language"},{"id":"fr","label":"French","type":"language"},{"id":"ga","label":"Irish","type":"language"},{"id":"hr","label":"Croatian","type":"language"},{"id":"hu","label":"Hungarian","type":"language"},{"id":"it","label":"Italian","type":"language"},{"id":"lt","label":"Lithuanian","type":"language"},{"id":"lv","label":"Latvian","type":"language"},{"id":"mt","label":"Maltese","type":"language"},{"id":"nl","label":"Dutch","type":"language"},{"id":"pl","label":"Polish","type":"language"},{"id":"pt","label":"Portuguese","type":"language"},{"id":"ro","label":"Romanian","type":"language"},{"id":"sk","label":"Slovak","type":"language"},{"id":"sl","label":"Slovenian","type":"language"},{"id":"sv","label":"Swedish","type":"language"},{"id":"llama","label":"llama","type":"other","clickable":true},{"id":"text-generation-inference","label":"text-generation-inference","type":"other","clickable":true},{"id":"endpoints_compatible","label":"Inference Endpoints","type":"other","clickable":true},{"id":"license:cc-by-sa-4.0","label":"cc-by-sa-4.0","type":"license"},{"type":"region","label":"🇺🇸 Region: US","id":"region:us"}],"transformersInfo":{"auto_model":"AutoModelForCausalLM","pipeline_tag":"text-generation","processor":"AutoTokenizer"},"safetensors":{"parameters":{"BF16":2023868416},"total":2023868416,"sharded":false},"hasBlockedOids":false,"region":"us","isQuantized":false},"discussionsStats":{"closed":1,"open":0,"total":1},"query":{},"inferenceContextData":{"billableEntities":[],"entityName2Providers":{}}}"><header class="bg-linear-to-t border-b border-gray-100 pt-6 sm:pt-9 from-purple-500/8 dark:from-purple-500/20 to-white to-70% dark:to-gray-950"><div class="container relative "><h1 class="flex flex-wrap items-center max-md:leading-tight mb-3 text-lg max-sm:gap-y-1.5 md:text-xl">
|
| 137 |
+
<div class="group flex flex-none items-center"><div class="relative mr-1 flex items-center">
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
<span class="inline-block "><span class="contents"><a href="/kyutai" class="text-gray-400 hover:text-blue-600"><img alt="" class="size-3.5 rounded-sm flex-none select-none" src="https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/8xGdIOlfkopZfhbMitw_k.jpeg" crossorigin="anonymous"></a></span>
|
| 142 |
+
</span></div>
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
<span class="inline-block "><span class="contents"><a href="/kyutai" class="text-gray-400 hover:text-blue-600">kyutai</a></span>
|
| 146 |
+
</span>
|
| 147 |
+
<div class="mx-0.5 text-gray-300">/</div></div>
|
| 148 |
+
|
| 149 |
+
<div class="max-w-full "><a class="break-words font-mono font-semibold hover:text-blue-600 " href="/kyutai/helium-1-2b">helium-1-2b</a>
|
| 150 |
+
<button class="text-sm mr-4 focus:outline-hidden inline-flex cursor-pointer items-center text-sm mx-0.5 text-gray-600 " title="Copy model name to clipboard" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
|
| 151 |
+
</button></div>
|
| 152 |
+
<div class="inline-flex items-center overflow-hidden whitespace-nowrap rounded-md border bg-white text-sm leading-none text-gray-500 mr-2"><button class="relative flex items-center overflow-hidden from-red-50 to-transparent dark:from-red-900 px-1.5 py-1 hover:bg-linear-to-t focus:outline-hidden" title="Like"><svg class="left-1.5 absolute" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32" fill="currentColor"><path d="M22.45,6a5.47,5.47,0,0,1,3.91,1.64,5.7,5.7,0,0,1,0,8L16,26.13,5.64,15.64a5.7,5.7,0,0,1,0-8,5.48,5.48,0,0,1,7.82,0L16,10.24l2.53-2.58A5.44,5.44,0,0,1,22.45,6m0-2a7.47,7.47,0,0,0-5.34,2.24L16,7.36,14.89,6.24a7.49,7.49,0,0,0-10.68,0,7.72,7.72,0,0,0,0,10.82L16,29,27.79,17.06a7.72,7.72,0,0,0,0-10.82A7.49,7.49,0,0,0,22.45,4Z"></path></svg>
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
<span class="ml-4 pl-0.5 ">like</span></button>
|
| 156 |
+
<button class="focus:outline-hidden flex items-center border-l px-1.5 py-1 text-gray-400 hover:bg-gray-50 focus:bg-gray-100 dark:hover:bg-gray-900 dark:focus:bg-gray-800" title="See users who liked this repository">42</button></div>
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
<div class="relative flex items-center gap-1.5 "><div class="mr-2 inline-flex h-6 items-center overflow-hidden whitespace-nowrap rounded-md border text-sm text-gray-500"><button class="focus:outline-hidden relative flex h-full max-w-56 items-center gap-1.5 overflow-hidden px-1.5 hover:bg-gray-50 focus:bg-gray-100 dark:hover:bg-gray-900 dark:focus:bg-gray-800" type="button" ><div class="flex h-full flex-1 items-center justify-center ">Follow</div>
|
| 160 |
+
<img alt="" class="rounded-xs size-3 flex-none select-none" src="https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/8xGdIOlfkopZfhbMitw_k.jpeg" loading="lazy">
|
| 161 |
+
<span class="truncate">Kyutai</span></button>
|
| 162 |
+
<button class="focus:outline-hidden flex h-full items-center border-l pl-1.5 pr-1.5 text-gray-400 hover:bg-gray-50 focus:bg-gray-100 dark:hover:bg-gray-900 dark:focus:bg-gray-800" title="Show Kyutai's followers" type="button">886</button></div>
|
| 163 |
+
|
| 164 |
+
</div>
|
| 165 |
+
|
| 166 |
+
</h1>
|
| 167 |
+
<div class="mb-3 flex flex-wrap md:mb-4">
|
| 168 |
+
<a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?pipeline_tag=text-generation"><div class="tag tag-white "><div class="tag-ico -ml-2 tag-ico-red"><svg class="-mr-0.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 18 18"><path d="M16.2607 8.08202L14.468 6.28928C14.3063 6.12804 14.0873 6.03749 13.859 6.03749C13.6307 6.03749 13.4117 6.12804 13.25 6.28928L5.6375 13.904V16.9125H8.64607L16.2607 9.30002C16.422 9.13836 16.5125 8.91935 16.5125 8.69102C16.5125 8.4627 16.422 8.24369 16.2607 8.08202V8.08202ZM8.1953 15.825H6.725V14.3547L11.858 9.22118L13.3288 10.6915L8.1953 15.825ZM14.0982 9.92262L12.6279 8.45232L13.8606 7.21964L15.3309 8.68994L14.0982 9.92262Z"></path><path d="M6.18125 9.84373H7.26875V6.03748H8.9V4.94998H4.55V6.03748H6.18125V9.84373Z"></path><path d="M4.55 11.475H2.375V2.775H11.075V4.95H12.1625V2.775C12.1625 2.48658 12.0479 2.20997 11.844 2.00602C11.64 1.80208 11.3634 1.6875 11.075 1.6875H2.375C2.08658 1.6875 1.80997 1.80208 1.60602 2.00602C1.40207 2.20997 1.2875 2.48658 1.2875 2.775V11.475C1.2875 11.7634 1.40207 12.04 1.60602 12.244C1.80997 12.4479 2.08658 12.5625 2.375 12.5625H4.55V11.475Z"></path></svg></div>
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
<span>Text Generation</span>
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
</div></a>
|
| 176 |
+
<a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?library=transformers"><div class="tag tag-white "><svg class="text-black inline-block text-sm" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" preserveAspectRatio="xMidYMid meet" width="1em" height="1em" viewBox="0 0 95 88"><path fill="#fff" d="M94.25 70.08a8.28 8.28 0 0 1-.43 6.46 10.57 10.57 0 0 1-3 3.6 25.18 25.18 0 0 1-5.7 3.2 65.74 65.74 0 0 1-7.56 2.65 46.67 46.67 0 0 1-11.42 1.68c-5.42.05-10.09-1.23-13.4-4.5a40.4 40.4 0 0 1-10.14.03c-3.34 3.25-7.99 4.52-13.39 4.47a46.82 46.82 0 0 1-11.43-1.68 66.37 66.37 0 0 1-7.55-2.65c-2.28-.98-4.17-2-5.68-3.2a10.5 10.5 0 0 1-3.02-3.6c-.99-2-1.18-4.3-.42-6.46a8.54 8.54 0 0 1-.33-5.63c.25-.95.66-1.83 1.18-2.61a8.67 8.67 0 0 1 2.1-8.47 8.23 8.23 0 0 1 2.82-2.07 41.75 41.75 0 1 1 81.3-.12 8.27 8.27 0 0 1 3.11 2.19 8.7 8.7 0 0 1 2.1 8.47c.52.78.93 1.66 1.18 2.61a8.61 8.61 0 0 1-.32 5.63Z"></path><path fill="#FFD21E" d="M47.21 76.5a34.75 34.75 0 1 0 0-69.5 34.75 34.75 0 0 0 0 69.5Z"></path><path fill="#FF9D0B" d="M81.96 41.75a34.75 34.75 0 1 0-69.5 0 34.75 34.75 0 0 0 69.5 0Zm-73.5 0a38.75 38.75 0 1 1 77.5 0 38.75 38.75 0 0 1-77.5 0Z"></path><path fill="#3A3B45" d="M58.5 32.3c1.28.44 1.78 3.06 3.07 2.38a5 5 0 1 0-6.76-2.07c.61 1.15 2.55-.72 3.7-.32ZM34.95 32.3c-1.28.44-1.79 3.06-3.07 2.38a5 5 0 1 1 6.76-2.07c-.61 1.15-2.56-.72-3.7-.32Z"></path><path fill="#FF323D" d="M46.96 56.29c9.83 0 13-8.76 13-13.26 0-2.34-1.57-1.6-4.09-.36-2.33 1.15-5.46 2.74-8.9 2.74-7.19 0-13-6.88-13-2.38s3.16 13.26 13 13.26Z"></path><path fill="#3A3B45" fill-rule="evenodd" d="M39.43 54a8.7 8.7 0 0 1 5.3-4.49c.4-.12.81.57 1.24 1.28.4.68.82 1.37 1.24 1.37.45 0 .9-.68 1.33-1.35.45-.7.89-1.38 1.32-1.25a8.61 8.61 0 0 1 5 4.17c3.73-2.94 5.1-7.74 5.1-10.7 0-2.34-1.57-1.6-4.09-.36l-.14.07c-2.31 1.15-5.39 2.67-8.77 2.67s-6.45-1.52-8.77-2.67c-2.6-1.29-4.23-2.1-4.23.29 0 3.05 1.46 8.06 5.47 10.97Z" clip-rule="evenodd"></path><path fill="#FF9D0B" d="M70.71 37a3.25 3.25 0 1 0 0-6.5 3.25 3.25 0 0 0 0 6.5ZM24.21 37a3.25 3.25 0 1 0 0-6.5 3.25 3.25 0 0 0 0 6.5ZM17.52 48c-1.62 0-3.06.66-4.07 1.87a5.97 5.97 0 0 0-1.33 3.76 7.1 7.1 0 0 0-1.94-.3c-1.55 0-2.95.59-3.94 1.66a5.8 5.8 0 0 0-.8 7 5.3 5.3 0 0 0-1.79 2.82c-.24.9-.48 2.8.8 4.74a5.22 5.22 0 0 0-.37 5.02c1.02 2.32 3.57 4.14 8.52 6.1 3.07 1.22 5.89 2 5.91 2.01a44.33 44.33 0 0 0 10.93 1.6c5.86 0 10.05-1.8 12.46-5.34 3.88-5.69 3.33-10.9-1.7-15.92-2.77-2.78-4.62-6.87-5-7.77-.78-2.66-2.84-5.62-6.25-5.62a5.7 5.7 0 0 0-4.6 2.46c-1-1.26-1.98-2.25-2.86-2.82A7.4 7.4 0 0 0 17.52 48Zm0 4c.51 0 1.14.22 1.82.65 2.14 1.36 6.25 8.43 7.76 11.18.5.92 1.37 1.31 2.14 1.31 1.55 0 2.75-1.53.15-3.48-3.92-2.93-2.55-7.72-.68-8.01.08-.02.17-.02.24-.02 1.7 0 2.45 2.93 2.45 2.93s2.2 5.52 5.98 9.3c3.77 3.77 3.97 6.8 1.22 10.83-1.88 2.75-5.47 3.58-9.16 3.58-3.81 0-7.73-.9-9.92-1.46-.11-.03-13.45-3.8-11.76-7 .28-.54.75-.76 1.34-.76 2.38 0 6.7 3.54 8.57 3.54.41 0 .7-.17.83-.6.79-2.85-12.06-4.05-10.98-8.17.2-.73.71-1.02 1.44-1.02 3.14 0 10.2 5.53 11.68 5.53.11 0 .2-.03.24-.1.74-1.2.33-2.04-4.9-5.2-5.21-3.16-8.88-5.06-6.8-7.33.24-.26.58-.38 1-.38 3.17 0 10.66 6.82 10.66 6.82s2.02 2.1 3.25 2.1c.28 0 .52-.1.68-.38.86-1.46-8.06-8.22-8.56-11.01-.34-1.9.24-2.85 1.31-2.85Z"></path><path fill="#FFD21E" d="M38.6 76.69c2.75-4.04 2.55-7.07-1.22-10.84-3.78-3.77-5.98-9.3-5.98-9.3s-.82-3.2-2.69-2.9c-1.87.3-3.24 5.08.68 8.01 3.91 2.93-.78 4.92-2.29 2.17-1.5-2.75-5.62-9.82-7.76-11.18-2.13-1.35-3.63-.6-3.13 2.2.5 2.79 9.43 9.55 8.56 11-.87 1.47-3.93-1.71-3.93-1.71s-9.57-8.71-11.66-6.44c-2.08 2.27 1.59 4.17 6.8 7.33 5.23 3.16 5.64 4 4.9 5.2-.75 1.2-12.28-8.53-13.36-4.4-1.08 4.11 11.77 5.3 10.98 8.15-.8 2.85-9.06-5.38-10.74-2.18-1.7 3.21 11.65 6.98 11.76 7.01 4.3 1.12 15.25 3.49 19.08-2.12Z"></path><path fill="#FF9D0B" d="M77.4 48c1.62 0 3.07.66 4.07 1.87a5.97 5.97 0 0 1 1.33 3.76 7.1 7.1 0 0 1 1.95-.3c1.55 0 2.95.59 3.94 1.66a5.8 5.8 0 0 1 .8 7 5.3 5.3 0 0 1 1.78 2.82c.24.9.48 2.8-.8 4.74a5.22 5.22 0 0 1 .37 5.02c-1.02 2.32-3.57 4.14-8.51 6.1-3.08 1.22-5.9 2-5.92 2.01a44.33 44.33 0 0 1-10.93 1.6c-5.86 0-10.05-1.8-12.46-5.34-3.88-5.69-3.33-10.9 1.7-15.92 2.78-2.78 4.63-6.87 5.01-7.77.78-2.66 2.83-5.62 6.24-5.62a5.7 5.7 0 0 1 4.6 2.46c1-1.26 1.98-2.25 2.87-2.82A7.4 7.4 0 0 1 77.4 48Zm0 4c-.51 0-1.13.22-1.82.65-2.13 1.36-6.25 8.43-7.76 11.18a2.43 2.43 0 0 1-2.14 1.31c-1.54 0-2.75-1.53-.14-3.48 3.91-2.93 2.54-7.72.67-8.01a1.54 1.54 0 0 0-.24-.02c-1.7 0-2.45 2.93-2.45 2.93s-2.2 5.52-5.97 9.3c-3.78 3.77-3.98 6.8-1.22 10.83 1.87 2.75 5.47 3.58 9.15 3.58 3.82 0 7.73-.9 9.93-1.46.1-.03 13.45-3.8 11.76-7-.29-.54-.75-.76-1.34-.76-2.38 0-6.71 3.54-8.57 3.54-.42 0-.71-.17-.83-.6-.8-2.85 12.05-4.05 10.97-8.17-.19-.73-.7-1.02-1.44-1.02-3.14 0-10.2 5.53-11.68 5.53-.1 0-.19-.03-.23-.1-.74-1.2-.34-2.04 4.88-5.2 5.23-3.16 8.9-5.06 6.8-7.33-.23-.26-.57-.38-.98-.38-3.18 0-10.67 6.82-10.67 6.82s-2.02 2.1-3.24 2.1a.74.74 0 0 1-.68-.38c-.87-1.46 8.05-8.22 8.55-11.01.34-1.9-.24-2.85-1.31-2.85Z"></path><path fill="#FFD21E" d="M56.33 76.69c-2.75-4.04-2.56-7.07 1.22-10.84 3.77-3.77 5.97-9.3 5.97-9.3s.82-3.2 2.7-2.9c1.86.3 3.23 5.08-.68 8.01-3.92 2.93.78 4.92 2.28 2.17 1.51-2.75 5.63-9.82 7.76-11.18 2.13-1.35 3.64-.6 3.13 2.2-.5 2.79-9.42 9.55-8.55 11 .86 1.47 3.92-1.71 3.92-1.71s9.58-8.71 11.66-6.44c2.08 2.27-1.58 4.17-6.8 7.33-5.23 3.16-5.63 4-4.9 5.2.75 1.2 12.28-8.53 13.36-4.4 1.08 4.11-11.76 5.3-10.97 8.15.8 2.85 9.05-5.38 10.74-2.18 1.69 3.21-11.65 6.98-11.76 7.01-4.31 1.12-15.26 3.49-19.08-2.12Z"></path></svg>
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
<span>Transformers</span>
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
</div></a>
|
| 184 |
+
<a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?library=safetensors"><div class="tag tag-white "><svg class="text-black inline-block text-sm" viewBox="0 0 57 44" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet"><path d="M36.816 20.1474L54.9918 27.4409C55.5142 27.6506 55.9623 28.0112 56.2788 28.4766C56.5954 28.9421 56.7661 29.4913 56.7691 30.0542C56.7722 30.6171 56.6074 31.1682 56.2959 31.637C55.9844 32.1059 55.5402 32.4713 55.0201 32.6866L29.953 43.0646C29.2593 43.3518 28.4799 43.3518 27.7862 43.0646L2.71624 32.6894C2.19613 32.4741 1.75197 32.1087 1.44044 31.6398C1.12892 31.171 0.964165 30.62 0.967204 30.057C0.970244 29.4941 1.14094 28.9449 1.45751 28.4794C1.77408 28.014 2.22216 27.6534 2.74456 27.4437L21.2404 20.0227C22.2997 19.5979 25.6477 20.8441 28.8682 20.8555C32.3096 20.8668 35.6292 19.6715 36.816 20.1474ZM11.3042 30.1119L28.8682 37.3828L46.435 30.1119L28.8682 23.0619L11.3042 30.1119ZM29.9247 0.388251L54.9918 10.4462C55.5142 10.6559 55.9623 11.0165 56.2788 11.482C56.5954 11.9474 56.7661 12.4967 56.7691 13.0596C56.7722 13.6225 56.6074 14.1735 56.2959 14.6424C55.9844 15.1112 55.5402 15.4766 55.0201 15.6919L29.953 26.07C29.2593 26.3572 28.4799 26.3572 27.7862 26.07L2.71624 15.6948C2.19613 15.4795 1.75197 15.1141 1.44044 14.6452C1.12892 14.1763 0.964165 13.6253 0.967204 13.0624C0.970244 12.4995 1.14094 11.9503 1.45751 11.4848C1.77408 11.0193 2.22216 10.6588 2.74456 10.4491L27.8117 0.388251C28.4896 0.1157 29.2467 0.1157 29.9247 0.388251ZM11.3042 13.1172L28.8682 20.3881L46.435 13.1172L28.8682 6.06729L11.3042 13.1172Z" fill="currentColor"></path></svg>
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
<span>Safetensors</span>
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
</div></a>
|
| 192 |
+
|
| 193 |
+
<button class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" type="button"><div class="tag tag-white ">
|
| 194 |
+
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="text-green-600/80" preserveAspectRatio="xMidYMid meet" width="1em" height="1em" viewBox="0 0 10 10"><path fill-rule="evenodd" clip-rule="evenodd" d="M0.625 5C0.625 6.16032 1.08594 7.27312 1.90641 8.09359C2.72688 8.91406 3.83968 9.375 5 9.375C6.16032 9.375 7.27312 8.91406 8.09359 8.09359C8.91406 7.27312 9.375 6.16032 9.375 5C9.375 3.83968 8.91406 2.72688 8.09359 1.90641C7.27312 1.08594 6.16032 0.625 5 0.625C3.83968 0.625 2.72688 1.08594 1.90641 1.90641C1.08594 2.72688 0.625 3.83968 0.625 5ZM7.64365 7.48027C7.61734 7.50832 7.59054 7.53598 7.56326 7.56326C7.13828 7.98824 6.61864 8.2968 6.0539 8.46842C6.29802 8.11949 6.49498 7.64804 6.63475 7.09483C7.00845 7.18834 7.35014 7.3187 7.64365 7.48027ZM8.10076 6.87776C8.37677 6.42196 8.55005 5.90894 8.60556 5.37499H6.86808C6.85542 5.71597 6.82551 6.04557 6.77971 6.35841C7.25309 6.47355 7.68808 6.6414 8.062 6.85549C8.07497 6.86283 8.08789 6.87025 8.10076 6.87776ZM6.03795 6.22536C6.07708 5.95737 6.1044 5.67232 6.11705 5.37499H3.88295C3.89666 5.69742 3.92764 6.00542 3.9722 6.29287C4.37075 6.21726 4.79213 6.17749 5.224 6.17749C5.50054 6.17749 5.77294 6.19376 6.03795 6.22536ZM4.1261 7.02673C4.34894 7.84835 4.68681 8.375 5 8.375C5.32122 8.375 5.66839 7.82101 5.8908 6.963C5.67389 6.93928 5.45082 6.92699 5.224 6.92699C4.84316 6.92699 4.47332 6.96176 4.1261 7.02673ZM3.39783 7.21853C3.53498 7.71842 3.72038 8.14579 3.9461 8.46842C3.42141 8.30898 2.93566 8.03132 2.52857 7.65192C2.77253 7.48017 3.06711 7.33382 3.39783 7.21853ZM3.23916 6.48077C3.18263 6.13193 3.14625 5.76074 3.13192 5.37499H1.39444C1.4585 5.99112 1.67936 6.57938 2.03393 7.08403C2.3706 6.83531 2.78055 6.63162 3.23916 6.48077ZM1.39444 4.62499H3.13192C3.14615 4.24204 3.18211 3.87344 3.23794 3.52681C2.77814 3.37545 2.36731 3.17096 2.03024 2.92123C1.67783 3.42469 1.45828 4.011 1.39444 4.62499ZM2.5237 2.35262C2.76812 2.52552 3.06373 2.67281 3.39584 2.78875C3.53318 2.28573 3.71928 1.85578 3.9461 1.53158C3.41932 1.69166 2.93178 1.97089 2.5237 2.35262ZM3.97101 3.71489C3.92709 4.00012 3.89654 4.30547 3.88295 4.62499H6.11705C6.10453 4.33057 6.07761 4.04818 6.03909 3.78248C5.77372 3.81417 5.50093 3.83049 5.224 3.83049C4.79169 3.83049 4.3699 3.79065 3.97101 3.71489ZM5.8928 3.04476C5.67527 3.06863 5.45151 3.08099 5.224 3.08099C4.84241 3.08099 4.47186 3.04609 4.12405 2.98086C4.34686 2.1549 4.68584 1.625 5 1.625C5.32218 1.625 5.67048 2.18233 5.8928 3.04476ZM6.78083 3.6493C6.826 3.95984 6.85552 4.28682 6.86808 4.62499H8.60556C8.55029 4.09337 8.37827 3.58251 8.10436 3.1282C8.0903 3.1364 8.07618 3.14449 8.062 3.15249C7.68838 3.36641 7.25378 3.53417 6.78083 3.6493ZM7.64858 2.52499C7.35446 2.68754 7.0117 2.81868 6.63664 2.91268C6.49676 2.35623 6.29913 1.88209 6.0539 1.53158C6.61864 1.7032 7.13828 2.01176 7.56326 2.43674C7.59224 2.46572 7.62068 2.49514 7.64858 2.52499Z" fill="currentColor"></path></svg>
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
<span>24 languages</span>
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
</div></button>
|
| 202 |
+
<a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?other=llama"><div class="tag tag-white ">
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
<span>llama</span>
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
</div></a>
|
| 210 |
+
<a class="mb-1 mr-1 md:mb-1.5 md:mr-1.5 rounded-lg" href="/models?other=text-generation-inference"><div class="tag tag-white ">
|
| 211 |
+
<svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 12 12"><path fill="#23B0FF" d="m9.6 3.6-3.2-2a1 1 0 0 0-1.1 0L2 3.7a1 1 0 0 0-.3 1.6H10a1 1 0 0 0-.3-1.6Z"></path><path fill="#2094FF" d="m6.7 9.7 3.2-4.5-.4-.8H5.7v4.8l1 .5Z"></path><path fill="#6BCAFF" d="M4.9 9.7 1.7 5.2l.4-.8h3.8v4.8l-1 .5Z"></path><path fill="#000" fill-rule="evenodd" d="M9.9 3.2c.8.5 1 1.5.5 2.3L7 10c-.6.9-2 .9-2.6 0L1.3 5.5c-.5-.8-.3-1.8.5-2.3l3.2-2c.5-.3 1.2-.3 1.7 0l3.2 2ZM6.4 5h3l-3 4.2V5ZM5.3 5h-3l3 4.2V5Zm3.8-1L6 2a.5.5 0 0 0-.5 0L2.6 4H9Z" clip-rule="evenodd"></path></svg>
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
<span>text-generation-inference</span>
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
</div></a><div class="relative inline-block ">
|
| 219 |
+
<button class="group mr-1 mb-1 md:mr-1.5 md:mb-1.5 rounded-full rounded-br-none " type="button">
|
| 220 |
+
<div slot="button"><div class="tag rounded-full tag-white relative rounded-br-none pr-2.5">
|
| 221 |
+
<svg class="text-xs text-gray-900" width="1em" height="1em" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1.46009 5.0945V6.88125C1.46009 7.25201 1.75937 7.55129 2.13012 7.55129C2.50087 7.55129 2.80016 7.25201 2.80016 6.88125V5.0945C2.80016 4.72375 2.50087 4.42446 2.13012 4.42446C1.75937 4.42446 1.46009 4.72375 1.46009 5.0945ZM4.14022 5.0945V6.88125C4.14022 7.25201 4.4395 7.55129 4.81026 7.55129C5.18101 7.55129 5.48029 7.25201 5.48029 6.88125V5.0945C5.48029 4.72375 5.18101 4.42446 4.81026 4.42446C4.4395 4.42446 4.14022 4.72375 4.14022 5.0945ZM1.23674 9.78473H8.38377C8.75452 9.78473 9.0538 9.48545 9.0538 9.1147C9.0538 8.74395 8.75452 8.44466 8.38377 8.44466H1.23674C0.865993 8.44466 0.566711 8.74395 0.566711 9.1147C0.566711 9.48545 0.865993 9.78473 1.23674 9.78473ZM6.82036 5.0945V6.88125C6.82036 7.25201 7.11964 7.55129 7.49039 7.55129C7.86114 7.55129 8.16042 7.25201 8.16042 6.88125V5.0945C8.16042 4.72375 7.86114 4.42446 7.49039 4.42446C7.11964 4.42446 6.82036 4.72375 6.82036 5.0945ZM4.39484 0.623142L0.865993 2.48137C0.682851 2.57517 0.566711 2.76725 0.566711 2.97273C0.566711 3.28094 0.816857 3.53109 1.12507 3.53109H8.49991C8.80365 3.53109 9.0538 3.28094 9.0538 2.97273C9.0538 2.76725 8.93766 2.57517 8.75452 2.48137L5.22568 0.623142C4.9666 0.484669 4.65391 0.484669 4.39484 0.623142V0.623142Z" fill="currentColor"></path></svg>
|
| 222 |
+
|
| 223 |
+
<span class="-mr-1 text-gray-400">License:</span>
|
| 224 |
+
|
| 225 |
+
<span>cc-by-sa-4.0</span>
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
<div class="border-br-gray-200 absolute bottom-0.5 right-0.5 h-1 w-1 border-[3px] border-l-transparent border-t-transparent border-b-gray-200 border-r-gray-200 group-hover:border-b-gray-400 group-hover:border-r-gray-400 dark:border-b-gray-700 dark:border-r-gray-700 group-hover:dark:border-b-gray-400 group-hover:dark:border-r-gray-400"></div></div></div>
|
| 229 |
+
</button>
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
</div></div>
|
| 233 |
+
|
| 234 |
+
<div class="flex flex-col-reverse lg:flex-row lg:items-center lg:justify-between"><div class="-mb-px flex h-12 items-center overflow-x-auto overflow-y-hidden ">
|
| 235 |
+
<a class="tab-alternate" href="/kyutai/helium-1-2b"><svg class="mr-1.5 text-gray-400 flex-none" style="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-quaternary" d="M20.23 7.24L12 12L3.77 7.24a1.98 1.98 0 0 1 .7-.71L11 2.76c.62-.35 1.38-.35 2 0l6.53 3.77c.29.173.531.418.7.71z" opacity=".25" fill="currentColor"></path><path class="uim-tertiary" d="M12 12v9.5a2.09 2.09 0 0 1-.91-.21L4.5 17.48a2.003 2.003 0 0 1-1-1.73v-7.5a2.06 2.06 0 0 1 .27-1.01L12 12z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M20.5 8.25v7.5a2.003 2.003 0 0 1-1 1.73l-6.62 3.82c-.275.13-.576.198-.88.2V12l8.23-4.76c.175.308.268.656.27 1.01z" fill="currentColor"></path></svg>
|
| 236 |
+
Model card
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
</a><a class="tab-alternate active" href="/kyutai/helium-1-2b/tree/main"><svg class="mr-1.5 text-gray-400 flex-none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-tertiary" d="M21 19h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2zm0-4h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2zm0-8h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2zm0 4h-8a1 1 0 0 1 0-2h8a1 1 0 0 1 0 2z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M9 19a1 1 0 0 1-1-1V6a1 1 0 0 1 2 0v12a1 1 0 0 1-1 1zm-6-4.333a1 1 0 0 1-.64-1.769L3.438 12l-1.078-.898a1 1 0 0 1 1.28-1.538l2 1.667a1 1 0 0 1 0 1.538l-2 1.667a.999.999 0 0 1-.64.231z" fill="currentColor"></path></svg>
|
| 241 |
+
<span class="xl:hidden">Files</span>
|
| 242 |
+
<span class="hidden xl:inline">Files and versions</span>
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
<span class="inline-block "><span class="contents"><div slot="anchor" class="shadow-purple-500/10 ml-2 inline-flex -translate-y-px items-center gap-0.5 rounded-md border bg-white px-1 py-0.5 align-middle text-xs font-semibold leading-none text-gray-800 shadow-sm dark:border-gray-700 dark:bg-gradient-to-b dark:from-gray-925 dark:to-gray-925 dark:text-gray-300"><svg class="size-3 " xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 12 12"><path fill-rule="evenodd" clip-rule="evenodd" d="M6.14 3.64 5.1 4.92 2.98 2.28h2.06l1.1 1.36Zm0 4.72-1.1 1.36H2.98l2.13-2.64 1.03 1.28Zm4.9 1.36L8.03 6l3-3.72H8.96L5.97 6l3 3.72h2.06Z" fill="#7875FF"></path><path d="M4.24 6 2.6 8.03.97 6 2.6 3.97 4.24 6Z" fill="#FF7F41" opacity="1"></path></svg>
|
| 248 |
+
<span>xet</span>
|
| 249 |
+
</div></span>
|
| 250 |
+
</span>
|
| 251 |
+
</a><a class="tab-alternate" href="/kyutai/helium-1-2b/discussions"><svg class="mr-1.5 text-gray-400 flex-none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path><path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path></svg>
|
| 252 |
+
Community
|
| 253 |
+
<div class="ml-1.5 flex h-4 min-w-[1rem] items-center justify-center rounded px-1 text-xs leading-none shadow-sm bg-gray-200 text-gray-600 dark:bg-gray-900 dark:text-gray-500">1</div>
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
</a></div>
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
<div class="relative mb-1.5 flex flex-wrap gap-1.5 sm:flex-nowrap lg:mb-0"><div class="order-last sm:order-first"><div class="relative ">
|
| 262 |
+
<button class="btn px-1.5 py-1.5 " type="button">
|
| 263 |
+
|
| 264 |
+
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="p-0.5" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><circle cx="16" cy="7" r="3" fill="currentColor"></circle><circle cx="16" cy="16" r="3" fill="currentColor"></circle><circle cx="16" cy="25" r="3" fill="currentColor"></circle></svg>
|
| 265 |
+
|
| 266 |
+
</button>
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
</div></div>
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
<div class="flex-none w-full sm:w-auto"><div class="relative ">
|
| 285 |
+
<button class="text-sm btn btn w-full cursor-pointer text-sm" type="button">
|
| 286 |
+
<svg class="mr-1.5 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><rect x="6.34" y="19" width="11.31" height="2" transform="translate(-10.63 14.34) rotate(-45)"></rect><path d="M17,30a1,1,0,0,1-.37-.07,1,1,0,0,1-.62-.79l-1-7,2-.28.75,5.27L21,24.52V17a1,1,0,0,1,.29-.71l4.07-4.07A8.94,8.94,0,0,0,28,5.86V4H26.14a8.94,8.94,0,0,0-6.36,2.64l-4.07,4.07A1,1,0,0,1,15,11H7.48L4.87,14.26l5.27.75-.28,2-7-1a1,1,0,0,1-.79-.62,1,1,0,0,1,.15-1l4-5A1,1,0,0,1,7,9h7.59l3.77-3.78A10.92,10.92,0,0,1,26.14,2H28a2,2,0,0,1,2,2V5.86a10.92,10.92,0,0,1-3.22,7.78L23,17.41V25a1,1,0,0,1-.38.78l-5,4A1,1,0,0,1,17,30Z"></path></svg>
|
| 287 |
+
Deploy
|
| 288 |
+
<svg class="-mr-1 text-gray-500 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path d="M16.293 9.293L12 13.586L7.707 9.293l-1.414 1.414L12 16.414l5.707-5.707z" fill="currentColor"></path></svg></button>
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
</div>
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
</div>
|
| 303 |
+
<div class="relative flex-auto sm:flex-none">
|
| 304 |
+
<button class="from-gray-800! to-black! text-white! gap-1! border-gray-800! dark:border-gray-900! btn w-full cursor-pointer text-sm" type="button">
|
| 305 |
+
<svg class="mr-1.5 mr-0.5! " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path fill="currentColor" d="M28 4H4a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h8v4H8v2h16v-2h-4v-4h8a2 2 0 0 0 2-2V6a2 2 0 0 0-2-2ZM18 28h-4v-4h4Zm10-6H4V6h24Z"></path></svg>
|
| 306 |
+
Use this model
|
| 307 |
+
<svg class="-mr-1 text-gray-500 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path d="M16.293 9.293L12 13.586L7.707 9.293l-1.414 1.414L12 16.414l5.707-5.707z" fill="currentColor"></path></svg></button>
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
</div>
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
</div>
|
| 315 |
+
</div></div></header>
|
| 316 |
+
</div>
|
| 317 |
+
|
| 318 |
+
<div class="container relative flex flex-col md:grid md:space-y-0 w-full md:grid-cols-12 space-y-4 md:gap-6 mb-16"><section class="pt-8 border-gray-100 col-span-full"><div class="SVELTE_HYDRATER contents" data-target="ViewerHeader" data-props="{"context":{"repo":{"name":"kyutai/helium-1-2b","type":"model"},"rev":"main","path":"tokenizer.model","subpaths":[{"dir":"tokenizer.model"}]},"refs":{"branches":[{"name":"main","ref":"refs/heads/main","targetCommit":"5764947fc2e782982c24f363eacd9baea3e821f8"}],"tags":[],"converts":[]},"view":"blob","isMac":false}"><header class="flex flex-wrap items-center justify-start pb-2 md:justify-end lg:flex-nowrap"><div class="grow max-md:flex max-md:w-full max-md:items-start max-md:justify-between"><div class="relative mr-4 flex min-w-0 basis-auto flex-wrap items-center gap-x-3 md:grow md:basis-full lg:basis-auto lg:flex-nowrap"><div class="relative mb-2">
|
| 319 |
+
<button class="text-sm md:text-base btn w-full cursor-pointer text-sm" type="button">
|
| 320 |
+
<svg class="mr-1.5 text-gray-700 dark:text-gray-400" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24" style="transform: rotate(360deg);"><path d="M13 14c-3.36 0-4.46 1.35-4.82 2.24C9.25 16.7 10 17.76 10 19a3 3 0 0 1-3 3a3 3 0 0 1-3-3c0-1.31.83-2.42 2-2.83V7.83A2.99 2.99 0 0 1 4 5a3 3 0 0 1 3-3a3 3 0 0 1 3 3c0 1.31-.83 2.42-2 2.83v5.29c.88-.65 2.16-1.12 4-1.12c2.67 0 3.56-1.34 3.85-2.23A3.006 3.006 0 0 1 14 7a3 3 0 0 1 3-3a3 3 0 0 1 3 3c0 1.34-.88 2.5-2.09 2.86C17.65 11.29 16.68 14 13 14m-6 4a1 1 0 0 0-1 1a1 1 0 0 0 1 1a1 1 0 0 0 1-1a1 1 0 0 0-1-1M7 4a1 1 0 0 0-1 1a1 1 0 0 0 1 1a1 1 0 0 0 1-1a1 1 0 0 0-1-1m10 2a1 1 0 0 0-1 1a1 1 0 0 0 1 1a1 1 0 0 0 1-1a1 1 0 0 0-1-1z" fill="currentColor"></path></svg>
|
| 321 |
+
main
|
| 322 |
+
<svg class="-mr-1 text-gray-500 " xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path d="M16.293 9.293L12 13.586L7.707 9.293l-1.414 1.414L12 16.414l5.707-5.707z" fill="currentColor"></path></svg></button>
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
</div>
|
| 326 |
+
<div class="relative mb-2 flex flex-wrap items-center"><a class="truncate text-gray-800 hover:underline" href="/kyutai/helium-1-2b/tree/main">helium-1-2b</a>
|
| 327 |
+
<span class="mx-1 text-gray-300">/</span>
|
| 328 |
+
<span class="dark:text-gray-300">tokenizer.model</span>
|
| 329 |
+
<button class="text-xs ml-2 focus:outline-hidden inline-flex cursor-pointer items-center text-sm mx-0.5 text-gray-600 " title="Copy path" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
|
| 330 |
+
</button></div>
|
| 331 |
+
</div>
|
| 332 |
+
</div>
|
| 333 |
+
|
| 334 |
+
</header></div>
|
| 335 |
+
<div class="SVELTE_HYDRATER contents" data-target="LastCommit" data-props="{"commitLast":{"date":"2025-04-30T14:01:50.000Z","verified":"verified","subject":"Upload tokenizer.model with huggingface_hub","authors":[{"_id":"6355a3c1805be5a8f30fea49","avatar":"https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/ONMEctCWAeAgF2eZ307si.jpeg","isHf":false,"user":"lmz"}],"commit":{"id":"b8d50a6775dfd77d956b7cd18928736dccd17fe7","parentIds":["b3d4f57a13777182735134b6aaf4b610767cd08c"]},"title":"Upload tokenizer.model with huggingface_hub"},"repo":{"name":"kyutai/helium-1-2b","type":"model"}}"><div class="from-gray-100-to-white bg-linear-to-t flex flex-wrap items-baseline gap-y-1 rounded-t-lg border border-b-0 px-3 py-2 dark:border-gray-800"><img class="mr-2.5 mt-0.5 h-4 w-4 self-center rounded-full" alt="lmz's picture" src="https://cdn-avatars.huggingface.co/v1/production/uploads/6355a3c1805be5a8f30fea49/ONMEctCWAeAgF2eZ307si.jpeg">
|
| 336 |
+
<div class="mr-4 flex flex-none items-center truncate"><a class="hover:underline" href="/lmz">lmz
|
| 337 |
+
</a>
|
| 338 |
+
|
| 339 |
+
</div>
|
| 340 |
+
<div class="mr-4 truncate font-mono text-xs text-gray-500 hover:prose-a:underline sm:text-sm"><!-- HTML_TAG_START -->Upload tokenizer.model with huggingface_hub<!-- HTML_TAG_END --></div>
|
| 341 |
+
<a class="rounded-sm border bg-gray-50 px-1.5 text-sm hover:underline dark:border-gray-800 dark:bg-gray-900" href="/kyutai/helium-1-2b/commit/b8d50a6775dfd77d956b7cd18928736dccd17fe7">b8d50a6</a>
|
| 342 |
+
<span class="mx-2 text-green-500 dark:text-green-600 px-1.5 border-green-100 dark:border-green-800 rounded-full border text-xs uppercase" title="This commit is signed and the signature is verified">verified</span>
|
| 343 |
+
<time class="ml-auto hidden flex-none truncate pl-2 text-gray-500 dark:text-gray-400 lg:block" datetime="2025-04-30T14:01:50" title="Wed, 30 Apr 2025 14:01:50 GMT">7 months ago</time></div></div>
|
| 344 |
+
<div class="relative flex flex-wrap items-center border px-3 py-1.5 text-sm text-gray-800 dark:border-gray-800 dark:bg-gray-900 ">
|
| 345 |
+
<a class="group my-1 mr-4 flex items-center " download="" href="/kyutai/helium-1-2b/resolve/main/tokenizer.model?download=true"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" viewBox="0 0 32 32"><path fill="currentColor" d="M26 24v4H6v-4H4v4a2 2 0 0 0 2 2h20a2 2 0 0 0 2-2v-4zm0-10l-1.41-1.41L17 20.17V2h-2v18.17l-7.59-7.58L6 14l10 10l10-10z"></path></svg>
|
| 346 |
+
download</span>
|
| 347 |
+
|
| 348 |
+
</a><div class="SVELTE_HYDRATER contents" data-target="CopyButton" data-props="{"value":"https://huggingface.co/kyutai/helium-1-2b/resolve/main/tokenizer.model","style":"blank","label":"Copy download link","classNames":"my-1 mr-4 flex items-center no-underline hover:underline"}"><button class="my-1 mr-4 flex items-center no-underline hover:underline " title="Copy download link" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
|
| 349 |
+
<span class="ml-1.5 ">Copy download link</span></button></div><a class="group my-1 mr-4 flex items-center " href="/kyutai/helium-1-2b/commits/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32" style="transform: rotate(360deg);"><path d="M16 4C9.383 4 4 9.383 4 16s5.383 12 12 12s12-5.383 12-12S22.617 4 16 4zm0 2c5.535 0 10 4.465 10 10s-4.465 10-10 10S6 21.535 6 16S10.465 6 16 6zm-1 2v9h7v-2h-5V8z" fill="currentColor"></path></svg>
|
| 350 |
+
history</span>
|
| 351 |
+
|
| 352 |
+
</a><a class="group my-1 mr-4 flex items-center " href="/kyutai/helium-1-2b/blame/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32" style="transform: rotate(360deg);"><path d="M16 2a14 14 0 1 0 14 14A14 14 0 0 0 16 2zm0 26a12 12 0 1 1 12-12a12 12 0 0 1-12 12z" fill="currentColor"></path><path d="M11.5 11a2.5 2.5 0 1 0 2.5 2.5a2.48 2.48 0 0 0-2.5-2.5z" fill="currentColor"></path><path d="M20.5 11a2.5 2.5 0 1 0 2.5 2.5a2.48 2.48 0 0 0-2.5-2.5z" fill="currentColor"></path></svg>
|
| 353 |
+
blame</span>
|
| 354 |
+
|
| 355 |
+
</a><a class="group my-1 mr-4 flex items-center text-green-600 dark:text-green-500" href="/kyutai/helium-1-2b/edit/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M2 26h28v2H2z" fill="currentColor"></path><path d="M25.4 9c.8-.8.8-2 0-2.8l-3.6-3.6c-.8-.8-2-.8-2.8 0l-15 15V24h6.4l15-15zm-5-5L24 7.6l-3 3L17.4 7l3-3zM6 22v-3.6l10-10l3.6 3.6l-10 10H6z" fill="currentColor"></path></svg>
|
| 356 |
+
contribute</span>
|
| 357 |
+
|
| 358 |
+
</a><a class="group my-1 mr-4 flex items-center " href="/kyutai/helium-1-2b/delete/main/tokenizer.model"><span class="flex items-center group-hover:underline"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M12 12h2v12h-2z" fill="currentColor"></path><path d="M18 12h2v12h-2z" fill="currentColor"></path><path d="M4 6v2h2v20a2 2 0 0 0 2 2h16a2 2 0 0 0 2-2V8h2V6zm4 22V8h16v20z" fill="currentColor"></path><path d="M12 2h8v2h-8z" fill="currentColor"></path></svg>
|
| 359 |
+
delete</span>
|
| 360 |
+
|
| 361 |
+
</a>
|
| 362 |
+
|
| 363 |
+
<div class="mr-4 flex items-center"><div class="SVELTE_HYDRATER contents" data-target="ScanStatusBadge" data-props="{"classNames":"mr-2","scanStatus":{"status":"safe","protectAiScan":{"status":"unscanned","message":null,"reportLink":"https://protectai.com/insights/models/kyutai/helium-1-2b/5764947fc2e782982c24f363eacd9baea3e821f8/files?blob-id=48679a193304ea9e6dda4c4de9be4d4db590c249&utm_source=huggingface"},"avScan":{"status":"safe","message":"No security issues detected","reportLink":"https://fdtn.ai/ai-supply-chain/hugging-face?utm_source=huggingface","reportLabel":"Learn more at Cisco Foundation AI"},"pickleImportScan":{"status":"unscanned","pickleImports":[],"version":"0.0.0"},"virusTotalScan":{"status":"safe","message":"0/76 engines detect it as malicious.","reportLink":"https://www.virustotal.com/gui/file/abb8879fdb2001dfae68d0bbdccbe92ae1593bad518abb34c9513f27904ee303?utm_source=huggingface","reportLabel":"See more details on VirusTotal"},"jFrogScan":{"status":"unscanned","message":"Not a machine-learning model","reportLink":"","reportLabel":""}},"repo":{"name":"kyutai/helium-1-2b","type":"model"},"revision":"main","filePath":"tokenizer.model","openByDefault":false}"><div class="sm:relative mr-2"><button class="flex h-[1.125rem] select-none items-center gap-0.5 rounded border pl-0.5 pr-0.5 text-xs leading-tight text-gray-400 hover:cursor-pointer text-gray-400 hover:border-gray-200 hover:bg-gray-50 hover:text-gray-500 dark:border-gray-800 dark:hover:bg-gray-800 dark:hover:text-gray-200 "><svg class="flex-none" width="1em" height="1em" viewBox="0 0 22 28" fill="none" xmlns="http://www.w3.org/2000/svg"><path fill-rule="evenodd" clip-rule="evenodd" d="M15.3634 10.3639C15.8486 10.8491 15.8486 11.6357 15.3634 12.1209L10.9292 16.5551C10.6058 16.8785 10.0814 16.8785 9.7579 16.5551L7.03051 13.8277C6.54532 13.3425 6.54532 12.5558 7.03051 12.0707C7.51569 11.5855 8.30234 11.5855 8.78752 12.0707L9.7579 13.041C10.0814 13.3645 10.6058 13.3645 10.9292 13.041L13.6064 10.3639C14.0916 9.8787 14.8782 9.8787 15.3634 10.3639Z" fill="currentColor"></path><path fill-rule="evenodd" clip-rule="evenodd" d="M10.6666 27.12C4.93329 25.28 0 19.2267 0 12.7867V6.52001C0 5.40001 0.693334 4.41334 1.73333 4.01334L9.73333 1.01334C10.3333 0.786673 11 0.786673 11.6 1.02667L19.6 4.02667C20.1083 4.21658 20.5465 4.55701 20.8562 5.00252C21.1659 5.44803 21.3324 5.97742 21.3333 6.52001V12.7867C21.3333 19.24 16.4 25.28 10.6666 27.12Z" fill="currentColor" fill-opacity="0.22"></path><path d="M10.0845 1.94967L10.0867 1.94881C10.4587 1.8083 10.8666 1.81036 11.2286 1.95515L11.2387 1.95919L11.2489 1.963L19.2489 4.963L19.25 4.96342C19.5677 5.08211 19.8416 5.29488 20.0351 5.57333C20.2285 5.85151 20.3326 6.18203 20.3333 6.52082C20.3333 6.52113 20.3333 6.52144 20.3333 6.52176L20.3333 12.7867C20.3333 18.6535 15.8922 24.2319 10.6666 26.0652C5.44153 24.2316 1 18.6409 1 12.7867V6.52001C1 5.82357 1.42893 5.20343 2.08883 4.94803L10.0845 1.94967Z" stroke="currentColor" stroke-opacity="0.30" stroke-width="2"></path></svg>
|
| 364 |
+
|
| 365 |
+
<span class="mr-0.5 max-sm:hidden">Safe</span></button>
|
| 366 |
+
|
| 367 |
+
</div></div>
|
| 368 |
+
</div>
|
| 369 |
+
|
| 370 |
+
<div class="flex items-center gap-x-3 dark:text-gray-300 sm:ml-auto">
|
| 371 |
+
1.14 MB</div></div>
|
| 372 |
+
|
| 373 |
+
<div class="relative min-h-[100px] rounded-b-lg border border-t-0 leading-tight dark:border-gray-800 dark:bg-gray-925">
|
| 374 |
+
<div class="p-4 py-8 text-center">This file is stored with
|
| 375 |
+
<a class="underline" href="https://huggingface.co/docs/hub/xet/index">Xet</a>
|
| 376 |
+
. It is too big to display, but you can still
|
| 377 |
+
<a download class="underline" href="/kyutai/helium-1-2b/resolve/main/tokenizer.model">download</a>
|
| 378 |
+
it.
|
| 379 |
+
</div>
|
| 380 |
+
<div class="bg-linear-to-br from-gray-50-to-white relative border-t p-4"><div class="text-smd mb-2 flex items-baseline"><h3 class="font-semibold">Large File Pointer Details</h3>
|
| 381 |
+
<span class="ml-2">(</span>
|
| 382 |
+
<a href="/kyutai/helium-1-2b/raw/main/tokenizer.model" class="flex items-center underline decoration-gray-400 hover:decoration-gray-700 dark:decoration-gray-500 dark:hover:decoration-gray-300" target="_blank"><svg class="mr-0.5 text-xs" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M25.7 9.3l-7-7A.908.908 0 0 0 18 2H8a2.006 2.006 0 0 0-2 2v24a2.006 2.006 0 0 0 2 2h16a2.006 2.006 0 0 0 2-2V10a.908.908 0 0 0-.3-.7zM18 4.4l5.6 5.6H18zM24 28H8V4h8v6a2.006 2.006 0 0 0 2 2h6z" fill="currentColor"></path></svg> Raw pointer file
|
| 383 |
+
</a>
|
| 384 |
+
<span>)</span></div>
|
| 385 |
+
<dl class="break-words font-mono text-[0.8rem]"><div class="mr-1 flex md:mb-1"><dt class="mr-1.5 font-semibold">SHA256:</dt>
|
| 386 |
+
<dd class="truncate">abb8879fdb2001dfae68d0bbdccbe92ae1593bad518abb34c9513f27904ee303</dd>
|
| 387 |
+
</div><div class="flex flex-wrap"><dt class="mr-1.5 font-semibold">Pointer size:</dt>
|
| 388 |
+
<dd>132 Bytes</dd>
|
| 389 |
+
|
| 390 |
+
<div class="px-1.5 opacity-40">·</div>
|
| 391 |
+
|
| 392 |
+
<dt class="mr-1.5 font-semibold">Size of remote file:</dt>
|
| 393 |
+
<dd>1.14 MB</dd>
|
| 394 |
+
|
| 395 |
+
<div class="px-1.5 opacity-40">·</div>
|
| 396 |
+
<dt class="mr-1.5 font-semibold">Xet hash:</dt>
|
| 397 |
+
<dd class="truncate">bca9ac44cc00b884e9fb49abf7fa3576e32e568e788ff68ce6cee89d24e2b8e4</dd></div></dl>
|
| 398 |
+
<p class="mt-2 text-sm text-gray-500">Xet efficiently stores Large Files inside Git, intelligently splitting files into unique chunks and
|
| 399 |
+
accelerating uploads and downloads.
|
| 400 |
+
<a class="underline" href="/join/xet" target="_blank">More info</a>.</p></div>
|
| 401 |
+
</div></section></div></main>
|
| 402 |
+
|
| 403 |
+
</div>
|
| 404 |
+
<script>
|
| 405 |
+
import("\/front\/build\/kube-02d86c8\/index.js"); window.moonSha = "kube-02d86c8\/"; window.__hf_deferred =
|
| 406 |
+
{};
|
| 407 |
+
</script>
|
| 408 |
+
<!-- Stripe -->
|
| 409 |
+
<script>
|
| 410 |
+
if (["hf.co", "huggingface.co"].includes(window.location.hostname)) {
|
| 411 |
+
const script = document.createElement("script");
|
| 412 |
+
script.src = "https://js.stripe.com/v3/";
|
| 413 |
+
script.async = true;
|
| 414 |
+
document.head.appendChild(script);
|
| 415 |
+
}
|
| 416 |
+
</script>
|
| 417 |
+
</body>
|
| 418 |
+
</html>
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 3 |
+
"additional_special_tokens": [
|
| 4 |
+
"<|im_sp_00|>",
|
| 5 |
+
"<|im_sp_01|>",
|
| 6 |
+
"<|im_sp_02|>",
|
| 7 |
+
"<|im_sp_94|>",
|
| 8 |
+
"<|im_sp_95|>",
|
| 9 |
+
"<|im_sp_96|>",
|
| 10 |
+
"<|im_sp_97|>",
|
| 11 |
+
"<|im_sp_98|>",
|
| 12 |
+
"<|im_sp_99|>"
|
| 13 |
+
]
|
| 14 |
+
}
|
utils.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=protected-access
|
| 2 |
+
"""Utils to handle CASA layers construction"""
|
| 3 |
+
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from dataclasses import dataclass, fields
|
| 6 |
+
from typing import Any, Callable, Generic, TypeVar
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def delta_w_factory(
|
| 12 |
+
org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
|
| 13 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 14 |
+
"""Factory for building linear op where the weights are the sum of two layers' weights"""
|
| 15 |
+
|
| 16 |
+
def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
nonlocal org_lin, new_lin
|
| 18 |
+
bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
|
| 19 |
+
return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)
|
| 20 |
+
|
| 21 |
+
return _delta_w_fwd
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class StreamingState:
|
| 26 |
+
"""Streaming State used by CASA layers at inference to save
|
| 27 |
+
e.g. the offset, the KV Cache and other persistent states"""
|
| 28 |
+
|
| 29 |
+
offset: int = 0
|
| 30 |
+
|
| 31 |
+
def _is_valid_field(self, key: str) -> bool:
|
| 32 |
+
return key in {x.name for x in fields(self)}
|
| 33 |
+
|
| 34 |
+
def _init_field(self, key: str) -> None:
|
| 35 |
+
"""Init function for non-arggment dependent defauls"""
|
| 36 |
+
assert self._is_valid_field(key)
|
| 37 |
+
if key == "offset":
|
| 38 |
+
self.offset = 0
|
| 39 |
+
else:
|
| 40 |
+
# for fields which should be set explicitly and cannot be auto-initialized
|
| 41 |
+
setattr(self, key, None)
|
| 42 |
+
|
| 43 |
+
def init(self) -> None:
|
| 44 |
+
for key in [x.name for x in fields(self)]:
|
| 45 |
+
self._init_field(key)
|
| 46 |
+
|
| 47 |
+
def _reset_field(self, name: str) -> None:
|
| 48 |
+
"""Resets the given field"""
|
| 49 |
+
self._init_field(name)
|
| 50 |
+
|
| 51 |
+
def reset(self) -> None:
|
| 52 |
+
for f in fields(self):
|
| 53 |
+
self._reset_field(f.name)
|
| 54 |
+
|
| 55 |
+
def _get_field(self, f: str) -> Any:
|
| 56 |
+
"""Get field and init if not"""
|
| 57 |
+
assert self._is_valid_field(f)
|
| 58 |
+
if getattr(self, f) is None:
|
| 59 |
+
self._init_field(f)
|
| 60 |
+
return getattr(self, f)
|
| 61 |
+
|
| 62 |
+
def _set_field(self, f: str, value: Any) -> None:
|
| 63 |
+
assert self._is_valid_field(f)
|
| 64 |
+
setattr(self, f, value)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method
|
| 71 |
+
"""Overrides Audiocraft's Streaming modules with additional small utils"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, state_class: type) -> None:
|
| 74 |
+
torch.nn.Module.__init__(self)
|
| 75 |
+
self.is_streaming: bool = False
|
| 76 |
+
self.enable_viz: tuple[str, ...] = ()
|
| 77 |
+
self._streaming_state: StreamingStateT = state_class()
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def streaming_state(self) -> StreamingStateT:
|
| 81 |
+
return self._streaming_state
|
| 82 |
+
|
| 83 |
+
def _apply_named_streaming(self, fn: Callable):
|
| 84 |
+
"""Apply function to all streaming modules"""
|
| 85 |
+
for name, module in self.named_modules():
|
| 86 |
+
if isinstance(module, StreamingModule):
|
| 87 |
+
fn(name, module)
|
| 88 |
+
|
| 89 |
+
def reset_streaming(self):
|
| 90 |
+
"""Reset the streaming state."""
|
| 91 |
+
|
| 92 |
+
def _reset(_: str, module: StreamingModule):
|
| 93 |
+
module._streaming_state.reset()
|
| 94 |
+
|
| 95 |
+
self._apply_named_streaming(_reset)
|
| 96 |
+
|
| 97 |
+
def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
|
| 98 |
+
"""Set all streaming modules in streaming mode"""
|
| 99 |
+
|
| 100 |
+
def _set_streaming(_, module: StreamingModule) -> None:
|
| 101 |
+
module.is_streaming = streaming
|
| 102 |
+
module.enable_viz = viz
|
| 103 |
+
if streaming:
|
| 104 |
+
module.streaming_state.init()
|
| 105 |
+
|
| 106 |
+
self._apply_named_streaming(_set_streaming)
|
| 107 |
+
|
| 108 |
+
@contextmanager
|
| 109 |
+
def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
|
| 110 |
+
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
| 111 |
+
self._set_streaming(stream, viz)
|
| 112 |
+
try:
|
| 113 |
+
yield
|
| 114 |
+
finally:
|
| 115 |
+
self._set_streaming(False, ())
|
| 116 |
+
self.reset_streaming()
|