File size: 3,525 Bytes
3977e64
 
 
ba9a967
3977e64
 
 
 
 
 
 
 
be0a757
ba9a967
 
3383b9c
ba9a967
3977e64
 
 
be0a757
ba9a967
3977e64
9e6e352
 
3977e64
9e6e352
 
3977e64
9e6e352
 
 
 
 
 
 
3977e64
9e6e352
 
 
 
3977e64
9e6e352
 
3977e64
9e6e352
 
3977e64
ba9a967
 
 
04a6795
ba9a967
 
 
3977e64
 
 
9e6e352
 
 
 
 
 
3977e64
9e6e352
3383b9c
 
 
 
 
9e6e352
3383b9c
9e6e352
5cad3e1
3383b9c
9e6e352
5cad3e1
3383b9c
9e6e352
 
 
 
 
 
 
 
 
3383b9c
 
9e6e352
 
3383b9c
2aef358
9e6e352
3383b9c
 
2aef358
9e6e352
 
 
 
 
3383b9c
 
9e6e352
3383b9c
3977e64
ba9a967
3977e64
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from flask import Flask, render_template, request, flash, jsonify
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import os, json

app = Flask(__name__)
app.secret_key = os.urandom(24)

ee_model = None
ee_tokenizer = None
ee_config = None
ee_model_name = None

SPACE_HOST = os.environ.get("SPACE_HOST", "")
SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860"


@app.route("/", methods=["GET", "POST"])
def index():
    global ee_model, ee_tokenizer, ee_config, ee_model_name

    if request.method == "POST":
        ee_model_name = request.form["ee_model_name"].strip()
        hf_token      = request.form["hf_token"].strip()

        try:
            login(token=hf_token)

            ee_model = AutoModelForCausalLM.from_pretrained(
                ee_model_name, torch_dtype=torch.float16,
                device_map="auto", trust_remote_code=True
            )
            ee_tokenizer = AutoTokenizer.from_pretrained(
                ee_model_name, trust_remote_code=True
            )

            from huggingface_hub import hf_hub_download
            config_path = hf_hub_download(ee_model_name, "ee_config.json")
            with open(config_path) as f:
                ee_config = json.load(f)

            flash(f"✅ Model loaded: {ee_model_name}", "success")
            flash("Point your Client Space to this Space's URL below.", "info")

        except Exception as e:
            flash(f"Error: {str(e)}", "danger")

    return render_template(
        "index.html",
        server_ready=(ee_model is not None),
        model_name=ee_model_name if ee_config else None,
        space_url=SPACE_URL,
    )


@app.route("/generate", methods=["POST"])
def generate():
    """
    Receives sigma-encrypted embeddings + optional past_key_values.
    Returns last hidden state (still in sigma-space) + new KV cache.
    Does NOT run lm_head — that stays on the client.
    Server never sees token IDs, logits, or plaintext.
    """
    if ee_model is None:
        return jsonify({"error": "Server not started yet"}), 400

    try:
        data = request.json
        model_dtype = next(ee_model.parameters()).dtype

        inputs_embeds = torch.tensor(data["inputs_embeds"]).to(
            dtype=model_dtype, device=ee_model.device
        )

        attention_mask = torch.tensor(
            data.get("attention_mask", [[1] * inputs_embeds.shape[1]])
        ).to(device=ee_model.device)

        past_key_values = None
        if data.get("past_key_values"):
            past_key_values = tuple(
                tuple(
                    torch.tensor(t).to(dtype=model_dtype, device=ee_model.device)
                    for t in layer
                )
                for layer in data["past_key_values"]
            )

        with torch.no_grad():
            out = ee_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                use_cache=False,
                output_hidden_states=True,
            )

        # Return final hidden state in sigma-space — client applies sigma_inv + lm_head
        last_hidden = out.hidden_states[-1]  # (1, seq_len, hidden)

        return jsonify({
            "last_hidden": last_hidden.cpu().tolist(),
        })

    except Exception as e:
        import traceback
        return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)