File size: 11,790 Bytes
4c989c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
quantize_pixie.py β€” ONNX quantization for XLM-RoBERTa-family embedding models.

Produces three self-contained ONNX variants from a float32 source model:
  model_quantized.onnx   β€” INT8 dynamic (all weights including word embeddings)
  model_int4.onnx        β€” INT4 MatMul (MatMulNBits) + INT8 word embedding
  model_int4_full.onnx   β€” INT4 MatMul + INT4 word embedding (opset 21, smallest)

Usage:
    python quantize_pixie.py \\
        --input  onnx/model.onnx \\
        --outdir onnx/ \\
        [--block-size 32]

    # Or via environment variables:
    PIXIE_INPUT=onnx/model.onnx PIXIE_OUTDIR=onnx/ python quantize_pixie.py

The input model is expected to reside in the same directory as its companion
data file (model.onnx_data) when using the default HuggingFace layout.
"""

import argparse
import os
import struct
from pathlib import Path

import numpy as np
import onnx
import onnx.version_converter
from onnxruntime.quantization import (
    QuantType,
    quantize_dynamic,
)
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer


# ── helpers ──────────────────────────────────────────────────────────────────

def _load(path: Path) -> onnx.ModelProto:
    """Load an ONNX model, handling both inline and external initializers."""
    model = onnx.load(str(path), load_external_data=False)
    data_file = path.with_suffix(".onnx_data")
    if data_file.exists():
        onnx.load_external_data_for_model(model, str(path.parent))
    return model


def _save_temp(model: onnx.ModelProto, path: Path) -> None:
    """Save a model to disk, inlining all tensors (needed before quantization)."""
    onnx.save(model, str(path))


def _find_gather_input_name(model: onnx.ModelProto) -> str | None:
    """Return the initializer name fed into the first Gather (word embedding) node."""
    for node in model.graph.node:
        if node.op_type == "Gather":
            return node.input[0]  # initializer with embedding weight
    return None


# ── INT8 dynamic quantization ─────────────────────────────────────────────────

def make_int8(src: Path, dst: Path) -> None:
    """
    INT8 dynamic quantization β€” all weight tensors (MatMul + Gather).

    Uses onnxruntime quantize_dynamic with QInt8.  The word embedding Gather
    is included, bringing the ~977 MB FP32 embedding table down to ~244 MB.
    """
    print(f"  INT8: {src.name} β†’ {dst.name}")
    quantize_dynamic(str(src), str(dst), weight_type=QuantType.QInt8)
    print(f"  INT8 done  ({dst.stat().st_size / 1024**2:.0f} MB)")


# ── INT4 MatMulNBits quantization ─────────────────────────────────────────────

def _apply_matmul_nbits(src_model: onnx.ModelProto, block_size: int) -> onnx.ModelProto:
    """Apply MatMulNBits (INT4) to all MatMul weight tensors."""
    import tempfile, copy
    with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
        tmp_in = Path(f.name)
    with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
        tmp_out = Path(f.name)
    try:
        _save_temp(src_model, tmp_in)
        q = MatMulNBitsQuantizer(
            str(tmp_in),
            block_size=block_size,
            is_symmetric=True,
            nodes_to_exclude=[],
        )
        q.process()
        q.model.save_model_to_file(str(tmp_out), use_external_data_format=False)
        return onnx.load(str(tmp_out))
    finally:
        tmp_in.unlink(missing_ok=True)
        tmp_out.unlink(missing_ok=True)


def make_int4_int8_emb(src: Path, dst: Path, block_size: int = 32) -> None:
    """
    Two-pass: INT4 MatMul (MatMulNBits) + INT8 word embedding.

    Pass 1 β€” MatMulNBitsQuantizer packs transformer MatMul weights to 4-bit.
    Pass 2 β€” quantize_dynamic(op_types=["Gather"], QInt8) quantizes the
             word embedding table (250,002 Γ— 1024) from FP32 to INT8.
    """
    import tempfile
    print(f"  INT4+INT8 emb: {src.name} β†’ {dst.name}")
    model = _load(src)

    # Pass 1: INT4 MatMul
    print("    Pass 1: MatMulNBits INT4 ...")
    matmul_model = _apply_matmul_nbits(model, block_size=block_size)

    # Pass 2: INT8 Gather (word embedding table only)
    print("    Pass 2: INT8 Gather ...")
    with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
        tmp = Path(f.name)
    try:
        _save_temp(matmul_model, tmp)
        quantize_dynamic(
            str(tmp), str(dst),
            op_types_to_quantize=["Gather"],
            weight_type=QuantType.QInt8,
        )
    finally:
        tmp.unlink(missing_ok=True)
    print(f"  INT4+INT8 emb done  ({dst.stat().st_size / 1024**2:.0f} MB)")


# ── INT4 full (word embeddings packed as INT4 nibbles) ────────────────────────

def _pack_int4_rows(weight: np.ndarray) -> tuple[bytes, np.ndarray]:
    """
    Pack a 2-D float32 tensor as per-row symmetric INT4.

    Each row r is quantized with scale = max(|row_r|) / 7.
    Values are clamped to [-7, 7] and packed as two INT4 nibbles per byte
    (little-endian nibble order: low nibble = even index, high nibble = odd).

    Returns:
        packed_bytes  β€” raw bytes (vocab_size Γ— ceil(dim/2))
        scales        β€” float32 scale per row (vocab_size,)
    """
    vocab, dim = weight.shape
    abs_max = np.abs(weight).max(axis=1, keepdims=True).clip(min=1e-9)
    scales = (abs_max / 7.0).squeeze(1).astype(np.float32)
    quantized = np.round(weight / abs_max * 7.0).clip(-7, 7).astype(np.int8)

    # Pack pairs of INT4 values into bytes
    # Treat negative as unsigned 4-bit: -7..7 β†’ offset doesn't apply for symmetric
    # Use unsigned nibbles with zero_point=0 (symmetric)
    u4 = (quantized % 16).astype(np.uint8)  # map negatives: e.g. -1 β†’ 15
    padded = u4 if dim % 2 == 0 else np.pad(u4, ((0, 0), (0, 1)))
    packed = padded[:, 0::2] | (padded[:, 1::2] << 4)
    return packed.tobytes(), scales


def make_int4_full(src: Path, dst: Path, block_size: int = 32) -> None:
    """
    INT4 full: INT4 MatMul (MatMulNBits) + INT4 word embedding (DequantizeLinear).

    The word embedding Gather is replaced by:
        INT4_packed_tensor β†’ DequantizeLinear(axis=0, scale=per_row) β†’ FP32 lookup
    Requires ONNX opset 21 for the INT4 DequantizeLinear kernel in OnnxRuntime.

    Build from the FP32 source (not from model_int4.onnx which already has an
    INT8 DequantizeLinear node on the Gather output, causing a type conflict).
    """
    import tempfile
    print(f"  INT4 full: {src.name} β†’ {dst.name}")
    model = _load(src)

    # Step 1: INT4 MatMul
    print("    Step 1: MatMulNBits INT4 ...")
    matmul_model = _apply_matmul_nbits(model, block_size=block_size)

    # Step 2: Migrate to opset 21 (required for INT4 DequantizeLinear)
    print("    Step 2: Opset 14 β†’ 21 ...")
    matmul_model = onnx.version_converter.convert_version(matmul_model, 21)

    # Step 3: Find and replace the Gather (word embedding) node
    print("    Step 3: INT4-pack word embedding table ...")
    graph = matmul_model.graph

    # Locate embedding initializer name
    embed_init_name = _find_gather_input_name(matmul_model)
    if embed_init_name is None:
        raise RuntimeError("Could not find Gather (word embedding) node in graph.")

    # Extract current FP32 embedding tensor
    embed_init = next(
        (init for init in graph.initializer if init.name == embed_init_name), None
    )
    if embed_init is None:
        raise RuntimeError(f"Initializer '{embed_init_name}' not found.")

    weight_fp32 = np.array(
        onnx.numpy_helper.to_array(embed_init), dtype=np.float32
    )
    packed_bytes, scales = _pack_int4_rows(weight_fp32)

    # Replace the FP32 initializer with packed INT4
    graph.initializer.remove(embed_init)
    int4_name = embed_init_name + "_int4"
    scales_name = embed_init_name + "_scales"

    # INT4 tensor stored as raw bytes in ONNX (UINT4 = elem_type 17)
    int4_tensor = onnx.TensorProto()
    int4_tensor.name = int4_name
    int4_tensor.data_type = 17  # UINT4
    int4_tensor.dims.extend(list(weight_fp32.shape))
    int4_tensor.raw_data = packed_bytes
    graph.initializer.append(int4_tensor)

    # Per-row scale tensor (float32)
    scales_tensor = onnx.numpy_helper.from_array(scales, name=scales_name)
    graph.initializer.append(scales_tensor)

    # Insert DequantizeLinear(axis=0) between INT4 weights and the Gather node
    dql_out_name = embed_init_name + "_dq"
    dql_node = onnx.helper.make_node(
        "DequantizeLinear",
        inputs=[int4_name, scales_name],
        outputs=[dql_out_name],
        axis=0,
    )

    # Reroute: Gather now reads from dql_out instead of original initializer
    for node in graph.node:
        if node.op_type == "Gather" and node.input[0] == embed_init_name:
            node.input[0] = dql_out_name

    # Insert DequantizeLinear before the Gather node
    gather_idx = next(
        i for i, n in enumerate(graph.node)
        if n.op_type == "Gather" and n.input[0] == dql_out_name
    )
    graph.node.insert(gather_idx, dql_node)

    # Save
    with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
        tmp = Path(f.name)
    try:
        onnx.save(matmul_model, str(tmp))
        onnx.checker.check_model(str(tmp))
        import shutil
        shutil.copy(tmp, dst)
    finally:
        tmp.unlink(missing_ok=True)
    print(f"  INT4 full done  ({dst.stat().st_size / 1024**2:.0f} MB)")


# ── entry point ──────────────────────────────────────────────────────────────

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--input",  default=os.environ.get("PIXIE_INPUT"),
                   help="Path to the FP32 source model.onnx (may have companion .onnx_data)")
    p.add_argument("--outdir", default=os.environ.get("PIXIE_OUTDIR", "."),
                   help="Output directory for quantized models (default: cwd)")
    p.add_argument("--block-size", type=int, default=32,
                   help="Block size for MatMulNBits INT4 (default: 32)")
    p.add_argument("--variants", nargs="+",
                   choices=["int8", "int4", "int4_full", "all"],
                   default=["all"],
                   help="Which variants to produce (default: all)")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    if not args.input:
        raise SystemExit("Error: --input or PIXIE_INPUT env var required.")

    src = Path(args.input).resolve()
    outdir = Path(args.outdir).resolve()
    outdir.mkdir(parents=True, exist_ok=True)

    variants = set(args.variants)
    if "all" in variants:
        variants = {"int8", "int4", "int4_full"}

    print(f"Source : {src}")
    print(f"Out dir: {outdir}")
    print(f"Targets: {', '.join(sorted(variants))}")
    print()

    if "int8" in variants:
        make_int8(src, outdir / "model_quantized.onnx")

    if "int4" in variants:
        make_int4_int8_emb(src, outdir / "model_int4.onnx",
                           block_size=args.block_size)

    if "int4_full" in variants:
        make_int4_full(src, outdir / "model_int4_full.onnx",
                       block_size=args.block_size)

    print("\nAll done.")


if __name__ == "__main__":
    main()