broadfield-dev commited on
Commit
9272618
Β·
verified Β·
1 Parent(s): 07ee289

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -42
app.py CHANGED
@@ -8,28 +8,30 @@ from huggingface_hub import hf_hub_download
8
 
9
  app = Flask(__name__)
10
 
11
- # Cache per EE model name so repeated requests don't re-download
12
  _cache = {}
13
 
14
 
15
- def get_sigma(hidden_size: int, seed: int):
16
- """Derive the hidden-dimension permutation from the secret seed."""
 
 
 
 
17
  rng = np.random.default_rng(seed)
18
  return rng.permutation(hidden_size)
19
 
20
 
21
  def load_client_components(ee_model_name: str):
22
  """
23
- Load and cache everything the client needs:
24
  - ee_config β†’ hidden_size + original model name
25
- - tokenizer β†’ from the EE model
26
- - embed_layer β†’ from the ORIGINAL (unmodified) model
27
-
28
- Why we need the original embed layer:
29
- The EE model's weights were permuted with sigma, but its embedding table was
30
- NOT permuted (it maps token IDs β†’ plain vectors). The client must embed with
31
- the original model and then apply sigma to produce the scrambled vectors the
32
- EE model expects.
33
  """
34
  if ee_model_name in _cache:
35
  return _cache[ee_model_name]
@@ -43,16 +45,16 @@ def load_client_components(ee_model_name: str):
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
45
 
46
- # We only need embed_tokens β€” load the full model then discard everything else
47
  original_model = AutoModelForCausalLM.from_pretrained(
48
  original_model_name,
49
- torch_dtype=torch.float32, # float32 for precision on CPU
50
  device_map="cpu",
51
  trust_remote_code=True,
52
  )
53
  embed_layer = original_model.model.embed_tokens
54
  embed_layer.eval()
55
- del original_model # free RAM β€” we only keep the embed layer
56
 
57
  _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
58
  return tokenizer, embed_layer, hidden_size
@@ -68,57 +70,48 @@ def index():
68
  form_data = request.form.to_dict()
69
  server_url = request.form["server_url"].rstrip("/")
70
  ee_model_name = request.form["ee_model_name"].strip()
71
- ee_seed = int(request.form["ee_seed"])
72
  prompt = request.form["prompt"].strip()
73
  max_tokens = int(request.form.get("max_tokens", 256))
74
 
75
  try:
76
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
77
 
78
- # --- Step 1: tokenize ---
 
 
79
  inputs = tokenizer(prompt, return_tensors="pt")
80
- input_ids = inputs.input_ids # (1, seq_len)
81
 
82
- # --- Step 2: embed with ORIGINAL model's embed layer ---
83
  with torch.no_grad():
84
- plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
85
 
86
- # --- Step 3: ENCRYPT β€” permute hidden dim with secret sigma ---
87
- # The EE model's weight matrices were pre-permuted with sigma,
88
- # so feeding sigma-permuted embeddings is equivalent to feeding
89
- # plain embeddings to the original model.
90
  sigma = get_sigma(hidden_size, ee_seed)
91
- encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
92
-
93
- # Match server model dtype (float16)
94
  encrypted_embeds = encrypted_embeds.to(torch.float16)
95
 
96
- # --- Step 4: send to server ---
97
  payload = {
98
  "encrypted_embeds": encrypted_embeds.tolist(),
99
- "attention_mask": inputs.attention_mask.tolist(),
100
- "max_new_tokens": max_tokens,
101
  }
102
 
103
- resp = requests.post(
104
- f"{server_url}/generate",
105
- json=payload,
106
- timeout=300,
107
- )
108
 
109
  if not resp.ok:
110
- raise RuntimeError(
111
- f"Server {resp.status_code}: {resp.text[:600]}"
112
- )
113
 
114
  body = resp.json()
115
  if "error" in body:
116
  raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")
117
 
118
- # --- Step 5: decode ---
119
- # No decryption needed on the output β€” the EE model's lm_head was
120
- # also permuted so output logits correctly map to the real vocabulary.
121
- # We skip special tokens and strip the prompt echo if present.
122
  gen_ids = body["generated_ids"]
