broadfield-dev commited on
Commit
1c023ec
·
verified ·
1 Parent(s): 41550e9

Update debug_ee.py

Browse files
Files changed (1) hide show
  1. debug_ee.py +71 -74
debug_ee.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- EE Sanity Check + Layer Diagnostics
 
3
  Usage:
4
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
5
  """
@@ -16,97 +17,93 @@ def get_sigma(hidden_size, seed):
16
 
17
  def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
18
  print(f"\n{'='*60}")
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
20
  inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
21
 
22
- print("[1] Loading models...")
23
- orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
24
- ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
25
- orig.eval(); ee.eval()
 
26
 
27
  hidden_size = orig.config.hidden_size
28
  sigma, sigma_inv = get_sigma(hidden_size, seed)
29
- print(f"hidden_size={hidden_size}, seed={seed}")
30
 
31
- # --- CHECK 1: Embed layers ---
32
- embed_match = torch.allclose(orig.model.embed_tokens.weight.data, ee.model.embed_tokens.weight.data, atol=1e-3)
 
 
33
  print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # --- LAYER DIFF: print every layer that differs and HOW ---
36
- print("\n[LAYER DIFF] Comparing every named parameter...")
37
- ROPE_OUTPUT_LAYERS = {"q_proj", "k_proj"}
38
- issues = []
39
- for (name_o, param_o), (name_e, param_e) in zip(orig.named_parameters(), ee.named_parameters()):
40
- assert name_o == name_e
41
- if torch.allclose(param_o.data, param_e.data, atol=1e-3):
42
- continue # unchanged — skip
43
-
44
- basename = name_o.split(".")[-1] # "weight", "bias"
45
- layer = name_o.split(".")[-2] # "q_proj", "embed_tokens", etc.
46
- shape = tuple(param_o.shape)
47
-
48
- # Check what the transform DID to this param
49
- changed_cols = changed_rows = False
50
- if param_o.dim() == 2:
51
- if not torch.allclose(param_o.data, param_e.data[:, np.argsort(sigma_inv)], atol=1e-3):
52
- pass
53
- # Did it permute cols?
54
- reconstructed_cols = param_e.data[:, np.argsort(sigma_inv)]
55
- changed_cols = torch.allclose(param_o.data, reconstructed_cols, atol=1e-3)
56
- # Did it permute rows?
57
- reconstructed_rows = param_e.data[np.argsort(sigma_inv), :]
58
- changed_rows = torch.allclose(param_o.data, reconstructed_rows, atol=1e-3)
59
- # Did it permute both?
60
- reconstructed_both = param_e.data[np.argsort(sigma_inv), :][:, np.argsort(sigma_inv)]
61
- changed_both = torch.allclose(param_o.data, reconstructed_both, atol=1e-3)
62
-
63
- what = []
64
- if changed_both: what = ["BOTH rows+cols"]
65
- elif changed_cols: what = ["cols only"]
66
- elif changed_rows: what = ["rows only"]
67
- else: what = ["UNKNOWN permutation"]
68
-
69
- flag = ""
70
- if layer in ROPE_OUTPUT_LAYERS and ("BOTH" in what[0] or "rows" in what[0]):
71
- flag = " ⚠️ BAD: RoPE layer has rows permuted!"
72
- issues.append(f"{name_o}: rows permuted on RoPE layer")
73
- elif layer not in ROPE_OUTPUT_LAYERS and shape[0] == hidden_size and shape[1] == hidden_size and "BOTH" not in what[0]:
74
- flag = f" ⚠️ BAD: square hidden layer should have BOTH permuted"
75
- issues.append(f"{name_o}: square layer missing full permutation")
76
-
77
- print(f" {layer:20s} {str(shape):20s} → {what[0]}{flag}")
78
-
79
- elif param_o.dim() == 1:
80
- print(f" {layer:20s} {str(shape):20s} → 1D (norm/bias)")
81
-
82
- # --- CHECK 4: Logits ---
83
- print("\n[CHECK 4] Equivariance test...")
84
  with torch.no_grad():
85
- plain_embeds = orig.model.embed_tokens(inputs.input_ids)
86
  encrypted_embeds = plain_embeds[..., sigma]
87
- orig_logits = orig(inputs_embeds=plain_embeds).logits
88
- ee_logits = ee(inputs_embeds=encrypted_embeds).logits
89
 
 
 
90
  max_diff = (orig_logits - ee_logits).abs().max().item()
91
- match = max_diff < 0.5
92
- print(f" Max logit diff: {max_diff:.4f} → {'✅ PASS' if match else '❌ FAIL'}")
93
-
94
- # --- CHECK 5: Decode ---
95
- print("\n[CHECK 5] Greedy decode (10 tokens)...")
 
 
 
 
 
 
 
96
  with torch.no_grad():
97
- orig_ids = orig.generate(inputs.input_ids, max_new_tokens=10, do_sample=False)
98
  ee_ids = ee.generate(inputs_embeds=encrypted_embeds,
99
  attention_mask=inputs.attention_mask,
100
  max_new_tokens=10, do_sample=False,
101
  pad_token_id=tokenizer.eos_token_id)
102
- print(f" Original : {repr(tokenizer.decode(orig_ids[0], skip_special_tokens=True))}")
103
- print(f" EE model : {repr(tokenizer.decode(ee_ids[0], skip_special_tokens=True))}")
104
-
105
- if issues:
106
- print(f"\n⚠️ {len(issues)} issue(s) found:")
107
- for i in issues: print(f" - {i}")
108
- else:
109
- print("\n✅ No layer issues detected")
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
 
 
1
  """
