broadfield-dev commited on
Commit
02f6c65
·
verified ·
1 Parent(s): 9272618

Create debug_ee.py

Browse files
Files changed (1) hide show
  1. debug_ee.py +117 -0
debug_ee.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EE Sanity Check Script
3
+ Run this locally (not on HF Spaces) to verify the transform is correct.
4
+
5
+ Usage:
6
+ python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
7
+ """
8
+ import torch
9
+ import numpy as np
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ import argparse
12
+
13
+ def get_sigma(hidden_size, seed):
14
+ rng = np.random.default_rng(seed)
15
+ sigma = rng.permutation(hidden_size)
16
+ sigma_inv = np.argsort(sigma)
17
+ return sigma, sigma_inv
18
+
19
+ def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
20
+ print(f"\n{'='*60}")
21
+ print(f"Original : {original_name}")
22
+ print(f"EE model : {ee_name}")
23
+ print(f"Seed : {seed}")
24
+ print(f"Prompt : {prompt}")
25
+ print('='*60)
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
28
+ inputs = tokenizer(prompt, return_tensors="pt")
29
+ input_ids = inputs.input_ids
30
+
31
+ print("\n[1] Loading original model...")
32
+ orig = AutoModelForCausalLM.from_pretrained(
33
+ original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
34
+ )
35
+ orig.eval()
36
+
37
+ print("[2] Loading EE model...")
38
+ ee = AutoModelForCausalLM.from_pretrained(
39
+ ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
40
+ )
41
+ ee.eval()
42
+
43
+ hidden_size = orig.config.hidden_size
44
+ sigma, sigma_inv = get_sigma(hidden_size, seed)
45
+
46
+ # --- Check 1: Does the EE embed layer match original? ---
47
+ orig_embed = orig.model.embed_tokens.weight.data
48
+ ee_embed = ee.model.embed_tokens.weight.data
49
+ embed_match = torch.allclose(orig_embed, ee_embed, atol=1e-3)
50
+ print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
51
+ if not embed_match:
52
+ diff = (orig_embed - ee_embed).abs().max().item()
53
+ print(f" ⚠️ Max diff: {diff:.6f} — EE embed was permuted, this BREAKS client-side encryption")
54
+ print(f" → Re-run transform with the embed layer skipped (see transform_fix.py)")
55
+
56
+ # --- Check 2: Run plain forward on original ---
57
+ print("\n[CHECK 2] Running plain forward on original...")
58
+ with torch.no_grad():
59
+ plain_embeds = orig.model.embed_tokens(input_ids)
60
+ orig_out = orig(inputs_embeds=plain_embeds, output_hidden_states=False)
61
+ orig_logits = orig_out.logits # (1, seq, vocab)
62
+
63
+ # --- Check 3: Run encrypted forward on EE model ---
64
+ print("[CHECK 3] Running encrypted forward on EE model...")
65
+ with torch.no_grad():
66
+ encrypted_embeds = plain_embeds[..., sigma]
67
+ ee_out = ee(inputs_embeds=encrypted_embeds, output_hidden_states=False)
68
+ ee_logits = ee_out.logits
69
+
70
+ # --- Check 4: Do logits match? ---
71
+ logit_match = torch.allclose(orig_logits, ee_logits, atol=1e-1)
72
+ max_diff = (orig_logits - ee_logits).abs().max().item()
73
+ print(f"\n[CHECK 4] Logits match (atol=0.1): {logit_match}")
74
+ print(f" Max logit diff: {max_diff:.4f}")
75
+ if not logit_match:
76
+ print(" ⚠️ Logits differ — equivariance is BROKEN")
77
+ # Find where it breaks — check RoPE suspicion
78
+ print("\n Diagnosing: checking if RoPE is the culprit...")
79
+ print(" RoPE applies rotation in head_dim space (64), not hidden space (1024)")
80
+ print(" If q_proj/k_proj output is permuted (because output==hidden_size),")
81
+ print(" the head_dim slices fed to RoPE will be scrambled → broken attention")
82
+
83
+ # --- Check 5: Greedy decode comparison ---
84
+ print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
85
+ with torch.no_grad():
86
+ orig_ids = orig.generate(input_ids, max_new_tokens=10, do_sample=False)
87
+ ee_ids = ee.generate(inputs_embeds=encrypted_embeds,
88
+ attention_mask=inputs.attention_mask,
89
+ max_new_tokens=10, do_sample=False,
90
+ pad_token_id=tokenizer.eos_token_id)
91
+
92
+ orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
93
+ ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True)
94
+ print(f" Original output : {repr(orig_text)}")
95
+ print(f" EE model output : {repr(ee_text)}")
96
+ print(f" Match: {orig_text == ee_text}")
97
+
98
+ if orig_text != ee_text:
99
+ print("\n ⚠️ OUTPUTS DIFFER. Most likely causes in order:")
100
+ print(" 1. Embed layer was permuted in EE model (Check 1 above)")
101
+ print(" 2. RoPE disruption — q_proj/k_proj output rows were permuted")
102
+ print(" FIX: do NOT permute output rows of q_proj and k_proj")
103
+ print(" because their outputs are split into heads for RoPE rotation")
104
+ print(" 3. Model on Hub is stale — re-run transform and re-push")
105
+
106
+ print(f"\n{'='*60}\n")
107
+ return embed_match and logit_match
108
+
109
+
110
+ if __name__ == "__main__":
111
+ parser = argparse.ArgumentParser()
112
+ parser.add_argument("--original", required=True)
113
+ parser.add_argument("--ee", required=True)
114
+ parser.add_argument("--seed", type=int, required=True)
115
+ parser.add_argument("--prompt", default="Hello, how are you?")
116
+ args = parser.parse_args()
117
+ run_check(args.original, args.ee, args.seed, args.prompt)