broadfield-dev commited on
Commit
eae97bc
Β·
verified Β·
1 Parent(s): 995d4dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -56
app.py CHANGED
@@ -7,7 +7,6 @@ import json
7
  from huggingface_hub import hf_hub_download
8
 
9
  app = Flask(__name__)
10
-
11
  _cache = {}
12
 
13
 
@@ -19,17 +18,6 @@ def get_sigma(hidden_size: int, seed: int):
19
 
20
 
21
  def load_client_components(ee_model_name: str):
22
- """
23
- Client holds:
24
- - tokenizer (from original model)
25
- - embed_tokens (original, unmodified)
26
- - lm_head (original, unmodified)
27
- - hidden_size
28
-
29
- embed_tokens and lm_head never leave the client.
30
- The server only has the transformer body with permuted weights.
31
- sigma is derived from the seed β€” also never leaves the client.
32
- """
33
  if ee_model_name in _cache:
34
  return _cache[ee_model_name]
35
 
@@ -42,7 +30,6 @@ def load_client_components(ee_model_name: str):
42
 
43
  tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True)
44
 
45
- # Load original model to extract embed + lm_head, then discard the rest
46
  original_model = AutoModelForCausalLM.from_pretrained(
47
  original_model_name,
48
  torch_dtype=torch.float32,
@@ -51,7 +38,7 @@ def load_client_components(ee_model_name: str):
51
  )
52
  embed_layer = original_model.model.embed_tokens
53
  lm_head = original_model.lm_head
54
- final_norm = original_model.model.norm # final RMSNorm before lm_head
55
  embed_layer.eval()
56
  lm_head.eval()
57
  final_norm.eval()
@@ -61,39 +48,37 @@ def load_client_components(ee_model_name: str):
61
  return tokenizer, embed_layer, lm_head, final_norm, hidden_size
62
 
63
 
64
- def generate_tokens(
65
- server_url, tokenizer, embed_layer, lm_head, final_norm,
66
- sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens
67
- ):
68
  """
69
- Token-by-token generation loop:
70
- 1. Client embeds current tokens β†’ applies sigma β†’ sends to server
71
- 2. Server returns last hidden state (sigma-space) + KV cache
72
- 3. Client applies sigma_inv β†’ runs final_norm + lm_head β†’ next token
73
- 4. Repeat until eos or max_tokens
 
 
 
74
  """
75
  inputs = tokenizer(formatted_prompt, return_tensors="pt")
76
- input_ids = inputs.input_ids
77
- attention_mask = inputs.attention_mask
78
-
79
- generated_ids = []
80
- past_key_values = None
81
 
82
- # First forward: send full prompt embeddings
83
  with torch.no_grad():
84
- plain_embeds = embed_layer(input_ids) # (1, seq, hidden)
85
- encrypted_embeds = plain_embeds[..., sigma_t] # encrypt
86
- encrypted_embeds = encrypted_embeds.to(torch.float16)
87
 
88
- current_mask = attention_mask
89
 
90
  for step in range(max_new_tokens):
 
 
 
 
 
91
  payload = {
92
- "inputs_embeds": encrypted_embeds.tolist(),
93
- "attention_mask": current_mask.tolist(),
94
  }
95
- if past_key_values is not None:
96
- payload["past_key_values"] = past_key_values
97
 
98
  resp = requests.post(f"{server_url}/generate", json=payload, timeout=120)
99
  if not resp.ok:
@@ -103,18 +88,15 @@ def generate_tokens(
103
  if "error" in body:
104
  raise RuntimeError(f"Server error: {body['error']}")
105
 
106
- # Decrypt: apply sigma_inv to get plain hidden state
107
  last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden)
108
- past_key_values = body.get("past_key_values") # may be None
 
109
 
110
- # Take only the last position
111
- last_pos = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space
112
- plain_hidden = last_pos[..., sigma_inv_t] # (1, 1, hidden) plain-space
113
-
114
- # Client-side: final norm + lm_head β†’ logits
115
  with torch.no_grad():
116
- normed = final_norm(plain_hidden)
117
- logits = lm_head(normed) # (1, 1, vocab)
118
 
119
  next_token_id = logits[0, -1, :].argmax().item()
120
  generated_ids.append(next_token_id)
