gregtatum commited on
Commit
c133d59
·
1 Parent(s): 08d598e
Files changed (1) hide show
  1. scripts/build_models.py +238 -32
scripts/build_models.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import numpy as np
2
  from zstandard import ZstdCompressor
3
  from pathlib import Path
@@ -5,62 +8,265 @@ import io
5
  from sentence_transformers import SentenceTransformer
6
  from torch.nn import EmbeddingBag
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def save_data(path: Path, tensor: torch.Tensor):
11
- """Writes out the static embeddings to a .npy.zst file"""
12
- assert str(path).endswith(".npy.zst")
13
  buffer = io.BytesIO()
14
- np.save(buffer, tensor.detach().numpy())
15
 
16
- with (
17
- open(path, "wb") as outfile,
18
- ZstdCompressor().stream_writer(outfile) as writer,
19
- ):
20
- writer.write(buffer.getvalue())
21
 
 
 
 
22
 
23
- model_path = Path("model")
24
- model_name = "sentence-transformers/static-similarity-mrl-multilingual-v1"
25
- vocab_size = 105_879
26
- dimensions = 1024
27
 
28
 
29
- def load_embeddings():
30
- model = SentenceTransformer(model_name, device="cpu")
31
- embedding_bag: EmbeddingBag = model[0].embedding # type: ignore
32
- embeddings = torch.Tensor(embedding_bag.weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- print(embeddings.shape)
35
- assert embeddings.shape == torch.Size([vocab_size, dimensions])
 
 
36
 
37
- print("float32")
38
- print(f" 1024 dim - {embeddings.shape[0] * 1024 * 4 / 1024 / 1024:,.1f} MiB")
39
- print(f" 512 dim - {embeddings.shape[0] * 512 * 4 / 1024 / 1024:,.1f} MiB")
40
- print(f" 256 dim - {embeddings.shape[0] * 256 * 4 / 1024 / 1024:,.1f} MiB")
 
 
 
41
 
42
- print("float16")
43
- print(f" 1024 dim - {embeddings.shape[0] * 1024 * 2 / 1024 / 1024:,.1f} MiB")
44
- print(f" 512 dim - {embeddings.shape[0] * 512 * 2 / 1024 / 1024:,.1f} MiB")
45
- print(f" 256 dim - {embeddings.shape[0] * 256 * 2 / 1024 / 1024:,.1f} MiB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  for dim in (1024, 512, 384, 256, 128):
 
 
 
 
48
  truncated = embeddings[:, :dim]
49
  assert truncated.shape == torch.Size([vocab_size, dim])
50
 
51
- save_data(model_path / f"static-embeddings.{dim}.fp32.npy.zst", embeddings)
 
 
 
 
52
  save_data(
53
- model_path / f"static-embeddings.{dim}.fp16.npy.zst",
54
- embeddings.to(dtype=torch.float16),
55
  )
56
  save_data(
57
- model_path / f"static-embeddings.{dim}.int8.npy.zst",
58
- embeddings.to(dtype=torch.int8),
59
  )
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def main() -> None:
63
- load_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  if __name__ == "__main__":
 
1
+ import shutil
2
+ from textwrap import dedent
3
+ from typing import Any
4
  import numpy as np
5
  from zstandard import ZstdCompressor
6
  from pathlib import Path
 
8
  from sentence_transformers import SentenceTransformer
9
  from torch.nn import EmbeddingBag
10
  import torch
11
+ from model2vec import StaticModel
12
+ from tokenizers import Tokenizer
13
+
14
+ models_path = Path("models")
15
+
16
+
17
+ def zst_compress_file(input: Path):
18
+ cctx = ZstdCompressor()
19
+ output = input.parent / f"{input.name}.zst"
20
+ print(f"Compressing {output}")
21
+ with open(input, "rb") as fin, open(output, "wb") as fout:
22
+ cctx.copy_stream(fin, fout)
23
 
24
 
25
  def save_data(path: Path, tensor: torch.Tensor):
26
+ """Writes out the static embeddings to a .npy and .npy.zst file"""
 
27
  buffer = io.BytesIO()
 
28
 
29
+ if tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
30
+ # Store as the raw bytes.
31
+ np.save(buffer, tensor.detach().view(torch.uint8).numpy())
32
+ else:
33
+ np.save(buffer, tensor.detach().numpy())
34
 
35
+ print(f"Saving {path}")
36
+ with (open(path, "wb") as outfile,):
37
+ outfile.write(buffer.getvalue())
38
 
39
+ zst_compress_file(path)
 
 
 
40
 
41
 
42
+ def quantization_loss_mse(tensor: torch.Tensor, dtype: torch.dtype):
43
+ """
44
+ Compute reconstruction loss when converting embeddings to a datatype and back using
45
+ the mean squared error, which punishes big errors more than small ones.
46
+ """
47
+
48
+ # Original → quantize → dequantize
49
+ roundtrip = tensor.detach().to(dtype).to(tensor.dtype)
50
+
51
+ # Mean squared error
52
+ return torch.mean((tensor - roundtrip) ** 2).item()
53
+
54
+
55
+ def quantization_loss_mae(tensor: torch.Tensor, dtype: torch.dtype):
56
+ """
57
+ Compute reconstruction loss when converting embeddings to a datatype and back using
58
+ the mean absolute error, which is less sensitive to outliers than MSE.
59
+ """
60
+
61
+ # Original → quantize → dequantize
62
+ roundtrip = tensor.detach().to(dtype).to(tensor.dtype)
63
+
64
+ # Mean absolute error
65
+ return torch.mean(torch.abs(tensor - roundtrip)).item()
66
+
67
+
68
+ def quantization_loss_cosine(tensor: torch.Tensor, dtype: torch.dtype):
69
+ """
70
+ Compute reconstruction loss when converting embeddings to a datatype and back using
71
+ cosine similarity. This measures whether the embedding directions are preserved
72
+ after quantization, independent of their magnitudes.
73
+ """
74
+
75
+ # Original → quantize → dequantize
76
+ roundtrip = tensor.detach().to(dtype).to(tensor.dtype)
77
+
78
+ # Flatten both to 2D (num_vectors, dimensions) in case tensor is 1D or higher-D
79
+ if tensor.ndim == 1:
80
+ orig = tensor.unsqueeze(0)
81
+ recon = roundtrip.unsqueeze(0)
82
+ else:
83
+ orig = tensor.view(tensor.shape[0], -1)
84
+ recon = roundtrip.view(roundtrip.shape[0], -1)
85
+
86
+ # Cosine similarity per vector, then average
87
+ cos = torch.nn.functional.cosine_similarity(orig, recon, dim=1)
88
+ return cos.mean().item()
89
+
90
+
91
+ def export_embeddings(
92
+ hf_org: str, hf_repo: str, model_path: Path, embeddings: torch.Tensor
93
+ ) -> None:
94
+ vocab_size, dimensions = embeddings.shape
95
+
96
+ # This logic can always be adjusted for models with different shapes.
97
+ assert (
98
+ embeddings.dtype == torch.float32
99
+ ), f"The embeddings {embeddings.dtype} are assumed to be float32."
100
+ assert (
101
+ dimensions <= 1024
102
+ ), f"The embedding {dimensions} dimension is assumed to be at most 1024."
103
+
104
+ norms = torch.norm(embeddings, dim=1) # shape: [vocab_size]
105
+
106
+ print(f" - vocab size {vocab_size:,.0f}")
107
+ print(f" - embedding dimension {dimensions:,.0f}")
108
+ print(f" - vector length (mean): {norms.mean().item():.2f}")
109
+ print(f" - vector length (median): {norms.median().item():.2f}")
110
+ print(f" - stddev ±{norms.std().item():.2f}")
111
+ print(f" - value (mean): {embeddings.mean().item():.2f}")
112
+ print(f" - value (median): {embeddings.median().item():.2f}")
113
+ print(f" - stddev ±{embeddings.std().item():.2f}")
114
+
115
+ model_path.mkdir(exist_ok=True, parents=True)
116
+
117
+ with (model_path / "README.md").open("wt") as file:
118
+ file.write(
119
+ dedent(
120
+ f"""
121
+ # [{hf_org}/{hf_repo}](https://huggingface.co/{hf_org}/{hf_repo})
122
+
123
+ Beyond the vocab size and embedding size, these are stats for the length
124
+ of the vectors and the distribution of the values.
125
+
126
+ | item | metric | value |
127
+ | --------------| ----------------------- | ----- |
128
+ | vocab | size | {vocab_size:,.0f} |
129
+ | embedding | dimensions | {dimensions:,.0f} |
130
+ | vector length | mean | {norms.mean().item():.2f} |
131
+ | vector length | median | {norms.median().item():.2f} |
132
+ | vector length | stddev | {norms.std().item():.2f} |
133
+ | values | mean | {embeddings.mean().item():.2f} |
134
+ | values | median | {embeddings.median().item():.2f} |
135
+ | values | stddev | {embeddings.std().item():.2f} |
136
+
137
+ ## Quantization Loss
138
+
139
+ | Precision | Metric | Value |
140
+ | ------------- | ------ | ----- |
141
+ | fp16 | mse | {quantization_loss_mse(embeddings, torch.float16):.5f} |
142
+ | fp8 e4m3 | mse | {quantization_loss_mse(embeddings, torch.float8_e4m3fn):.5f} |
143
+ | fp8 e5m2 | mse | {quantization_loss_mse(embeddings, torch.float8_e5m2):.5f} |
144
+ | fp16 | mae | {quantization_loss_mae(embeddings, torch.float16):.5f} |
145
+ | fp8 e4m3 | mae | {quantization_loss_mae(embeddings, torch.float8_e4m3fn):.5f} |
146
+ | fp8 e5m2 | mae | {quantization_loss_mae(embeddings, torch.float8_e5m2):.5f} |
147
+ | fp16 | cosine | {quantization_loss_cosine(embeddings, torch.float16):.5f} |
148
+ | fp8 e4m3 | cosine | {quantization_loss_cosine(embeddings, torch.float8_e4m3fn):.5f} |
149
+ | fp8 e5m2 | cosine | {quantization_loss_cosine(embeddings, torch.float8_e5m2):.5f} |
150
 
151
+ When embeddings are quantized to lower precision (e.g. FP8) and then dequantized
152
+ back to `float32`, some information is inevitably lost. To measure how much the
153
+ quantized embeddings differ from the originals, we report three complementary
154
+ metrics:
155
 
156
+ - **MSE (Mean Squared Error)** — emphasizes large errors by squaring the
157
+ differences. Useful for detecting whether any values are badly distorted.
158
+ - **MAE (Mean Absolute Error)** the average absolute difference between
159
+ original and quantized values. Easier to interpret, less sensitive to outliers.
160
+ - **Cosine Similarity** — measures how well the *direction* of embedding vectors
161
+ is preserved after quantization, independent of scale. This is especially
162
+ relevant when embeddings are used for similarity search or retrieval.
163
 
164
+ Together, these metrics provide a more complete picture of quantization quality
165
+ than any one alone.
166
+
167
+ ### Interpreting Quantization Loss
168
+
169
+ - **Cosine similarity** is the most important metric for embedding use-cases
170
+ such as similarity search, clustering, or retrieval. Values close to 1.0
171
+ mean that embedding directions are preserved after quantization, so model
172
+ quality is likely to hold up.
173
+
174
+ - **MSE and MAE** measure raw element-wise reconstruction error. They provide
175
+ a sense of how much the numerical values change, but these shifts often have
176
+ limited impact on cosine similarity once embeddings are pooled and
177
+ normalized.
178
+
179
+ - **FP16** is effectively lossless and can be treated as a baseline.
180
+
181
+ - **FP8 E4M3** typically offers better precision (lower MSE/MAE) when values
182
+ stay within a moderate range, making it a strong default for static
183
+ embeddings.
184
+
185
+ - **FP8 E5M2** trades some precision for greater dynamic range. It can be
186
+ preferable if embeddings occasionally contain very large values, but it will
187
+ usually show higher MSE/MAE than E4M3.
188
+
189
+ In practice, if cosine similarity remains very close to 1.0, quantization is
190
+ unlikely to harm downstream tasks, even if MSE/MAE look relatively large.
191
+ """
192
+ ).strip()
193
+ )
194
 
195
  for dim in (1024, 512, 384, 256, 128):
196
+ if dim > dimensions:
197
+ print(f"Skipping output of {dim} as the max dimension is {dimensions}")
198
+ continue
199
+
200
  truncated = embeddings[:, :dim]
201
  assert truncated.shape == torch.Size([vocab_size, dim])
202
 
203
+ save_data(model_path / f"fp32.d{dim}.npy", truncated)
204
+ save_data(
205
+ model_path / f"fp16.d{dim}.npy",
206
+ truncated.to(dtype=torch.float16),
207
+ )
208
  save_data(
209
+ model_path / f"fp8_e5m2.d{dim}.npy",
210
+ truncated.to(dtype=torch.float8_e5m2),
211
  )
212
  save_data(
213
+ model_path / f"fp8_e4m3.d{dim}.npy",
214
+ truncated.to(dtype=torch.float8_e4m3fn),
215
  )
216
 
217
 
218
+ def export_tokenizer(model_path: Path, tokenizer: Tokenizer) -> None:
219
+ tokenizer_path = model_path / "tokenizer.json"
220
+ print(f"Exporting tokenizer: {tokenizer_path}")
221
+ tokenizer.save(str(tokenizer_path))
222
+ zst_compress_file(tokenizer_path)
223
+
224
+
225
+ def export_sentence_transformers(hf_org: str, hf_repo: str) -> None:
226
+ """Extract the embeddings and tokenizer from SentenceTransformers"""
227
+
228
+ model_name = f"{hf_org}/{hf_repo}"
229
+ print("Processing", model_name)
230
+
231
+ model = SentenceTransformer(f"{hf_org}/{hf_repo}", device="cpu")
232
+ embedding_bag: EmbeddingBag = model[0].embedding # type: ignore
233
+ model_path = models_path / hf_org / hf_repo
234
+
235
+ export_embeddings(hf_org, hf_repo, model_path, torch.Tensor(embedding_bag.weight))
236
+ export_tokenizer(model_path, model.tokenizer)
237
+
238
+
239
+ def export_model2vec(hf_org: str, hf_repo: str) -> None:
240
+ """Extract the embeddings and tokenizer from model2vec"""
241
+
242
+ model = StaticModel.from_pretrained("minishlab/potion-multilingual-128M")
243
+ model_path = models_path / hf_org / hf_repo
244
+ export_embeddings(hf_org, hf_repo, model_path, torch.from_numpy(model.embedding))
245
+ export_tokenizer(model_path, model.tokenizer)
246
+
247
+
248
  def main() -> None:
249
+ # Static embedders that use sentence_transformers models.
250
+ sentence_transformers_models = [
251
+ ("sentence-transformers", "static-similarity-mrl-multilingual-v1"),
252
+ ("sentence-transformers", "static-retrieval-mrl-en-v1"),
253
+ ]
254
+ # Static embedders that use model2vec.
255
+ model2vec_models = [
256
+ ("minishlab", "potion-multilingual-128M"),
257
+ ("minishlab", "potion-retrieval-32M"),
258
+ ]
259
+
260
+ if models_path.exists():
261
+ print(f"Removing the old models folder: {models_path}")
262
+ shutil.rmtree(models_path)
263
+ models_path.mkdir()
264
+
265
+ for hf_org, hf_repo in sentence_transformers_models:
266
+ export_sentence_transformers(hf_org, hf_repo)
267
+
268
+ for hf_org, hf_repo in model2vec_models:
269
+ export_model2vec(hf_org, hf_repo)
270
 
271
 
272
  if __name__ == "__main__":