cnmoro's picture
Upload 29 files
18f4d80 verified
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM
from rotorquant_weights import (
quantize_state_dict,
save_quantized_package,
save_report,
estimate_bits_per_weight,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Quantize a HF model with RotorQuant-weight codec")
p.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct")
p.add_argument("--output", default="artifacts/qwen2.5-0.5b-rotorq3.pt")
p.add_argument("--report", default="artifacts/qwen2.5-0.5b-rotorq3-report.json")
p.add_argument("--bits", type=int, default=3)
p.add_argument("--block-size", type=int, default=128)
p.add_argument("--seed", type=int, default=1337)
p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
p.add_argument("--min-ndim", type=int, default=2)
p.add_argument(
"--skip-name",
action="append",
default=[],
help="Exact tensor names to keep unquantized (repeatable).",
)
p.add_argument("--lowrank-rank", type=int, default=0, help="Optional residual low-rank correction rank.")
p.add_argument("--rotor-angle-scale", type=float, default=1.0, help="Scale for rotor angle; 0.0 disables rotation.")
p.add_argument("--rowwise", action="store_true", help="Quantize per-row (higher overhead, sometimes higher fidelity).")
p.add_argument("--outlier-frac", type=float, default=0.0, help="Store top-k residual outliers per row in fp16.")
p.add_argument(
"--include-name-contains",
action="append",
default=[],
help="Only quantize tensors whose name contains at least one provided fragment (repeatable).",
)
return p.parse_args()
def str_to_dtype(s: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[s]
def main() -> None:
args = parse_args()
dtype = str_to_dtype(args.dtype)
print(f"Loading model: {args.model_id} (dtype={dtype})")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
torch_dtype=dtype,
device_map=None,
low_cpu_mem_usage=True,
)
model.eval()
state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
print(f"State dict tensors: {len(state)}")
pkg = quantize_state_dict(
state,
bits=args.bits,
block_size=args.block_size,
seed=args.seed,
min_ndim=args.min_ndim,
verbose=True,
skip_names=args.skip_name,
lowrank_rank=args.lowrank_rank,
rotor_angle_scale=args.rotor_angle_scale,
rowwise=args.rowwise,
include_if_name_contains=args.include_name_contains,
outlier_frac=args.outlier_frac,
)
pkg["model_id"] = args.model_id
pkg["source_dtype"] = args.dtype
output_path = Path(args.output)
report_path = Path(args.report)
save_quantized_package(pkg, output_path)
save_report(pkg, report_path)
bpw = estimate_bits_per_weight(pkg)
print(f"Saved quantized package: {output_path}")
print(f"Saved report: {report_path}")
print(f"Estimated effective bits/weight: {bpw:.4f}")
print(f"Quantized tensors: {len(pkg['quantized'])}, passthrough: {len(pkg['passthrough'])}")
if __name__ == "__main__":
main()