File size: 3,714 Bytes
18f4d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations

import argparse
from statistics import mean

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from rotorquant_weights import load_quantized_package, dequantize_to_state_dict


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Validate baseline vs quantized behavior")
    p.add_argument("--quantized", required=True)
    p.add_argument("--max-new-tokens", type=int, default=48)
    p.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
    return p.parse_args()


def str_to_dtype(s: str) -> torch.dtype:
    return {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }[s]


def make_inputs(tokenizer, prompt: str):
    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        {"role": "user", "content": prompt},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return tokenizer([text], return_tensors="pt")


def token_match_ratio(a: torch.Tensor, b: torch.Tensor) -> float:
    n = min(a.numel(), b.numel())
    if n == 0:
        return 1.0
    return (a[:n] == b[:n]).float().mean().item()


def main() -> None:
    args = parse_args()
    dtype = str_to_dtype(args.dtype)

    pkg = load_quantized_package(args.quantized)
    model_id = pkg["model_id"]

    print(f"Loading baseline model: {model_id}")
    baseline = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=None,
        low_cpu_mem_usage=True,
    ).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    print("Loading quantized model")
    quantized = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=None,
        low_cpu_mem_usage=True,
    ).eval()
    q_state = dequantize_to_state_dict(pkg, dtype=dtype, device="cpu")
    missing, unexpected = quantized.load_state_dict(q_state, strict=False)
    if missing or unexpected:
        raise RuntimeError(f"State dict mismatch. Missing={missing}, unexpected={unexpected}")

    prompts = [
        "Explain what quantization means in machine learning.",
        "Write three bullet points about the Moon landing.",
        "Give me a Python function that computes Fibonacci iteratively.",
        "Summarize why smaller language models can still be useful.",
    ]

    cosine_scores = []
    token_match_scores = []

    for i, prompt in enumerate(prompts, start=1):
        inputs = make_inputs(tokenizer, prompt)

        with torch.no_grad():
            b_logits = baseline(**inputs).logits[:, -1, :].float()
            q_logits = quantized(**inputs).logits[:, -1, :].float()

        cos = torch.nn.functional.cosine_similarity(b_logits, q_logits, dim=-1).mean().item()
        cosine_scores.append(cos)

        with torch.no_grad():
            b_gen = baseline.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
            q_gen = quantized.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)

        b_new = b_gen[:, inputs["input_ids"].shape[1]:].reshape(-1)
        q_new = q_gen[:, inputs["input_ids"].shape[1]:].reshape(-1)
        tm = token_match_ratio(b_new, q_new)
        token_match_scores.append(tm)

        print(f"[{i}] cosine(last-token-logits)={cos:.6f}, token-match={tm:.4f}")

    print("\n=== Aggregate ===")
    print(f"Mean cosine similarity: {mean(cosine_scores):.6f}")
    print(f"Mean token match ratio: {mean(token_match_scores):.4f}")


if __name__ == "__main__":
    main()