| """ |
| quantize_pixie.py β ONNX quantization for XLM-RoBERTa-family embedding models. |
| |
| Produces three self-contained ONNX variants from a float32 source model: |
| model_quantized.onnx β INT8 dynamic (all weights including word embeddings) |
| model_int4.onnx β INT4 MatMul (MatMulNBits) + INT8 word embedding |
| model_int4_full.onnx β INT4 MatMul + INT4 word embedding (opset 21, smallest) |
| |
| Usage: |
| python quantize_pixie.py \\ |
| --input onnx/model.onnx \\ |
| --outdir onnx/ \\ |
| [--block-size 32] |
| |
| # Or via environment variables: |
| PIXIE_INPUT=onnx/model.onnx PIXIE_OUTDIR=onnx/ python quantize_pixie.py |
| |
| The input model is expected to reside in the same directory as its companion |
| data file (model.onnx_data) when using the default HuggingFace layout. |
| """ |
|
|
| import argparse |
| import os |
| import struct |
| from pathlib import Path |
|
|
| import numpy as np |
| import onnx |
| import onnx.version_converter |
| from onnxruntime.quantization import ( |
| QuantType, |
| quantize_dynamic, |
| ) |
| from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer |
|
|
|
|
| |
|
|
| def _load(path: Path) -> onnx.ModelProto: |
| """Load an ONNX model, handling both inline and external initializers.""" |
| model = onnx.load(str(path), load_external_data=False) |
| data_file = path.with_suffix(".onnx_data") |
| if data_file.exists(): |
| onnx.load_external_data_for_model(model, str(path.parent)) |
| return model |
|
|
|
|
| def _save_temp(model: onnx.ModelProto, path: Path) -> None: |
| """Save a model to disk, inlining all tensors (needed before quantization).""" |
| onnx.save(model, str(path)) |
|
|
|
|
| def _find_gather_input_name(model: onnx.ModelProto) -> str | None: |
| """Return the initializer name fed into the first Gather (word embedding) node.""" |
| for node in model.graph.node: |
| if node.op_type == "Gather": |
| return node.input[0] |
| return None |
|
|
|
|
| |
|
|
| def make_int8(src: Path, dst: Path) -> None: |
| """ |
| INT8 dynamic quantization β all weight tensors (MatMul + Gather). |
| |
| Uses onnxruntime quantize_dynamic with QInt8. The word embedding Gather |
| is included, bringing the ~977 MB FP32 embedding table down to ~244 MB. |
| """ |
| print(f" INT8: {src.name} β {dst.name}") |
| quantize_dynamic(str(src), str(dst), weight_type=QuantType.QInt8) |
| print(f" INT8 done ({dst.stat().st_size / 1024**2:.0f} MB)") |
|
|
|
|
| |
|
|
| def _apply_matmul_nbits(src_model: onnx.ModelProto, block_size: int) -> onnx.ModelProto: |
| """Apply MatMulNBits (INT4) to all MatMul weight tensors.""" |
| import tempfile, copy |
| with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: |
| tmp_in = Path(f.name) |
| with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: |
| tmp_out = Path(f.name) |
| try: |
| _save_temp(src_model, tmp_in) |
| q = MatMulNBitsQuantizer( |
| str(tmp_in), |
| block_size=block_size, |
| is_symmetric=True, |
| nodes_to_exclude=[], |
| ) |
| q.process() |
| q.model.save_model_to_file(str(tmp_out), use_external_data_format=False) |
| return onnx.load(str(tmp_out)) |
| finally: |
| tmp_in.unlink(missing_ok=True) |
| tmp_out.unlink(missing_ok=True) |
|
|
|
|
| def make_int4_int8_emb(src: Path, dst: Path, block_size: int = 32) -> None: |
| """ |
| Two-pass: INT4 MatMul (MatMulNBits) + INT8 word embedding. |
| |
| Pass 1 β MatMulNBitsQuantizer packs transformer MatMul weights to 4-bit. |
| Pass 2 β quantize_dynamic(op_types=["Gather"], QInt8) quantizes the |
| word embedding table (250,002 Γ 1024) from FP32 to INT8. |
| """ |
| import tempfile |
| print(f" INT4+INT8 emb: {src.name} β {dst.name}") |
| model = _load(src) |
|
|
| |
| print(" Pass 1: MatMulNBits INT4 ...") |
| matmul_model = _apply_matmul_nbits(model, block_size=block_size) |
|
|
| |
| print(" Pass 2: INT8 Gather ...") |
| with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: |
| tmp = Path(f.name) |
| try: |
| _save_temp(matmul_model, tmp) |
| quantize_dynamic( |
| str(tmp), str(dst), |
| op_types_to_quantize=["Gather"], |
| weight_type=QuantType.QInt8, |
| ) |
| finally: |
| tmp.unlink(missing_ok=True) |
| print(f" INT4+INT8 emb done ({dst.stat().st_size / 1024**2:.0f} MB)") |
|
|
|
|
| |
|
|
| def _pack_int4_rows(weight: np.ndarray) -> tuple[bytes, np.ndarray]: |
| """ |
| Pack a 2-D float32 tensor as per-row symmetric INT4. |
| |
| Each row r is quantized with scale = max(|row_r|) / 7. |
| Values are clamped to [-7, 7] and packed as two INT4 nibbles per byte |
| (little-endian nibble order: low nibble = even index, high nibble = odd). |
| |
| Returns: |
| packed_bytes β raw bytes (vocab_size Γ ceil(dim/2)) |
| scales β float32 scale per row (vocab_size,) |
| """ |
| vocab, dim = weight.shape |
| abs_max = np.abs(weight).max(axis=1, keepdims=True).clip(min=1e-9) |
| scales = (abs_max / 7.0).squeeze(1).astype(np.float32) |
| quantized = np.round(weight / abs_max * 7.0).clip(-7, 7).astype(np.int8) |
|
|
| |
| |
| |
| u4 = (quantized % 16).astype(np.uint8) |
| padded = u4 if dim % 2 == 0 else np.pad(u4, ((0, 0), (0, 1))) |
| packed = padded[:, 0::2] | (padded[:, 1::2] << 4) |
| return packed.tobytes(), scales |
|
|
|
|
| def make_int4_full(src: Path, dst: Path, block_size: int = 32) -> None: |
| """ |
| INT4 full: INT4 MatMul (MatMulNBits) + INT4 word embedding (DequantizeLinear). |
| |
| The word embedding Gather is replaced by: |
| INT4_packed_tensor β DequantizeLinear(axis=0, scale=per_row) β FP32 lookup |
| Requires ONNX opset 21 for the INT4 DequantizeLinear kernel in OnnxRuntime. |
| |
| Build from the FP32 source (not from model_int4.onnx which already has an |
| INT8 DequantizeLinear node on the Gather output, causing a type conflict). |
| """ |
| import tempfile |
| print(f" INT4 full: {src.name} β {dst.name}") |
| model = _load(src) |
|
|
| |
| print(" Step 1: MatMulNBits INT4 ...") |
| matmul_model = _apply_matmul_nbits(model, block_size=block_size) |
|
|
| |
| print(" Step 2: Opset 14 β 21 ...") |
| matmul_model = onnx.version_converter.convert_version(matmul_model, 21) |
|
|
| |
| print(" Step 3: INT4-pack word embedding table ...") |
| graph = matmul_model.graph |
|
|
| |
| embed_init_name = _find_gather_input_name(matmul_model) |
| if embed_init_name is None: |
| raise RuntimeError("Could not find Gather (word embedding) node in graph.") |
|
|
| |
| embed_init = next( |
| (init for init in graph.initializer if init.name == embed_init_name), None |
| ) |
| if embed_init is None: |
| raise RuntimeError(f"Initializer '{embed_init_name}' not found.") |
|
|
| weight_fp32 = np.array( |
| onnx.numpy_helper.to_array(embed_init), dtype=np.float32 |
| ) |
| packed_bytes, scales = _pack_int4_rows(weight_fp32) |
|
|
| |
| graph.initializer.remove(embed_init) |
| int4_name = embed_init_name + "_int4" |
| scales_name = embed_init_name + "_scales" |
|
|
| |
| int4_tensor = onnx.TensorProto() |
| int4_tensor.name = int4_name |
| int4_tensor.data_type = 17 |
| int4_tensor.dims.extend(list(weight_fp32.shape)) |
| int4_tensor.raw_data = packed_bytes |
| graph.initializer.append(int4_tensor) |
|
|
| |
| scales_tensor = onnx.numpy_helper.from_array(scales, name=scales_name) |
| graph.initializer.append(scales_tensor) |
|
|
| |
| dql_out_name = embed_init_name + "_dq" |
| dql_node = onnx.helper.make_node( |
| "DequantizeLinear", |
| inputs=[int4_name, scales_name], |
| outputs=[dql_out_name], |
| axis=0, |
| ) |
|
|
| |
| for node in graph.node: |
| if node.op_type == "Gather" and node.input[0] == embed_init_name: |
| node.input[0] = dql_out_name |
|
|
| |
| gather_idx = next( |
| i for i, n in enumerate(graph.node) |
| if n.op_type == "Gather" and n.input[0] == dql_out_name |
| ) |
| graph.node.insert(gather_idx, dql_node) |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: |
| tmp = Path(f.name) |
| try: |
| onnx.save(matmul_model, str(tmp)) |
| onnx.checker.check_model(str(tmp)) |
| import shutil |
| shutil.copy(tmp, dst) |
| finally: |
| tmp.unlink(missing_ok=True) |
| print(f" INT4 full done ({dst.stat().st_size / 1024**2:.0f} MB)") |
|
|
|
|
| |
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description=__doc__, |
| formatter_class=argparse.RawDescriptionHelpFormatter) |
| p.add_argument("--input", default=os.environ.get("PIXIE_INPUT"), |
| help="Path to the FP32 source model.onnx (may have companion .onnx_data)") |
| p.add_argument("--outdir", default=os.environ.get("PIXIE_OUTDIR", "."), |
| help="Output directory for quantized models (default: cwd)") |
| p.add_argument("--block-size", type=int, default=32, |
| help="Block size for MatMulNBits INT4 (default: 32)") |
| p.add_argument("--variants", nargs="+", |
| choices=["int8", "int4", "int4_full", "all"], |
| default=["all"], |
| help="Which variants to produce (default: all)") |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if not args.input: |
| raise SystemExit("Error: --input or PIXIE_INPUT env var required.") |
|
|
| src = Path(args.input).resolve() |
| outdir = Path(args.outdir).resolve() |
| outdir.mkdir(parents=True, exist_ok=True) |
|
|
| variants = set(args.variants) |
| if "all" in variants: |
| variants = {"int8", "int4", "int4_full"} |
|
|
| print(f"Source : {src}") |
| print(f"Out dir: {outdir}") |
| print(f"Targets: {', '.join(sorted(variants))}") |
| print() |
|
|
| if "int8" in variants: |
| make_int8(src, outdir / "model_quantized.onnx") |
|
|
| if "int4" in variants: |
| make_int4_int8_emb(src, outdir / "model_int4.onnx", |
| block_size=args.block_size) |
|
|
| if "int4_full" in variants: |
| make_int4_full(src, outdir / "model_int4_full.onnx", |
| block_size=args.block_size) |
|
|
| print("\nAll done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|