2
+ EE Sanity Check Script
3
+ Run this locally (not on HF Spaces) to verify the transform is correct.
4
  Usage:
5
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
6
  """
 
17
 
18
  def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
19
  print(f"\n{'='*60}")
20
+ print(f"Original : {original_name}")
21
+ print(f"EE model : {ee_name}")
22
+ print(f"Seed : {seed}")
23
+ print(f"Prompt : {prompt}")
24
+ print('='*60)
25
+
26
  tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
27
  inputs = tokenizer(prompt, return_tensors="pt")
28
+ input_ids = inputs.input_ids
29
+
30
+ print("\n[1] Loading original model...")
31
+ orig = AutoModelForCausalLM.from_pretrained(
32
+ original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
33
+ )
34
+ orig.eval()
35
 
36
+ print("[2] Loading EE model...")
37
+ ee = AutoModelForCausalLM.from_pretrained(
38
+ ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
39
+ )
40
+ ee.eval()
41
 
42
  hidden_size = orig.config.hidden_size
43
  sigma, sigma_inv = get_sigma(hidden_size, seed)
 
44
 
45
+ # --- Check 1: Does the EE embed layer match original? ---
46
+ orig_embed = orig.model.embed_tokens.weight.data
47
+ ee_embed = ee.model.embed_tokens.weight.data
48
+ embed_match = torch.allclose(orig_embed, ee_embed, atol=1e-3)
49
  print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
50
+ if not embed_match:
51
+ diff = (orig_embed - ee_embed).abs().max().item()
52
+ print(f" ⚠️ Max diff: {diff:.6f} — EE embed was permuted, this BREAKS client-side encryption")
53
+ print(f" → Re-run transform with the embed layer skipped (see transform_fix.py)")
54
+
55
+ # --- Check 2: Run plain forward on original ---
56
+ print("\n[CHECK 2] Running plain forward on original...")
57
+ with torch.no_grad():
58
+ plain_embeds = orig.model.embed_tokens(input_ids)
59
+ orig_out = orig(inputs_embeds=plain_embeds, output_hidden_states=False)
60
+ orig_logits = orig_out.logits # (1, seq, vocab)
61
 
62
+ # --- Check 3: Run encrypted forward on EE model ---
63
+ print("[CHECK 3] Running encrypted forward on EE model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with torch.no_grad():
 
65
  encrypted_embeds = plain_embeds[..., sigma]
66
+ ee_out = ee(inputs_embeds=encrypted_embeds, output_hidden_states=False)
67
+ ee_logits = ee_out.logits
68
 
69
+ # --- Check 4: Do logits match? ---
70
+ logit_match = torch.allclose(orig_logits, ee_logits, atol=1e-1)
71
  max_diff = (orig_logits - ee_logits).abs().max().item()
72
+ print(f"\n[CHECK 4] Logits match (atol=0.1): {logit_match}")
73
+ print(f" Max logit diff: {max_diff:.4f}")
74
+ if not logit_match:
75
+ print(" ⚠️ Logits differ equivariance is BROKEN")
76
+ # Find where it breaks — check RoPE suspicion
77
+ print("\n Diagnosing: checking if RoPE is the culprit...")
78
+ print(" RoPE applies rotation in head_dim space (64), not hidden space (1024)")
79
+ print(" If q_proj/k_proj output is permuted (because output==hidden_size),")
80
+ print(" the head_dim slices fed to RoPE will be scrambled → broken attention")
81
+
82
+ # --- Check 5: Greedy decode comparison ---
83
+ print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
84
  with torch.no_grad():
85
+ orig_ids = orig.generate(input_ids, max_new_tokens=10, do_sample=False)
86
  ee_ids = ee.generate(inputs_embeds=encrypted_embeds,
87
  attention_mask=inputs.attention_mask,
88
  max_new_tokens=10, do_sample=False,
89
  pad_token_id=tokenizer.eos_token_id)
90
+
91
+ orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
92
+ ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True)
93
+ print(f" Original output : {repr(orig_text)}")
94
+ print(f" EE model output : {repr(ee_text)}")
95
+ print(f" Match: {orig_text == ee_text}")
96
+
97
+ if orig_text != ee_text:
98
+ print("\n ⚠️ OUTPUTS DIFFER. Most likely causes in order:")
99
+ print(" 1. Embed layer was permuted in EE model (Check 1 above)")
100
+ print(" 2. RoPE disruption — q_proj/k_proj output rows were permuted")
101
+ print(" FIX: do NOT permute output rows of q_proj and k_proj")
102
+ print(" because their outputs are split into heads for RoPE rotation")
103
+ print(" 3. Model on Hub is stale — re-run transform and re-push")
104
+
105
+ print(f"\n{'='*60}\n")
106
+ return embed_match and logit_match
107
 
108
  if __name__ == "__main__":
109