broadfield-dev commited on
Commit
9737a84
·
verified ·
1 Parent(s): 6ad0d9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -36
app.py CHANGED
@@ -8,73 +8,101 @@ from huggingface_hub import hf_hub_download
8
 
9
  app = Flask(__name__)
10
 
 
 
 
 
11
  def get_sigma(hidden_size: int, seed: int):
12
- """Client-side encryption key from secret seed"""
13
  rng = np.random.default_rng(seed)
14
- sigma = rng.permutation(hidden_size)
15
- return sigma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.route("/", methods=["GET", "POST"])
18
  def index():
19
  result = None
20
  error = None
 
21
 
22
  if request.method == "POST":
23
- server_url = request.form["server_url"].rstrip("/")
 
24
  ee_model_name = request.form["ee_model_name"].strip()
25
- ee_seed = int(request.form["ee_seed"])
26
- prompt = request.form["prompt"].strip()
27
- max_tokens = int(request.form.get("max_tokens", 256))
28
 
29
  try:
30
- # 1. Load config to know hidden_size + original model
31
- config_path = hf_hub_download(ee_model_name, "ee_config.json")
32
- with open(config_path) as f:
33
- ee_config = json.load(f)
34
- hidden_size = ee_config["hidden_size"]
35
- original_model_name = ee_config["original_model"]
36
-
37
- # 2. Generate encryption permutation (this is your secret key in action)
38
- sigma = get_sigma(hidden_size, ee_seed)
39
 
40
- # 3. Load tokenizer
41
- tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
42
-
43
- # 4. Load ORIGINAL (clean) embedding layer
44
- embed_model = AutoModelForCausalLM.from_pretrained(
45
- original_model_name,
46
- torch_dtype=torch.float16,
47
- device_map="cpu",
48
- trust_remote_code=True
49
- )
50
- embed_layer = embed_model.model.embed_tokens
51
 
52
- # 5. Tokenize + compute normal embeddings
53
  inputs = tokenizer(prompt, return_tensors="pt")
 
 
54
  with torch.no_grad():
55
- normal_embeds = embed_layer(inputs.input_ids) # shape: (1, seq_len, hidden_size)
56
 
57
- # 6. === EXPLICIT ENCRYPTION (this is the key step you asked for) ===
58
- # Permute the hidden dimension according to the secret sigma
59
- encrypted_embeds = normal_embeds[..., sigma] # now scrambled — provider sees nothing
60
 
61
- # 7. Send ONLY encrypted embeddings to server
62
  payload = {
63
  "encrypted_embeds": encrypted_embeds.tolist(),
64
  "attention_mask": inputs.attention_mask.tolist(),
65
- "max_new_tokens": max_tokens
66
  }
67
 
68
- resp = requests.post(f"{server_url}/generate", json=payload, timeout=300)
 
 
 
 
69
  resp.raise_for_status()
70
 
71
  gen_ids = resp.json()["generated_ids"]
72
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
73
 
 
 
 
 
74
  except Exception as e:
75
  error = str(e)
76
 
77
- return render_template("client.html", result=result, error=error)
 
78
 
79
  if __name__ == "__main__":
80
  app.run(host="0.0.0.0", port=7860)
 
8
 
9
  app = Flask(__name__)
10
 
11
+ # Cache tokenizer/embed layer so repeated requests don't reload from scratch
12
+ _cache = {}
13
+
14
+
15
  def get_sigma(hidden_size: int, seed: int):
16
+ """Derive client-side encryption permutation from secret seed."""
17
  rng = np.random.default_rng(seed)
18
+ return rng.permutation(hidden_size)
19
+
20
+
21
+ def load_client_components(ee_model_name: str):
22
+ """Load (and cache) tokenizer + original embed layer for a given EE model."""
23
+ if ee_model_name in _cache:
24
+ return _cache[ee_model_name]
25
+
26
+ # 1. Fetch EE config to discover hidden_size + original model name
27
+ config_path = hf_hub_download(ee_model_name, "ee_config.json")
28
+ with open(config_path) as f:
29
+ ee_config = json.load(f)
30
+
31
+ hidden_size = ee_config["hidden_size"]
32
+ original_model_name = ee_config["original_model"]
33
+
34
+ # 2. Load tokenizer (from the EE model)
35
+ tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True)
36
+
37
+ # 3. Load ONLY the original embedding layer (CPU is fine — no forward pass needed)
38
+ embed_model = AutoModelForCausalLM.from_pretrained(
39
+ original_model_name,
40
+ torch_dtype=torch.float16,
41
+ device_map="cpu",
42
+ trust_remote_code=True,
43
+ )
44
+ embed_layer = embed_model.model.embed_tokens
45
+
46
+ _cache[ee_model_name] = (tokenizer, embed_layer, hidden_size)
47
+ return tokenizer, embed_layer, hidden_size
48
+
49
 
50
  @app.route("/", methods=["GET", "POST"])
51
  def index():
52
  result = None
53
  error = None
54
+ form_data = {}
55
 
56
  if request.method == "POST":
57
+ form_data = request.form.to_dict()
58
+ server_url = request.form["server_url"].rstrip("/")
59
  ee_model_name = request.form["ee_model_name"].strip()
60
+ ee_seed = int(request.form["ee_seed"])
61
+ prompt = request.form["prompt"].strip()
62
+ max_tokens = int(request.form.get("max_tokens", 256))
63
 
64
  try:
65
+ tokenizer, embed_layer, hidden_size = load_client_components(ee_model_name)
 
 
 
 
 
 
 
 
66
 
67
+ # Derive encryption key
68
+ sigma = get_sigma(hidden_size, ee_seed)
 
 
 
 
 
 
 
 
 
69
 
70
+ # Tokenize
71
  inputs = tokenizer(prompt, return_tensors="pt")
72
+
73
+ # Compute plain embeddings
74
  with torch.no_grad():
75
+ normal_embeds = embed_layer(inputs.input_ids) # (1, seq_len, hidden)
76
 
77
+ # Encrypt: permute hidden dimension server sees only scrambled vectors
78
+ encrypted_embeds = normal_embeds[..., sigma]
 
79
 
80
+ # Send to server
81
  payload = {
82
  "encrypted_embeds": encrypted_embeds.tolist(),
83
  "attention_mask": inputs.attention_mask.tolist(),
84
+ "max_new_tokens": max_tokens,
85
  }
86
 
87
+ resp = requests.post(
88
+ f"{server_url}/generate",
89
+ json=payload,
90
+ timeout=300,
91
+ )
92
  resp.raise_for_status()
93
 
94
  gen_ids = resp.json()["generated_ids"]
95
  result = tokenizer.decode(gen_ids, skip_special_tokens=True)
96
 
97
+ except requests.exceptions.ConnectionError:
98
+ error = f"Could not connect to server at {server_url}. Is it running?"
99
+ except requests.exceptions.HTTPError as e:
100
+ error = f"Server returned an error: {e.response.status_code} — {e.response.text}"
101
  except Exception as e:
102
  error = str(e)
103
 
104
+ return render_template("client.html", result=result, error=error, form=form_data)
105
+
106
 
107
  if __name__ == "__main__":
108
  app.run(host="0.0.0.0", port=7860)