RotorQuant-ModelWeights-Runtime / validate_quantization.py
cnmoro's picture
Upload 29 files
18f4d80 verified
from __future__ import annotations
import argparse
from statistics import mean
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rotorquant_weights import load_quantized_package, dequantize_to_state_dict
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Validate baseline vs quantized behavior")
p.add_argument("--quantized", required=True)
p.add_argument("--max-new-tokens", type=int, default=48)
p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
return p.parse_args()
def str_to_dtype(s: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[s]
def make_inputs(tokenizer, prompt: str):
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return tokenizer([text], return_tensors="pt")
def token_match_ratio(a: torch.Tensor, b: torch.Tensor) -> float:
n = min(a.numel(), b.numel())
if n == 0:
return 1.0
return (a[:n] == b[:n]).float().mean().item()
def main() -> None:
args = parse_args()
dtype = str_to_dtype(args.dtype)
pkg = load_quantized_package(args.quantized)
model_id = pkg["model_id"]
print(f"Loading baseline model: {model_id}")
baseline = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=None,
low_cpu_mem_usage=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
print("Loading quantized model")
quantized = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=None,
low_cpu_mem_usage=True,
).eval()
q_state = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu")
missing, unexpected = quantized.load_state_dict(q_state, strict=False)
if missing or unexpected:
raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}")
prompts = [
"Explain what quantization means in machine learning.",
"Write three bullet points about the Moon landing.",
"Give me a Python function that computes Fibonacci iteratively.",
"Summarize why smaller language models can still be useful.",
]
cosine_scores = []
token_match_scores = []
for i, prompt in enumerate(prompts, start=1):
inputs = make_inputs(tokenizer, prompt)
with torch.no_grad():
b_logits = baseline(**inputs).logits[:, -1, :].float()
q_logits = quantized(**inputs).logits[:, -1, :].float()
cos = torch.nn.functional.cosine_similarity(b_logits, q_logits, dim=-1).mean().item()
cosine_scores.append(cos)
with torch.no_grad():
b_gen = baseline.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
q_gen = quantized.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
b_new = b_gen[:, inputs["input_ids"].shape[1]:].reshape(-1)
q_new = q_gen[:, inputs["input_ids"].shape[1]:].reshape(-1)
tm = token_match_ratio(b_new, q_new)
token_match_scores.append(tm)
print(f"[{i}] cosine(last-token-logits)={cos:.6f}, token-match={tm:.4f}")
print("\n=== Aggregate ===")
print(f"Mean cosine similarity: {mean(cosine_scores):.6f}")
print(f"Mean token match ratio: {mean(token_match_scores):.4f}")
if __name__ == "__main__":
main()