File size: 4,683 Bytes
acdc12e 53f92c7 b30d9e6 53f92c7 acdc12e 53f92c7 acdc12e b30d9e6 acdc12e b30d9e6 53f92c7 acdc12e 53f92c7 acdc12e b30d9e6 acdc12e 53f92c7 acdc12e b30d9e6 acdc12e b30d9e6 acdc12e 53f92c7 b30d9e6 acdc12e 53f92c7 acdc12e b30d9e6 acdc12e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import os
import json
import asyncio
import numpy as np
import onnxruntime as ort
import tiktoken
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# App-Initialisierung
app = FastAPI()
# CORS für dein externes Frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# 1. Modell & Tokenizer Setup
TOKENIZER = tiktoken.get_encoding("gpt2")
MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
VOCAB_SIZE = 50304
# 2. ONNX Runtime Optimierung
# HF Free Spaces haben 2 vCPUs. Wir limitieren die Threads,
# um "Context Switching" Overhead zu vermeiden.
options = ort.SessionOptions()
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# Session laden
print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...")
session = ort.InferenceSession(
MODEL_PATH,
sess_options=options,
providers=['CPUExecutionProvider']
)
def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None):
"""Hochoptimiertes Sampling mit NumPy"""
# 1. Repetition Penalty (optional, falls history vorhanden)
if history is not None and penalty != 1.0:
# Wir bestrafen bereits generierte Tokens direkt in den Logits
unique_history = np.unique(history)
# Nur gültige Token-Indices bestrafen
valid_indices = unique_history[unique_history < len(logits)]
logits[valid_indices] /= penalty
# 2. Temperature Scaling
logits = logits / max(temp, 1e-6)
# 3. Top-K via Partition (schneller als vollständiges Sortieren)
# Sucht die k größten Werte ohne den Rest zu sortieren
top_k_idx = np.argpartition(logits, -k)[-k:]
top_k_logits = logits[top_k_idx]
# 4. Softmax
shifted_logits = top_k_logits - np.max(top_k_logits)
exp_logits = np.exp(shifted_logits)
probs = exp_logits / np.sum(exp_logits)
# 5. Sample
choice = np.random.choice(top_k_idx, p=probs)
return int(choice)
@app.post("/chat")
async def chat(request: Request):
try:
data = await request.json()
user_prompt = data.get('prompt', '')
max_len = int(data.get('maxLen', 100))
temp = float(data.get('temp', 0.7))
top_k = int(data.get('topK', 25))
repetition_penalty = float(data.get('penalty', 1.2))
# Alpaca Instruction Format
full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n"
tokens = TOKENIZER.encode(full_prompt)
async def generate():
nonlocal tokens
# Wir behalten die Historie für die Penalty im Auge
history = np.array(tokens, dtype=np.int32)
for _ in range(max_len):
# 1. Context Handling: Immer exakt 1024 (Padding rechtsbündig)
ctx = tokens[-1024:]
input_array = np.zeros((1, 1024), dtype=np.int64)
input_array[0, -len(ctx):] = ctx
# 2. Inferenz (Synchroner Call in asynchronem Generator)
# Das ist der Flaschenhals, hier arbeitet die CPU
outputs = session.run(None, {'input': input_array})
# 3. Logits extrahieren (letztes Token, erste VOCAB_SIZE)
logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
# 4. Sampling
next_token = fast_top_k_sample(
logits,
k=top_k,
temp=temp,
penalty=repetition_penalty,
history=history
)
if next_token == 50256: # EOS Token
break
# 5. Update state
tokens.append(next_token)
history = np.append(history, next_token)
# 6. Stream zum Client
yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n"
# Kurze Pause für den Event-Loop
await asyncio.sleep(0.01)
return StreamingResponse(generate(), media_type="text/event-stream")
except Exception as e:
print(f"Error: {e}")
return {"error": str(e)}
@app.get("/")
async def health():
return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |