broadfield-dev's picture
Update debug_ee.py
315b928 verified
"""
EE Sanity Check
Usage:
python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
"""
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
def get_sigma(hidden_size, seed):
rng = np.random.default_rng(seed)
return rng.permutation(hidden_size)
def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
print(f"\n{'='*60}")
print(f"Original : {original_name}")
print(f"EE model : {ee_name}")
print(f"Seed : {seed}")
print(f"Prompt : {prompt}")
print('='*60)
tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
print("\n[1] Loading models...")
orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
orig.eval(); ee.eval()
hidden_size = orig.config.hidden_size
sigma = get_sigma(hidden_size, seed)
sigma_t = torch.tensor(sigma, dtype=torch.long)
# --- CHECK 1: Embed layers must be identical ---
embed_match = torch.allclose(
orig.model.embed_tokens.weight.data,
ee.model.embed_tokens.weight.data, atol=1e-3
)
print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
if not embed_match:
print(" ⚠️ Embed was permuted — client-side encryption will be double-permuted")
# --- CHECK 2 & 3: Forward pass with encrypted embeds ---
print("\n[CHECK 2] Running plain forward on original...")
with torch.no_grad():
plain_embeds = orig.model.embed_tokens(input_ids) # use ORIGINAL embed
orig_logits = orig(inputs_embeds=plain_embeds).logits
print("[CHECK 3] Running encrypted forward on EE model...")
with torch.no_grad():
# Client encrypts: take plain embeds, apply sigma
encrypted_embeds = plain_embeds[..., sigma_t]
ee_logits = ee(inputs_embeds=encrypted_embeds).logits
# --- CHECK 4: Logits ---
max_diff = (orig_logits - ee_logits).abs().max().item()
match = max_diff < 0.5
print(f"\n[CHECK 4] Logits match (atol=0.1): {match}")
print(f" Max logit diff: {max_diff:.4f}")
if not match:
print(" ⚠️ Equivariance BROKEN")
# --- CHECK 5: Greedy decode ---
# Both models must use inputs_embeds (not input_ids).
# Original uses plain embeds, EE uses sigma-encrypted embeds.
# Their outputs should be identical token sequences.
print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
with torch.no_grad():
orig_ids = orig.generate(
inputs_embeds=plain_embeds,
attention_mask=inputs.attention_mask,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
ee_ids = ee.generate(
inputs_embeds=encrypted_embeds,
attention_mask=inputs.attention_mask,
max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True)
print(f" Original output : {repr(orig_text)}")
print(f" EE model output : {repr(ee_text)}")
print(f" Match: {orig_text == ee_text}")
if orig_text == ee_text:
print("\n✅ All checks passed — EE transform is correct")
else:
print("\n⚠️ Text differs despite logits matching.")
print(" This usually means floating point drift in autoregressive generation.")
print(" Check if token IDs match even if decoded text differs slightly:")
print(f" orig_ids: {orig_ids[0].tolist()}")
print(f" ee_ids: {ee_ids[0].tolist()}")
ids_match = orig_ids[0].tolist() == ee_ids[0].tolist()
print(f" Token IDs match: {ids_match}")
print(f"\n{'='*60}\n")
'''if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--original", required=True)
parser.add_argument("--ee", required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--prompt", default="Hello, how are you?")
args = parser.parse_args()
run_check(args.original, args.ee, args.seed, args.prompt)'''
if __name__ == "__main__":
original_name='Qwen/Qwen3-0.6B'
ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
seed = 424242
run_check(original_name, ee_name, seed, prompt="Hello, how are you?")
'''parser = argparse.ArgumentParser()
parser.add_argument("--original", required=True)
parser.add_argument("--ee", required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--prompt", default="Hello, how are you?")
args = parser.parse_args()
run_check(args.original, args.ee, args.seed, args.prompt)'''