| | """ |
| | EE Sanity Check |
| | Usage: |
| | python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242 |
| | """ |
| | import torch |
| | import numpy as np |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import argparse |
| |
|
| | def get_sigma(hidden_size, seed): |
| | rng = np.random.default_rng(seed) |
| | return rng.permutation(hidden_size) |
| |
|
| | def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"): |
| | print(f"\n{'='*60}") |
| | print(f"Original : {original_name}") |
| | print(f"EE model : {ee_name}") |
| | print(f"Seed : {seed}") |
| | print(f"Prompt : {prompt}") |
| | print('='*60) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True) |
| | inputs = tokenizer(prompt, return_tensors="pt") |
| | input_ids = inputs.input_ids |
| |
|
| | print("\n[1] Loading models...") |
| | orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True) |
| | ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True) |
| | orig.eval(); ee.eval() |
| |
|
| | hidden_size = orig.config.hidden_size |
| | sigma = get_sigma(hidden_size, seed) |
| | sigma_t = torch.tensor(sigma, dtype=torch.long) |
| |
|
| | |
| | embed_match = torch.allclose( |
| | orig.model.embed_tokens.weight.data, |
| | ee.model.embed_tokens.weight.data, atol=1e-3 |
| | ) |
| | print(f"\n[CHECK 1] Embed layers identical: {embed_match}") |
| | if not embed_match: |
| | print(" ⚠️ Embed was permuted — client-side encryption will be double-permuted") |
| |
|
| | |
| | print("\n[CHECK 2] Running plain forward on original...") |
| | with torch.no_grad(): |
| | plain_embeds = orig.model.embed_tokens(input_ids) |
| | orig_logits = orig(inputs_embeds=plain_embeds).logits |
| |
|
| | print("[CHECK 3] Running encrypted forward on EE model...") |
| | with torch.no_grad(): |
| | |
| | encrypted_embeds = plain_embeds[..., sigma_t] |
| | ee_logits = ee(inputs_embeds=encrypted_embeds).logits |
| |
|
| | |
| | max_diff = (orig_logits - ee_logits).abs().max().item() |
| | match = max_diff < 0.5 |
| | print(f"\n[CHECK 4] Logits match (atol=0.1): {match}") |
| | print(f" Max logit diff: {max_diff:.4f}") |
| | if not match: |
| | print(" ⚠️ Equivariance BROKEN") |
| |
|
| | |
| | |
| | |
| | |
| | print("\n[CHECK 5] Greedy decode comparison (10 tokens)...") |
| | with torch.no_grad(): |
| | orig_ids = orig.generate( |
| | inputs_embeds=plain_embeds, |
| | attention_mask=inputs.attention_mask, |
| | max_new_tokens=10, |
| | do_sample=False, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| | ee_ids = ee.generate( |
| | inputs_embeds=encrypted_embeds, |
| | attention_mask=inputs.attention_mask, |
| | max_new_tokens=10, |
| | do_sample=False, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| |
|
| | orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True) |
| | ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True) |
| | print(f" Original output : {repr(orig_text)}") |
| | print(f" EE model output : {repr(ee_text)}") |
| | print(f" Match: {orig_text == ee_text}") |
| |
|
| | if orig_text == ee_text: |
| | print("\n✅ All checks passed — EE transform is correct") |
| | else: |
| | print("\n⚠️ Text differs despite logits matching.") |
| | print(" This usually means floating point drift in autoregressive generation.") |
| | print(" Check if token IDs match even if decoded text differs slightly:") |
| | print(f" orig_ids: {orig_ids[0].tolist()}") |
| | print(f" ee_ids: {ee_ids[0].tolist()}") |
| | ids_match = orig_ids[0].tolist() == ee_ids[0].tolist() |
| | print(f" Token IDs match: {ids_match}") |
| |
|
| | print(f"\n{'='*60}\n") |
| |
|
| | '''if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--original", required=True) |
| | parser.add_argument("--ee", required=True) |
| | parser.add_argument("--seed", type=int, required=True) |
| | parser.add_argument("--prompt", default="Hello, how are you?") |
| | args = parser.parse_args() |
| | run_check(args.original, args.ee, args.seed, args.prompt)''' |
| |
|
| | if __name__ == "__main__": |
| |
|
| | original_name='Qwen/Qwen3-0.6B' |
| | ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee' |
| | seed = 424242 |
| | run_check(original_name, ee_name, seed, prompt="Hello, how are you?") |
| | '''parser = argparse.ArgumentParser() |
| | parser.add_argument("--original", required=True) |
| | parser.add_argument("--ee", required=True) |
| | parser.add_argument("--seed", type=int, required=True) |
| | parser.add_argument("--prompt", default="Hello, how are you?") |
| | args = parser.parse_args() |
| | run_check(args.original, args.ee, args.seed, args.prompt)''' |