broadfield-dev commited on
Commit
4da358a
·
verified ·
1 Parent(s): 9e6e352

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -86,6 +86,9 @@ def generate():
86
  for layer in data["past_key_values"]
87
  )
88
 
 
 
 
89
  with torch.no_grad():
90
  out = ee_model(
91
  inputs_embeds=inputs_embeds,
@@ -95,13 +98,16 @@ def generate():
95
  output_hidden_states=True,
96
  )
97
 
98
- # Return final hidden state in sigma-space — client decrypts + runs lm_head
99
  last_hidden = out.hidden_states[-1] # (1, seq_len, hidden)
100
 
101
- new_past = [
102
- [t.cpu().tolist() for t in layer]
103
- for layer in out.past_key_values
104
- ]
 
 
 
105
 
106
  return jsonify({
107
  "last_hidden": last_hidden.cpu().tolist(),
 
86
  for layer in data["past_key_values"]
87
  )
88
 
89
+ # Ensure model config has caching enabled
90
+ ee_model.config.use_cache = True
91
+
92
  with torch.no_grad():
93
  out = ee_model(
94
  inputs_embeds=inputs_embeds,
 
98
  output_hidden_states=True,
99
  )
100
 
101
+ # Final hidden state (sigma-space) — client decrypts + runs lm_head
102
  last_hidden = out.hidden_states[-1] # (1, seq_len, hidden)
103
 
104
+ # Serialize KV cache — guard against None (some models/configs don't return it)
105
+ new_past = None
106
+ if out.past_key_values is not None:
107
+ new_past = [
108
+ [t.cpu().tolist() for t in layer]
109
+ for layer in out.past_key_values
110
+ ]
111
 
112
  return jsonify({
113
  "last_hidden": last_hidden.cpu().tolist(),