@@ -122,14 +104,11 @@ def generate_tokens(
122
  if next_token_id == tokenizer.eos_token_id:
123
  break
124
 
125
- # Prepare next step: embed + encrypt the single new token
126
  next_id_tensor = torch.tensor([[next_token_id]])
127
  with torch.no_grad():
128
- next_plain_embed = embed_layer(next_id_tensor) # (1, 1, hidden)
129
- encrypted_embeds = next_plain_embed[..., sigma_t].to(torch.float16)
130
-
131
- # Extend attention mask by 1
132
- current_mask = torch.ones(1, 1, dtype=attention_mask.dtype)
133
 
134
  return generated_ids
135
 
@@ -137,11 +116,11 @@ def generate_tokens(
137
  @app.route("/", methods=["GET", "POST"])
138
  def index():
139
  result = None
140
- error = None
141
  form_data = {}
142
 
143
  if request.method == "POST":
144
- form_data = request.form.to_dict()
145
  server_url = request.form["server_url"].rstrip("/")
146
  ee_model_name = request.form["ee_model_name"].strip()
147
  ee_seed = int(request.form["ee_seed"])
@@ -154,8 +133,7 @@ def index():
154
 
155
  sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed)
156
 
157
- # Apply chat template
158
- messages = [{"role": "user", "content": prompt}]
159
  formatted = tokenizer.apply_chat_template(
160
  messages, tokenize=False, add_generation_prompt=True
161
  )
 
7
  from huggingface_hub import hf_hub_download
8
 
9
  app = Flask(__name__)
 
10
  _cache = {}
11
 
12
 
 
18
 
19
 
20
  def load_client_components(ee_model_name: str):
 
 
 
 
 
 
 
 
 
 
 
21
  if ee_model_name in _cache:
22
  return _cache[ee_model_name]
23
 
 
30
 
31
  tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True)
32
 
 
33
  original_model = AutoModelForCausalLM.from_pretrained(
34
  original_model_name,
35
  torch_dtype=torch.float32,
 
38
  )
39
  embed_layer = original_model.model.embed_tokens
40
  lm_head = original_model.lm_head
41
+ final_norm = original_model.model.norm
42
  embed_layer.eval()
43
  lm_head.eval()
44
  final_norm.eval()
 
48
  return tokenizer, embed_layer, lm_head, final_norm, hidden_size
49
 
50
 
51
+ def generate_tokens(server_url, tokenizer, embed_layer, lm_head, final_norm,
52
+ sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens):
 
 
53
  """
54
+ Token-by-token generation. No KV cache β€” client accumulates all embeddings
55
+ and sends the full growing sequence each step.
56
+
57
+ Each step:
58
+ 1. Encrypt all token embeddings so far with sigma
59
+ 2. Send to server β†’ get back last hidden state (sigma-space)
60
+ 3. Decrypt last position: apply sigma_inv
61
+ 4. Run final_norm + lm_head locally β†’ next token
62
  """
63
  inputs = tokenizer(formatted_prompt, return_tensors="pt")
64
+ input_ids = inputs.input_ids # (1, seq_len)
 
 
 
 
65
 
66
+ # Build initial encrypted embeddings for full prompt
67
  with torch.no_grad():
68
+ all_plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
 
 
69
 
70
+ generated_ids = []
71
 
72
  for step in range(max_new_tokens):
73
+ # Encrypt the full sequence so far
74
+ all_encrypted = all_plain_embeds[..., sigma_t].to(torch.float16) # (1, seq, hidden)
75
+ seq_len = all_encrypted.shape[1]
76
+ attention_mask = torch.ones(1, seq_len, dtype=torch.long)
77
+
78
  payload = {
79
+ "inputs_embeds": all_encrypted.tolist(),
80
+ "attention_mask": attention_mask.tolist(),
81
  }
 
 
82
 
83
  resp = requests.post(f"{server_url}/generate", json=payload, timeout=120)
84
  if not resp.ok:
 
88
  if "error" in body:
89
  raise RuntimeError(f"Server error: {body['error']}")
90
 
91
+ # Decrypt last position only
92
  last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden)
93
+ last_pos_sigma = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space
94
+ last_pos_plain = last_pos_sigma[..., sigma_inv_t] # (1, 1, hidden) plain-space
95
 
96
+ # Client-side: final norm + lm_head β†’ next token
 
 
 
 
97
  with torch.no_grad():
98
+ normed = final_norm(last_pos_plain)
99
+ logits = lm_head(normed) # (1, 1, vocab)
100
 
101
  next_token_id = logits[0, -1, :].argmax().item()
102
  generated_ids.append(next_token_id)
 
104
  if next_token_id == tokenizer.eos_token_id:
105
  break
106
 
107
+ # Append new token's plain embedding to the growing sequence
108
  next_id_tensor = torch.tensor([[next_token_id]])
109
  with torch.no_grad():
110
+ next_embed = embed_layer(next_id_tensor) # (1, 1, hidden)
111
+ all_plain_embeds = torch.cat([all_plain_embeds, next_embed], dim=1)
 
 
 
112
 
113
  return generated_ids
114
 
 
116
  @app.route("/", methods=["GET", "POST"])
117
  def index():
118
  result = None
119
+ error = None
120
  form_data = {}
121
 
122
  if request.method == "POST":
123
+ form_data = request.form.to_dict()
124
  server_url = request.form["server_url"].rstrip("/")
125
  ee_model_name = request.form["ee_model_name"].strip()
126
  ee_seed = int(request.form["ee_seed"])
 
133
 
134
  sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed)
135
 
136
+ messages = [{"role": "user", "content": prompt}]
 
137
  formatted = tokenizer.apply_chat_template(
138
  messages, tokenize=False, add_generation_prompt=True
139
  )