File size: 3,525 Bytes
3977e64 ba9a967 3977e64 be0a757 ba9a967 3383b9c ba9a967 3977e64 be0a757 ba9a967 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 9e6e352 3977e64 ba9a967 04a6795 ba9a967 3977e64 9e6e352 3977e64 9e6e352 3383b9c 9e6e352 3383b9c 9e6e352 5cad3e1 3383b9c 9e6e352 5cad3e1 3383b9c 9e6e352 3383b9c 9e6e352 3383b9c 2aef358 9e6e352 3383b9c 2aef358 9e6e352 3383b9c 9e6e352 3383b9c 3977e64 ba9a967 3977e64 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | from flask import Flask, render_template, request, flash, jsonify
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import os, json
app = Flask(__name__)
app.secret_key = os.urandom(24)
ee_model = None
ee_tokenizer = None
ee_config = None
ee_model_name = None
SPACE_HOST = os.environ.get("SPACE_HOST", "")
SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"
@app.route("/", methods=["GET", "POST"])
def index():
global ee_model, ee_tokenizer, ee_config, ee_model_name
if request.method == "POST":
ee_model_name = request.form["ee_model_name"].strip()
hf_token = request.form["hf_token"].strip()
try:
login(token=hf_token)
ee_model = AutoModelForCausalLM.from_pretrained(
ee_model_name, torch_dtype=torch.float16,
device_map="auto", trust_remote_code=True
)
ee_tokenizer = AutoTokenizer.from_pretrained(
ee_model_name, trust_remote_code=True
)
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
flash(f"✅ Model loaded: {ee_model_name}", "success")
flash("Point your Client Space to this Space's URL below.", "info")
except Exception as e:
flash(f"Error: {str(e)}", "danger")
return render_template(
"index.html",
server_ready=(ee_model is not None),
model_name=ee_model_name if ee_config else None,
space_url=SPACE_URL,
)
@app.route("/generate", methods=["POST"])
def generate():
"""
Receives sigma-encrypted embeddings + optional past_key_values.
Returns last hidden state (still in sigma-space) + new KV cache.
Does NOT run lm_head — that stays on the client.
Server never sees token IDs, logits, or plaintext.
"""
if ee_model is None:
return jsonify({"error": "Server not started yet"}), 400
try:
data = request.json
model_dtype = next(ee_model.parameters()).dtype
inputs_embeds = torch.tensor(data["inputs_embeds"]).to(
dtype=model_dtype, device=ee_model.device
)
attention_mask = torch.tensor(
data.get("attention_mask", [[1] * inputs_embeds.shape[1]])
).to(device=ee_model.device)
past_key_values = None
if data.get("past_key_values"):
past_key_values = tuple(
tuple(
torch.tensor(t).to(dtype=model_dtype, device=ee_model.device)
for t in layer
)
for layer in data["past_key_values"]
)
with torch.no_grad():
out = ee_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
use_cache=False,
output_hidden_states=True,
)
# Return final hidden state in sigma-space — client applies sigma_inv + lm_head
last_hidden = out.hidden_states[-1] # (1, seq_len, hidden)
return jsonify({
"last_hidden": last_hidden.cpu().tolist(),
})
except Exception as e:
import traceback
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |