File size: 5,120 Bytes
02f6c65 315b928 02f6c65 315b928 02f6c65 1c023ec 02f6c65 1c023ec 315b928 02f6c65 315b928 02f6c65 315b928 02f6c65 1c023ec 315b928 1c023ec 315b928 1c023ec 315b928 02f6c65 1c023ec 02f6c65 315b928 02f6c65 315b928 02f6c65 315b928 1c023ec 315b928 1c023ec 02f6c65 315b928 1c023ec 315b928 1c023ec 315b928 65fc5c6 319b3f9 65fc5c6 1318c99 65fc5c6 319b3f9 65fc5c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | """
EE Sanity Check
Usage:
python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
"""
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
def get_sigma(hidden_size, seed):
rng = np.random.default_rng(seed)
return rng.permutation(hidden_size)
def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
print(f"\n{'='*60}")
print(f"Original : {original_name}")
print(f"EE model : {ee_name}")
print(f"Seed : {seed}")
print(f"Prompt : {prompt}")
print('='*60)
tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
print("\n[1] Loading models...")
orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
orig.eval(); ee.eval()
hidden_size = orig.config.hidden_size
sigma = get_sigma(hidden_size, seed)
sigma_t = torch.tensor(sigma, dtype=torch.long)
# --- CHECK 1: Embed layers must be identical ---
embed_match = torch.allclose(
orig.model.embed_tokens.weight.data,
ee.model.embed_tokens.weight.data, atol=1e-3
)
print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
if not embed_match:
print(" ⚠️ Embed was permuted — client-side encryption will be double-permuted")
# --- CHECK 2 & 3: Forward pass with encrypted embeds ---
print("\n[CHECK 2] Running plain forward on original...")
with torch.no_grad():
plain_embeds = orig.model.embed_tokens(input_ids) # use ORIGINAL embed
orig_logits = orig(inputs_embeds=plain_embeds).logits
print("[CHECK 3] Running encrypted forward on EE model...")
with torch.no_grad():
# Client encrypts: take plain embeds, apply sigma
encrypted_embeds = plain_embeds[..., sigma_t]
ee_logits = ee(inputs_embeds=encrypted_embeds).logits
# --- CHECK 4: Logits ---
max_diff = (orig_logits - ee_logits).abs().max().item()
match = max_diff < 0.5
print(f"\n[CHECK 4] Logits match (atol=0.1): {match}")
print(f" Max logit diff: {max_diff:.4f}")
if not match:
print(" ⚠️ Equivariance BROKEN")
# --- CHECK 5: Greedy decode ---
# Both models must use inputs_embeds (not input_ids).
# Original uses plain embeds, EE uses sigma-encrypted embeds.
# Their outputs should be identical token sequences.
print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
with torch.no_grad():
orig_ids = orig.generate(
inputs_embeds=plain_embeds,
attention_mask=inputs.attention_mask,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
ee_ids = ee.generate(
inputs_embeds=encrypted_embeds,
attention_mask=inputs.attention_mask,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True)
print(f" Original output : {repr(orig_text)}")
print(f" EE model output : {repr(ee_text)}")
print(f" Match: {orig_text == ee_text}")
if orig_text == ee_text:
print("\n✅ All checks passed — EE transform is correct")
else:
print("\n⚠️ Text differs despite logits matching.")
print(" This usually means floating point drift in autoregressive generation.")
print(" Check if token IDs match even if decoded text differs slightly:")
print(f" orig_ids: {orig_ids[0].tolist()}")
print(f" ee_ids: {ee_ids[0].tolist()}")
ids_match = orig_ids[0].tolist() == ee_ids[0].tolist()
print(f" Token IDs match: {ids_match}")
print(f"\n{'='*60}\n")
'''if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--original", required=True)
parser.add_argument("--ee", required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--prompt", default="Hello, how are you?")
args = parser.parse_args()
run_check(args.original, args.ee, args.seed, args.prompt)'''
if __name__ == "__main__":
original_name='Qwen/Qwen3-0.6B'
ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
seed = 424242
run_check(original_name, ee_name, seed, prompt="Hello, how are you?")
'''parser = argparse.ArgumentParser()
parser.add_argument("--original", required=True)
parser.add_argument("--ee", required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--prompt", default="Hello, how are you?")
args = parser.parse_args()
run_check(args.original, args.ee, args.seed, args.prompt)''' |