| from __future__ import annotations |
|
|
| import argparse |
| from statistics import mean |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from rotorquant_weights import load_quantized_package, dequantize_to_state_dict |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Validate baseline vs quantized behavior") |
| p.add_argument("--quantized", required=True) |
| p.add_argument("--max-new-tokens", type=int, default=48) |
| p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32") |
| return p.parse_args() |
|
|
|
|
| def str_to_dtype(s: str) -> torch.dtype: |
| return { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[s] |
|
|
|
|
| def make_inputs(tokenizer, prompt: str): |
| messages = [ |
| {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, |
| {"role": "user", "content": prompt}, |
| ] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| return tokenizer([text], return_tensors="pt") |
|
|
|
|
| def token_match_ratio(a: torch.Tensor, b: torch.Tensor) -> float: |
| n = min(a.numel(), b.numel()) |
| if n == 0: |
| return 1.0 |
| return (a[:n] == b[:n]).float().mean().item() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| dtype = str_to_dtype(args.dtype) |
|
|
| pkg = load_quantized_package(args.quantized) |
| model_id = pkg["model_id"] |
|
|
| print(f"Loading baseline model: {model_id}") |
| baseline = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=dtype, |
| device_map=None, |
| low_cpu_mem_usage=True, |
| ).eval() |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
| print("Loading quantized model") |
| quantized = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=dtype, |
| device_map=None, |
| low_cpu_mem_usage=True, |
| ).eval() |
| q_state = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu") |
| missing, unexpected = quantized.load_state_dict(q_state, strict=False) |
| if missing or unexpected: |
| raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}") |
|
|
| prompts = [ |
| "Explain what quantization means in machine learning.", |
| "Write three bullet points about the Moon landing.", |
| "Give me a Python function that computes Fibonacci iteratively.", |
| "Summarize why smaller language models can still be useful.", |
| ] |
|
|
| cosine_scores = [] |
| token_match_scores = [] |
|
|
| for i, prompt in enumerate(prompts, start=1): |
| inputs = make_inputs(tokenizer, prompt) |
|
|
| with torch.no_grad(): |
| b_logits = baseline(**inputs).logits[:, -1, :].float() |
| q_logits = quantized(**inputs).logits[:, -1, :].float() |
|
|
| cos = torch.nn.functional.cosine_similarity(b_logits, q_logits, dim=-1).mean().item() |
| cosine_scores.append(cos) |
|
|
| with torch.no_grad(): |
| b_gen = baseline.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False) |
| q_gen = quantized.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False) |
|
|
| b_new = b_gen[:, inputs["input_ids"].shape[1]:].reshape(-1) |
| q_new = q_gen[:, inputs["input_ids"].shape[1]:].reshape(-1) |
| tm = token_match_ratio(b_new, q_new) |
| token_match_scores.append(tm) |
|
|
| print(f"[{i}] cosine(last-token-logits)={cos:.6f}, token-match={tm:.4f}") |
|
|
| print("\n=== Aggregate ===") |
| print(f"Mean cosine similarity: {mean(cosine_scores):.6f}") |
| print(f"Mean token match ratio: {mean(token_match_scores):.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|