| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModelForCausalLM |
|
|
| from rotorquant_weights import ( |
| quantize_state_dict, |
| save_quantized_package, |
| save_report, |
| estimate_bits_per_weight, |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Quantize a HF model with RotorQuant-weight codec") |
| p.add_argument("--model-id", default="Qwen/Qwen2.5-0.5B-Instruct") |
| p.add_argument("--output", default="artifacts/qwen2.5-0.5b-rotorq3.pt") |
| p.add_argument("--report", default="artifacts/qwen2.5-0.5b-rotorq3-report.json") |
| p.add_argument("--bits", type=int, default=3) |
| p.add_argument("--block-size", type=int, default=128) |
| p.add_argument("--seed", type=int, default=1337) |
| p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32") |
| p.add_argument("--min-ndim", type=int, default=2) |
| p.add_argument( |
| "--skip-name", |
| action="append", |
| default=[], |
| help="Exact tensor names to keep unquantized (repeatable).", |
| ) |
| p.add_argument("--lowrank-rank", type=int, default=0, help="Optional residual low-rank correction rank.") |
| p.add_argument("--rotor-angle-scale", type=float, default=1.0, help="Scale for rotor angle; 0.0 disables rotation.") |
| p.add_argument("--rowwise", action="store_true", help="Quantize per-row (higher overhead, sometimes higher fidelity).") |
| p.add_argument("--outlier-frac", type=float, default=0.0, help="Store top-k residual outliers per row in fp16.") |
| p.add_argument( |
| "--include-name-contains", |
| action="append", |
| default=[], |
| help="Only quantize tensors whose name contains at least one provided fragment (repeatable).", |
| ) |
| return p.parse_args() |
|
|
|
|
| def str_to_dtype(s: str) -> torch.dtype: |
| return { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[s] |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| dtype = str_to_dtype(args.dtype) |
|
|
| print(f"Loading model: {args.model_id} (dtype={dtype})") |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_id, |
| torch_dtype=dtype, |
| device_map=None, |
| low_cpu_mem_usage=True, |
| ) |
| model.eval() |
|
|
| state = {k: v.detach().cpu() for k, v in model.state_dict().items()} |
| print(f"State dict tensors: {len(state)}") |
|
|
| pkg = quantize_state_dict( |
| state, |
| bits=args.bits, |
| block_size=args.block_size, |
| seed=args.seed, |
| min_ndim=args.min_ndim, |
| verbose=True, |
| skip_names=args.skip_name, |
| lowrank_rank=args.lowrank_rank, |
| rotor_angle_scale=args.rotor_angle_scale, |
| rowwise=args.rowwise, |
| include_if_name_contains=args.include_name_contains, |
| outlier_frac=args.outlier_frac, |
| ) |
| pkg["model_id"] = args.model_id |
| pkg["source_dtype"] = args.dtype |
|
|
| output_path = Path(args.output) |
| report_path = Path(args.report) |
| save_quantized_package(pkg, output_path) |
| save_report(pkg, report_path) |
|
|
| bpw = estimate_bits_per_weight(pkg) |
| print(f"Saved quantized package: {output_path}") |
| print(f"Saved report: {report_path}") |
| print(f"Estimated effective bits/weight: {bpw:.4f}") |
| print(f"Quantized tensors: {len(pkg['quantized'])}, passthrough: {len(pkg['passthrough'])}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|