123
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
124
 
 
8
 
9
  app = Flask(__name__)
10
 
 
11
  _cache = {}
12
 
13
 
14
+ def get_sigma(hidden_size: int, seed: int) -> np.ndarray:
15
+ """
16
+ Derive the encryption permutation from the secret seed.
17
+ This is the CLIENT'S secret key β€” it never leaves this Space.
18
+ The server only ever sees embeddings already scrambled by sigma.
19
+ """
20
  rng = np.random.default_rng(seed)
21
  return rng.permutation(hidden_size)
22
 
23
 
24
  def load_client_components(ee_model_name: str):
25
  """
26
+ Load and cache:
27
  - ee_config β†’ hidden_size + original model name
28
+ - tokenizer β†’ from EE model
29
+ - embed_layer β†’ from the ORIGINAL (untransformed) model
30
+
31
+ The original embed_layer is used to produce plain vectors from token IDs.
32
+ The client then applies sigma to those plain vectors before sending.
33
+ The server's EE model has weights permuted with sigma_inv, so:
34
+ EE_model(sigma(plain_embed(tokens))) == original_model(plain_embed(tokens))
 
35
  """
36
  if ee_model_name in _cache:
37
  return _cache[ee_model_name]
 
45
 
46
  tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
47
 
48
+ # Load ORIGINAL model just for its embed layer β€” discard everything else
49
  original_model = AutoModelForCausalLM.from_pretrained(
50
  original_model_name,
51
+ torch_dtype=torch.float32,
52
  device_map="cpu",
53
  trust_remote_code=True,
54
  )
55
  embed_layer = original_model.model.embed_tokens
56
  embed_layer.eval()
57
+ del original_model
58
 
59
  _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
60
  return tokenizer, embed_layer, hidden_size
 
70
  form_data = request.form.to_dict()
71
  server_url = request.form["server_url"].rstrip("/")
72
  ee_model_name = request.form["ee_model_name"].strip()
73
+ ee_seed = int(request.form["ee_seed"]) # SECRET β€” client only
74
  prompt = request.form["prompt"].strip()
75
  max_tokens = int(request.form.get("max_tokens", 256))
76
 
77
  try:
78
  tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
79
 
80
+ # --- CLIENT-SIDE ENCRYPTION ---
81
+
82
+ # Step 1: tokenize
83
  inputs = tokenizer(prompt, return_tensors="pt")
 
84
 
85
+ # Step 2: embed with ORIGINAL model embed layer β†’ plain vectors
86
  with torch.no_grad():
87
+ plain_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
88
 
89
+ # Step 3: apply sigma permutation β€” this is the encryption
90
+ # The server NEVER sees plain_embeds, only the scrambled version.
91
+ # Without knowing the seed, the server cannot recover the original.
 
92
  sigma = get_sigma(hidden_size, ee_seed)
93
+ encrypted_embeds = plain_embeds[..., sigma] # (1, seq_len, hidden)
 
 
94
  encrypted_embeds = encrypted_embeds.to(torch.float16)
95
 
96
+ # --- SEND TO SERVER ---
97
  payload = {
98
  "encrypted_embeds": encrypted_embeds.tolist(),
99
+ "attention_mask": inputs.attention_mask.tolist(),
100
+ "max_new_tokens": max_tokens,
101
  }
102
 
103
+ resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
 
 
 
 
104
 
105
  if not resp.ok:
106
+ raise RuntimeError(f"Server {resp.status_code}: {resp.text[:600]}")
 
 
107
 
108
  body = resp.json()
109
  if "error" in body:
110
  raise RuntimeError(f"Server error: {body['error']}\n{body.get('traceback','')}")
111
 
112
+ # --- OUTPUT DECODING ---
113
+ # The EE model's lm_head rows are permuted with sigma_inv, so output
114
+ # logits correctly index the real vocabulary β€” decode normally.
 
115
  gen_ids = body["generated_ids"]
116
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
117