File size: 4,683 Bytes
acdc12e
 
 
 
 
 
53f92c7
b30d9e6
53f92c7
 
acdc12e
53f92c7
 
acdc12e
b30d9e6
 
acdc12e
b30d9e6
 
 
53f92c7
acdc12e
 
53f92c7
acdc12e
b30d9e6
acdc12e
 
 
 
 
 
 
 
53f92c7
acdc12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b30d9e6
acdc12e
 
 
 
 
b30d9e6
acdc12e
 
 
53f92c7
b30d9e6
acdc12e
 
 
53f92c7
 
 
acdc12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b30d9e6
 
acdc12e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import json
import asyncio
import numpy as np
import onnxruntime as ort
import tiktoken
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

# App-Initialisierung
app = FastAPI()

# CORS für dein externes Frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# 1. Modell & Tokenizer Setup
TOKENIZER = tiktoken.get_encoding("gpt2")
MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
VOCAB_SIZE = 50304

# 2. ONNX Runtime Optimierung
# HF Free Spaces haben 2 vCPUs. Wir limitieren die Threads, 
# um "Context Switching" Overhead zu vermeiden.
options = ort.SessionOptions()
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL

# Session laden
print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...")
session = ort.InferenceSession(
    MODEL_PATH, 
    sess_options=options, 
    providers=['CPUExecutionProvider']
)

def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None):
    """Hochoptimiertes Sampling mit NumPy"""
    # 1. Repetition Penalty (optional, falls history vorhanden)
    if history is not None and penalty != 1.0:
        # Wir bestrafen bereits generierte Tokens direkt in den Logits
        unique_history = np.unique(history)
        # Nur gültige Token-Indices bestrafen
        valid_indices = unique_history[unique_history < len(logits)]
        logits[valid_indices] /= penalty

    # 2. Temperature Scaling
    logits = logits / max(temp, 1e-6)

    # 3. Top-K via Partition (schneller als vollständiges Sortieren)
    # Sucht die k größten Werte ohne den Rest zu sortieren
    top_k_idx = np.argpartition(logits, -k)[-k:]
    top_k_logits = logits[top_k_idx]
    
    # 4. Softmax
    shifted_logits = top_k_logits - np.max(top_k_logits)
    exp_logits = np.exp(shifted_logits)
    probs = exp_logits / np.sum(exp_logits)
    
    # 5. Sample
    choice = np.random.choice(top_k_idx, p=probs)
    return int(choice)

@app.post("/chat")
async def chat(request: Request):
    try:
        data = await request.json()
        user_prompt = data.get('prompt', '')
        max_len = int(data.get('maxLen', 100))
        temp = float(data.get('temp', 0.7))
        top_k = int(data.get('topK', 25))
        repetition_penalty = float(data.get('penalty', 1.2))

        # Alpaca Instruction Format
        full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n"
        tokens = TOKENIZER.encode(full_prompt)

        async def generate():
            nonlocal tokens
            # Wir behalten die Historie für die Penalty im Auge
            history = np.array(tokens, dtype=np.int32)

            for _ in range(max_len):
                # 1. Context Handling: Immer exakt 1024 (Padding rechtsbündig)
                ctx = tokens[-1024:]
                input_array = np.zeros((1, 1024), dtype=np.int64)
                input_array[0, -len(ctx):] = ctx
                
                # 2. Inferenz (Synchroner Call in asynchronem Generator)
                # Das ist der Flaschenhals, hier arbeitet die CPU
                outputs = session.run(None, {'input': input_array})
                
                # 3. Logits extrahieren (letztes Token, erste VOCAB_SIZE)
                logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
                
                # 4. Sampling
                next_token = fast_top_k_sample(
                    logits, 
                    k=top_k, 
                    temp=temp, 
                    penalty=repetition_penalty,
                    history=history
                )
                
                if next_token == 50256: # EOS Token
                    break
                
                # 5. Update state
                tokens.append(next_token)
                history = np.append(history, next_token)
                
                # 6. Stream zum Client
                yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n"
                
                # Kurze Pause für den Event-Loop
                await asyncio.sleep(0.01)

        return StreamingResponse(generate(), media_type="text/event-stream")

    except Exception as e:
        print(f"Error: {e}")
        return {"error": str(e)}

@app.get("/")
async def health():
    return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)