from __future__ import annotations import argparse from statistics import mean import torch from transformers import AutoModelForCausalLM, AutoTokenizer from rotorquant_weights import load_quantized_package, dequantize_to_state_dict def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Validate baseline vs quantized behavior") p.add_argument("--quantized", required=True) p.add_argument("--max-new-tokens", type=int, default=48) p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32") return p.parse_args() def str_to_dtype(s: str) -> torch.dtype: return { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[s] def make_inputs(tokenizer, prompt: str): messages = [ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, {"role": "user", "content": prompt}, ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return tokenizer([text], return_tensors="pt") def token_match_ratio(a: torch.Tensor, b: torch.Tensor) -> float: n = min(a.numel(), b.numel()) if n == 0: return 1.0 return (a[:n] == b[:n]).float().mean().item() def main() -> None: args = parse_args() dtype = str_to_dtype(args.dtype) pkg = load_quantized_package(args.quantized) model_id = pkg["model_id"] print(f"Loading baseline model: {model_id}") baseline = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=None, low_cpu_mem_usage=True, ).eval() tokenizer = AutoTokenizer.from_pretrained(model_id) print("Loading quantized model") quantized = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=None, low_cpu_mem_usage=True, ).eval() q_state = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu") missing, unexpected = quantized.load_state_dict(q_state, strict=False) if missing or unexpected: raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}") prompts = [ "Explain what quantization means in machine learning.", "Write three bullet points about the Moon landing.", "Give me a Python function that computes Fibonacci iteratively.", "Summarize why smaller language models can still be useful.", ] cosine_scores = [] token_match_scores = [] for i, prompt in enumerate(prompts, start=1): inputs = make_inputs(tokenizer, prompt) with torch.no_grad(): b_logits = baseline(**inputs).logits[:, -1, :].float() q_logits = quantized(**inputs).logits[:, -1, :].float() cos = torch.nn.functional.cosine_similarity(b_logits, q_logits, dim=-1).mean().item() cosine_scores.append(cos) with torch.no_grad(): b_gen = baseline.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False) q_gen = quantized.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False) b_new = b_gen[:, inputs["input_ids"].shape[1]:].reshape(-1) q_new = q_gen[:, inputs["input_ids"].shape[1]:].reshape(-1) tm = token_match_ratio(b_new, q_new) token_match_scores.append(tm) print(f"[{i}] cosine(last-token-logits)={cos:.6f}, token-match={tm:.4f}") print("\n=== Aggregate ===") print(f"Mean cosine similarity: {mean(cosine_scores):.6f}") print(f"Mean token match ratio: {mean(token_match_scores):.4f}") if __name__ == "__main__": main()