File size: 5,120 Bytes
02f6c65
315b928
02f6c65
 
 
 
 
 
 
 
 
 
315b928
02f6c65
 
 
1c023ec
 
 
 
 
 
02f6c65
 
1c023ec
 
315b928
 
 
 
02f6c65
 
315b928
 
02f6c65
315b928
 
 
 
 
02f6c65
1c023ec
315b928
1c023ec
315b928
1c023ec
 
315b928
 
02f6c65
1c023ec
02f6c65
315b928
 
 
02f6c65
315b928
02f6c65
315b928
 
1c023ec
315b928
 
 
 
 
 
 
1c023ec
02f6c65
315b928
 
 
 
 
 
 
 
 
 
 
 
 
 
1c023ec
 
 
 
 
 
 
315b928
 
 
 
 
 
 
 
 
 
1c023ec
 
315b928
 
 
 
 
 
 
 
 
65fc5c6
319b3f9
65fc5c6
 
 
1318c99
65fc5c6
 
319b3f9
 
 
 
 
65fc5c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
EE Sanity Check
Usage:
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
"""
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse

def get_sigma(hidden_size, seed):
    rng = np.random.default_rng(seed)
    return rng.permutation(hidden_size)

def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
    print(f"\n{'='*60}")
    print(f"Original : {original_name}")
    print(f"EE model : {ee_name}")
    print(f"Seed     : {seed}")
    print(f"Prompt   : {prompt}")
    print('='*60)

    tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids

    print("\n[1] Loading models...")
    orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
    ee   = AutoModelForCausalLM.from_pretrained(ee_name,       torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
    orig.eval(); ee.eval()

    hidden_size = orig.config.hidden_size
    sigma = get_sigma(hidden_size, seed)
    sigma_t = torch.tensor(sigma, dtype=torch.long)

    # --- CHECK 1: Embed layers must be identical ---
    embed_match = torch.allclose(
        orig.model.embed_tokens.weight.data,
        ee.model.embed_tokens.weight.data, atol=1e-3
    )
    print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
    if not embed_match:
        print("  ⚠️  Embed was permuted — client-side encryption will be double-permuted")

    # --- CHECK 2 & 3: Forward pass with encrypted embeds ---
    print("\n[CHECK 2] Running plain forward on original...")
    with torch.no_grad():
        plain_embeds = orig.model.embed_tokens(input_ids)  # use ORIGINAL embed
        orig_logits = orig(inputs_embeds=plain_embeds).logits

    print("[CHECK 3] Running encrypted forward on EE model...")
    with torch.no_grad():
        # Client encrypts: take plain embeds, apply sigma
        encrypted_embeds = plain_embeds[..., sigma_t]
        ee_logits = ee(inputs_embeds=encrypted_embeds).logits

    # --- CHECK 4: Logits ---
    max_diff = (orig_logits - ee_logits).abs().max().item()
    match = max_diff < 0.5
    print(f"\n[CHECK 4] Logits match (atol=0.1): {match}")
    print(f"  Max logit diff: {max_diff:.4f}")
    if not match:
        print("  ⚠️  Equivariance BROKEN")

    # --- CHECK 5: Greedy decode ---
    # Both models must use inputs_embeds (not input_ids).
    # Original uses plain embeds, EE uses sigma-encrypted embeds.
    # Their outputs should be identical token sequences.
    print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
    with torch.no_grad():
        orig_ids = orig.generate(
            inputs_embeds=plain_embeds,
            attention_mask=inputs.attention_mask,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
        ee_ids = ee.generate(
            inputs_embeds=encrypted_embeds,
            attention_mask=inputs.attention_mask,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

    orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
    ee_text   = tokenizer.decode(ee_ids[0],   skip_special_tokens=True)
    print(f"  Original output : {repr(orig_text)}")
    print(f"  EE model output : {repr(ee_text)}")
    print(f"  Match: {orig_text == ee_text}")

    if orig_text == ee_text:
        print("\n✅ All checks passed — EE transform is correct")
    else:
        print("\n⚠️  Text differs despite logits matching.")
        print("  This usually means floating point drift in autoregressive generation.")
        print("  Check if token IDs match even if decoded text differs slightly:")
        print(f"  orig_ids: {orig_ids[0].tolist()}")
        print(f"  ee_ids:   {ee_ids[0].tolist()}")
        ids_match = orig_ids[0].tolist() == ee_ids[0].tolist()
        print(f"  Token IDs match: {ids_match}")

    print(f"\n{'='*60}\n")

'''if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--original", required=True)
    parser.add_argument("--ee",       required=True)
    parser.add_argument("--seed",     type=int, required=True)
    parser.add_argument("--prompt",   default="Hello, how are you?")
    args = parser.parse_args()
    run_check(args.original, args.ee, args.seed, args.prompt)'''

if __name__ == "__main__":

    original_name='Qwen/Qwen3-0.6B'
    ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
    seed = 424242
    run_check(original_name, ee_name, seed, prompt="Hello, how are you?")
    '''parser = argparse.ArgumentParser()
    parser.add_argument("--original", required=True)
    parser.add_argument("--ee",       required=True)
    parser.add_argument("--seed",     type=int, required=True)
    parser.add_argument("--prompt",   default="Hello, how are you?")
    args = parser.parse_args()
    run_check(args.original, args.ee, args.seed, args.prompt)'''