broadfield-dev commited on
Commit
85036d4
Β·
verified Β·
1 Parent(s): 7f05023

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import numpy as np
5
+ import requests
6
+ import json
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ app = Flask(__name__)
10
+
11
+ _cache = {}
12
+
13
+
14
+ def get_sigma(hidden_size: int, seed: int):
15
+ rng = np.random.default_rng(seed)
16
+ sigma = rng.permutation(hidden_size)
17
+ sigma_inv = np.argsort(sigma)
18
+ return torch.tensor(sigma, dtype=torch.long), torch.tensor(sigma_inv, dtype=torch.long)
19
+
20
+
21
+ def load_client_components(ee_model_name: str):
22
+ """
23
+ Client holds:
24
+ - tokenizer (from original model)
25
+ - embed_tokens (original, unmodified)
26
+ - lm_head (original, unmodified)
27
+ - hidden_size
28
+
29
+ embed_tokens and lm_head never leave the client.
30
+ The server only has the transformer body with permuted weights.
31
+ sigma is derived from the seed β€” also never leaves the client.
32
+ """
33
+ if ee_model_name in _cache:
34
+ return _cache[ee_model_name]
35
+
36
+ config_path = hf_hub_download(ee_model_name, "ee_config.json")
37
+ with open(config_path) as f:
38
+ ee_config = json.load(f)
39
+
40
+ hidden_size = ee_config["hidden_size"]
41
+ original_model_name = ee_config["original_model"]
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True)
44
+
45
+ # Load original model to extract embed + lm_head, then discard the rest
46
+ original_model = AutoModelForCausalLM.from_pretrained(
47
+ original_model_name,
48
+ torch_dtype=torch.float32,
49
+ device_map="cpu",
50
+ trust_remote_code=True,
51
+ )
52
+ embed_layer = original_model.model.embed_tokens
53
+ lm_head = original_model.lm_head
54
+ final_norm = original_model.model.norm # final RMSNorm before lm_head
55
+ embed_layer.eval()
56
+ lm_head.eval()
57
+ final_norm.eval()
58
+ del original_model
59
+
60
+ _cache[ee_model_name] = (tokenizer, embed_layer, lm_head, final_norm, hidden_size)
61
+ return tokenizer, embed_layer, lm_head, final_norm, hidden_size
62
+
63
+
64
+ def generate_tokens(
65
+ server_url, tokenizer, embed_layer, lm_head, final_norm,
66
+ sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens
67
+ ):
68
+ """
69
+ Token-by-token generation loop:
70
+ 1. Client embeds current tokens β†’ applies sigma β†’ sends to server
71
+ 2. Server returns last hidden state (sigma-space) + KV cache
72
+ 3. Client applies sigma_inv β†’ runs final_norm + lm_head β†’ next token
73
+ 4. Repeat until eos or max_tokens
74
+ """
75
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
76
+ input_ids = inputs.input_ids
77
+ attention_mask = inputs.attention_mask
78
+
79
+ generated_ids = []
80
+ past_key_values = None
81
+
82
+ # First forward: send full prompt embeddings
83
+ with torch.no_grad():
84
+ plain_embeds = embed_layer(input_ids) # (1, seq, hidden)
85
+ encrypted_embeds = plain_embeds[..., sigma_t] # encrypt
86
+ encrypted_embeds = encrypted_embeds.to(torch.float16)
87
+
88
+ current_mask = attention_mask
89
+
90
+ for step in range(max_new_tokens):
91
+ payload = {
92
+ "inputs_embeds": encrypted_embeds.tolist(),
93
+ "attention_mask": current_mask.tolist(),
94
+ }
95
+ if past_key_values is not None:
96
+ payload["past_key_values"] = past_key_values
97
+
98
+ resp = requests.post(f"{server_url}/generate", json=payload, timeout=120)
99
+ if not resp.ok:
100
+ raise RuntimeError(f"Server {resp.status_code}: {resp.text[:400]}")
101
+
102
+ body = resp.json()
103
+ if "error" in body:
104
+ raise RuntimeError(f"Server error: {body['error']}")
105
+
106
+ # Decrypt: apply sigma_inv to get plain hidden state
107
+ last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden)
108
+ past_key_values = body["past_key_values"]
109
+
110
+ # Take only the last position
111
+ last_pos = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space
112
+ plain_hidden = last_pos[..., sigma_inv_t] # (1, 1, hidden) plain-space
113
+
114
+ # Client-side: final norm + lm_head β†’ logits
115
+ with torch.no_grad():
116
+ normed = final_norm(plain_hidden)
117
+ logits = lm_head(normed) # (1, 1, vocab)
118
+
119
+ next_token_id = logits[0, -1, :].argmax().item()
120
+ generated_ids.append(next_token_id)
121
+
122
+ if next_token_id == tokenizer.eos_token_id:
123
+ break
124
+
125
+ # Prepare next step: embed + encrypt the single new token
126
+ next_id_tensor = torch.tensor([[next_token_id]])
127
+ with torch.no_grad():
128
+ next_plain_embed = embed_layer(next_id_tensor) # (1, 1, hidden)
129
+ encrypted_embeds = next_plain_embed[..., sigma_t].to(torch.float16)
130
+
131
+ # Extend attention mask by 1
132
+ current_mask = torch.ones(1, 1, dtype=attention_mask.dtype)
133
+
134
+ return generated_ids
135
+
136
+
137
+ @app.route("/", methods=["GET", "POST"])
138
+ def index():
139
+ result = None
140
+ error = None
141
+ form_data = {}
142
+
143
+ if request.method == "POST":
144
+ form_data = request.form.to_dict()
145
+ server_url = request.form["server_url"].rstrip("/")
146
+ ee_model_name = request.form["ee_model_name"].strip()
147
+ ee_seed = int(request.form["ee_seed"])
148
+ prompt = request.form["prompt"].strip()
149
+ max_tokens = int(request.form.get("max_tokens", 256))
150
+
151
+ try:
152
+ tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
153
+ load_client_components(ee_model_name)
154
+
155
+ sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed)
156
+
157
+ # Apply chat template
158
+ messages = [{"role": "user", "content": prompt}]
159
+ formatted = tokenizer.apply_chat_template(
160
+ messages, tokenize=False, add_generation_prompt=True
161
+ )
162
+
163
+ gen_ids = generate_tokens(
164
+ server_url, tokenizer, embed_layer, lm_head, final_norm,
165
+ sigma_t, sigma_inv_t, formatted, max_tokens
166
+ )
167
+
168
+ result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
169
+
170
+ except RuntimeError as e:
171
+ error = str(e)
172
+ except requests.exceptions.ConnectionError:
173
+ error = f"Could not connect to {server_url} β€” is the server Space running?"
174
+ except Exception as e:
175
+ error = f"{type(e).__name__}: {e}"
176
+
177
+ return render_template("client.html", result=result, error=error, form=form_data)
178
+
179
+
180
+ if __name__ == "__main__":
181
+ app.run(host="0.0.0.0", port=7860)