PIXIE-Rune-v1.0-ONNX / scripts /quantize_pixie.py
cstr's picture
Add sanitized INT4/INT8 quantization script
4c989c0 verified
"""
quantize_pixie.py β€” ONNX quantization for XLM-RoBERTa-family embedding models.
Produces three self-contained ONNX variants from a float32 source model:
model_quantized.onnx β€” INT8 dynamic (all weights including word embeddings)
model_int4.onnx β€” INT4 MatMul (MatMulNBits) + INT8 word embedding
model_int4_full.onnx β€” INT4 MatMul + INT4 word embedding (opset 21, smallest)
Usage:
python quantize_pixie.py \\
--input onnx/model.onnx \\
--outdir onnx/ \\
[--block-size 32]
# Or via environment variables:
PIXIE_INPUT=onnx/model.onnx PIXIE_OUTDIR=onnx/ python quantize_pixie.py
The input model is expected to reside in the same directory as its companion
data file (model.onnx_data) when using the default HuggingFace layout.
"""
import argparse
import os
import struct
from pathlib import Path
import numpy as np
import onnx
import onnx.version_converter
from onnxruntime.quantization import (
QuantType,
quantize_dynamic,
)
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
# ── helpers ──────────────────────────────────────────────────────────────────
def _load(path: Path) -> onnx.ModelProto:
"""Load an ONNX model, handling both inline and external initializers."""
model = onnx.load(str(path), load_external_data=False)
data_file = path.with_suffix(".onnx_data")
if data_file.exists():
onnx.load_external_data_for_model(model, str(path.parent))
return model
def _save_temp(model: onnx.ModelProto, path: Path) -> None:
"""Save a model to disk, inlining all tensors (needed before quantization)."""
onnx.save(model, str(path))
def _find_gather_input_name(model: onnx.ModelProto) -> str | None:
"""Return the initializer name fed into the first Gather (word embedding) node."""
for node in model.graph.node:
if node.op_type == "Gather":
return node.input[0] # initializer with embedding weight
return None
# ── INT8 dynamic quantization ─────────────────────────────────────────────────
def make_int8(src: Path, dst: Path) -> None:
"""
INT8 dynamic quantization β€” all weight tensors (MatMul + Gather).
Uses onnxruntime quantize_dynamic with QInt8. The word embedding Gather
is included, bringing the ~977 MB FP32 embedding table down to ~244 MB.
"""
print(f" INT8: {src.name} β†’ {dst.name}")
quantize_dynamic(str(src), str(dst), weight_type=QuantType.QInt8)
print(f" INT8 done ({dst.stat().st_size / 1024**2:.0f} MB)")
# ── INT4 MatMulNBits quantization ─────────────────────────────────────────────
def _apply_matmul_nbits(src_model: onnx.ModelProto, block_size: int) -> onnx.ModelProto:
"""Apply MatMulNBits (INT4) to all MatMul weight tensors."""
import tempfile, copy
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
tmp_in = Path(f.name)
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
tmp_out = Path(f.name)
try:
_save_temp(src_model, tmp_in)
q = MatMulNBitsQuantizer(
str(tmp_in),
block_size=block_size,
is_symmetric=True,
nodes_to_exclude=[],
)
q.process()
q.model.save_model_to_file(str(tmp_out), use_external_data_format=False)
return onnx.load(str(tmp_out))
finally:
tmp_in.unlink(missing_ok=True)
tmp_out.unlink(missing_ok=True)
def make_int4_int8_emb(src: Path, dst: Path, block_size: int = 32) -> None:
"""
Two-pass: INT4 MatMul (MatMulNBits) + INT8 word embedding.
Pass 1 β€” MatMulNBitsQuantizer packs transformer MatMul weights to 4-bit.
Pass 2 β€” quantize_dynamic(op_types=["Gather"], QInt8) quantizes the
word embedding table (250,002 Γ— 1024) from FP32 to INT8.
"""
import tempfile
print(f" INT4+INT8 emb: {src.name} β†’ {dst.name}")
model = _load(src)
# Pass 1: INT4 MatMul
print(" Pass 1: MatMulNBits INT4 ...")
matmul_model = _apply_matmul_nbits(model, block_size=block_size)
# Pass 2: INT8 Gather (word embedding table only)
print(" Pass 2: INT8 Gather ...")
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
tmp = Path(f.name)
try:
_save_temp(matmul_model, tmp)
quantize_dynamic(
str(tmp), str(dst),
op_types_to_quantize=["Gather"],
weight_type=QuantType.QInt8,
)
finally:
tmp.unlink(missing_ok=True)
print(f" INT4+INT8 emb done ({dst.stat().st_size / 1024**2:.0f} MB)")
# ── INT4 full (word embeddings packed as INT4 nibbles) ────────────────────────
def _pack_int4_rows(weight: np.ndarray) -> tuple[bytes, np.ndarray]:
"""
Pack a 2-D float32 tensor as per-row symmetric INT4.
Each row r is quantized with scale = max(|row_r|) / 7.
Values are clamped to [-7, 7] and packed as two INT4 nibbles per byte
(little-endian nibble order: low nibble = even index, high nibble = odd).
Returns:
packed_bytes β€” raw bytes (vocab_size Γ— ceil(dim/2))
scales β€” float32 scale per row (vocab_size,)
"""
vocab, dim = weight.shape
abs_max = np.abs(weight).max(axis=1, keepdims=True).clip(min=1e-9)
scales = (abs_max / 7.0).squeeze(1).astype(np.float32)
quantized = np.round(weight / abs_max * 7.0).clip(-7, 7).astype(np.int8)
# Pack pairs of INT4 values into bytes
# Treat negative as unsigned 4-bit: -7..7 β†’ offset doesn't apply for symmetric
# Use unsigned nibbles with zero_point=0 (symmetric)
u4 = (quantized % 16).astype(np.uint8) # map negatives: e.g. -1 β†’ 15
padded = u4 if dim % 2 == 0 else np.pad(u4, ((0, 0), (0, 1)))
packed = padded[:, 0::2] | (padded[:, 1::2] << 4)
return packed.tobytes(), scales
def make_int4_full(src: Path, dst: Path, block_size: int = 32) -> None:
"""
INT4 full: INT4 MatMul (MatMulNBits) + INT4 word embedding (DequantizeLinear).
The word embedding Gather is replaced by:
INT4_packed_tensor β†’ DequantizeLinear(axis=0, scale=per_row) β†’ FP32 lookup
Requires ONNX opset 21 for the INT4 DequantizeLinear kernel in OnnxRuntime.
Build from the FP32 source (not from model_int4.onnx which already has an
INT8 DequantizeLinear node on the Gather output, causing a type conflict).
"""
import tempfile
print(f" INT4 full: {src.name} β†’ {dst.name}")
model = _load(src)
# Step 1: INT4 MatMul
print(" Step 1: MatMulNBits INT4 ...")
matmul_model = _apply_matmul_nbits(model, block_size=block_size)
# Step 2: Migrate to opset 21 (required for INT4 DequantizeLinear)
print(" Step 2: Opset 14 β†’ 21 ...")
matmul_model = onnx.version_converter.convert_version(matmul_model, 21)
# Step 3: Find and replace the Gather (word embedding) node
print(" Step 3: INT4-pack word embedding table ...")
graph = matmul_model.graph
# Locate embedding initializer name
embed_init_name = _find_gather_input_name(matmul_model)
if embed_init_name is None:
raise RuntimeError("Could not find Gather (word embedding) node in graph.")
# Extract current FP32 embedding tensor
embed_init = next(
(init for init in graph.initializer if init.name == embed_init_name), None
)
if embed_init is None:
raise RuntimeError(f"Initializer '{embed_init_name}' not found.")
weight_fp32 = np.array(
onnx.numpy_helper.to_array(embed_init), dtype=np.float32
)
packed_bytes, scales = _pack_int4_rows(weight_fp32)
# Replace the FP32 initializer with packed INT4
graph.initializer.remove(embed_init)
int4_name = embed_init_name + "_int4"
scales_name = embed_init_name + "_scales"
# INT4 tensor stored as raw bytes in ONNX (UINT4 = elem_type 17)
int4_tensor = onnx.TensorProto()
int4_tensor.name = int4_name
int4_tensor.data_type = 17 # UINT4
int4_tensor.dims.extend(list(weight_fp32.shape))
int4_tensor.raw_data = packed_bytes
graph.initializer.append(int4_tensor)
# Per-row scale tensor (float32)
scales_tensor = onnx.numpy_helper.from_array(scales, name=scales_name)
graph.initializer.append(scales_tensor)
# Insert DequantizeLinear(axis=0) between INT4 weights and the Gather node
dql_out_name = embed_init_name + "_dq"
dql_node = onnx.helper.make_node(
"DequantizeLinear",
inputs=[int4_name, scales_name],
outputs=[dql_out_name],
axis=0,
)
# Reroute: Gather now reads from dql_out instead of original initializer
for node in graph.node:
if node.op_type == "Gather" and node.input[0] == embed_init_name:
node.input[0] = dql_out_name
# Insert DequantizeLinear before the Gather node
gather_idx = next(
i for i, n in enumerate(graph.node)
if n.op_type == "Gather" and n.input[0] == dql_out_name
)
graph.node.insert(gather_idx, dql_node)
# Save
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
tmp = Path(f.name)
try:
onnx.save(matmul_model, str(tmp))
onnx.checker.check_model(str(tmp))
import shutil
shutil.copy(tmp, dst)
finally:
tmp.unlink(missing_ok=True)
print(f" INT4 full done ({dst.stat().st_size / 1024**2:.0f} MB)")
# ── entry point ──────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", default=os.environ.get("PIXIE_INPUT"),
help="Path to the FP32 source model.onnx (may have companion .onnx_data)")
p.add_argument("--outdir", default=os.environ.get("PIXIE_OUTDIR", "."),
help="Output directory for quantized models (default: cwd)")
p.add_argument("--block-size", type=int, default=32,
help="Block size for MatMulNBits INT4 (default: 32)")
p.add_argument("--variants", nargs="+",
choices=["int8", "int4", "int4_full", "all"],
default=["all"],
help="Which variants to produce (default: all)")
return p.parse_args()
def main() -> None:
args = parse_args()
if not args.input:
raise SystemExit("Error: --input or PIXIE_INPUT env var required.")
src = Path(args.input).resolve()
outdir = Path(args.outdir).resolve()
outdir.mkdir(parents=True, exist_ok=True)
variants = set(args.variants)
if "all" in variants:
variants = {"int8", "int4", "int4_full"}
print(f"Source : {src}")
print(f"Out dir: {outdir}")
print(f"Targets: {', '.join(sorted(variants))}")
print()
if "int8" in variants:
make_int8(src, outdir / "model_quantized.onnx")
if "int4" in variants:
make_int4_int8_emb(src, outdir / "model_int4.onnx",
block_size=args.block_size)
if "int4_full" in variants:
make_int4_full(src, outdir / "model_int4_full.onnx",
block_size=args.block_size)
print("\nAll done.")
if __name__ == "__main__":
main()