LH-Tech-AI's picture
Update app.py
acdc12e verified
import os
import json
import asyncio
import numpy as np
import onnxruntime as ort
import tiktoken
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# App-Initialisierung
app = FastAPI()
# CORS für dein externes Frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# 1. Modell & Tokenizer Setup
TOKENIZER = tiktoken.get_encoding("gpt2")
MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
VOCAB_SIZE = 50304
# 2. ONNX Runtime Optimierung
# HF Free Spaces haben 2 vCPUs. Wir limitieren die Threads,
# um "Context Switching" Overhead zu vermeiden.
options = ort.SessionOptions()
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# Session laden
print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...")
session = ort.InferenceSession(
MODEL_PATH,
sess_options=options,
providers=['CPUExecutionProvider']
)
def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None):
"""Hochoptimiertes Sampling mit NumPy"""
# 1. Repetition Penalty (optional, falls history vorhanden)
if history is not None and penalty != 1.0:
# Wir bestrafen bereits generierte Tokens direkt in den Logits
unique_history = np.unique(history)
# Nur gültige Token-Indices bestrafen
valid_indices = unique_history[unique_history < len(logits)]
logits[valid_indices] /= penalty
# 2. Temperature Scaling
logits = logits / max(temp, 1e-6)
# 3. Top-K via Partition (schneller als vollständiges Sortieren)
# Sucht die k größten Werte ohne den Rest zu sortieren
top_k_idx = np.argpartition(logits, -k)[-k:]
top_k_logits = logits[top_k_idx]
# 4. Softmax
shifted_logits = top_k_logits - np.max(top_k_logits)
exp_logits = np.exp(shifted_logits)
probs = exp_logits / np.sum(exp_logits)
# 5. Sample
choice = np.random.choice(top_k_idx, p=probs)
return int(choice)
@app.post("/chat")
async def chat(request: Request):
try:
data = await request.json()
user_prompt = data.get('prompt', '')
max_len = int(data.get('maxLen', 100))
temp = float(data.get('temp', 0.7))
top_k = int(data.get('topK', 25))
repetition_penalty = float(data.get('penalty', 1.2))
# Alpaca Instruction Format
full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n"
tokens = TOKENIZER.encode(full_prompt)
async def generate():
nonlocal tokens
# Wir behalten die Historie für die Penalty im Auge
history = np.array(tokens, dtype=np.int32)
for _ in range(max_len):
# 1. Context Handling: Immer exakt 1024 (Padding rechtsbündig)
ctx = tokens[-1024:]
input_array = np.zeros((1, 1024), dtype=np.int64)
input_array[0, -len(ctx):] = ctx
# 2. Inferenz (Synchroner Call in asynchronem Generator)
# Das ist der Flaschenhals, hier arbeitet die CPU
outputs = session.run(None, {'input': input_array})
# 3. Logits extrahieren (letztes Token, erste VOCAB_SIZE)
logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
# 4. Sampling
next_token = fast_top_k_sample(
logits,
k=top_k,
temp=temp,
penalty=repetition_penalty,
history=history
)
if next_token == 50256: # EOS Token
break
# 5. Update state
tokens.append(next_token)
history = np.append(history, next_token)
# 6. Stream zum Client
yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n"
# Kurze Pause für den Event-Loop
await asyncio.sleep(0.01)
return StreamingResponse(generate(), media_type="text/event-stream")
except Exception as e:
print(f"Error: {e}")
return {"error": str(e)}
@app.get("/")
async def health():
return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)