import os import json import asyncio import numpy as np import onnxruntime as ort import tiktoken from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware # App-Initialisierung app = FastAPI() # CORS für dein externes Frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # 1. Modell & Tokenizer Setup TOKENIZER = tiktoken.get_encoding("gpt2") MODEL_PATH = "SmaLLMPro_350M_int8.onnx" VOCAB_SIZE = 50304 # 2. ONNX Runtime Optimierung # HF Free Spaces haben 2 vCPUs. Wir limitieren die Threads, # um "Context Switching" Overhead zu vermeiden. options = ort.SessionOptions() options.intra_op_num_threads = 2 options.inter_op_num_threads = 2 options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL # Session laden print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...") session = ort.InferenceSession( MODEL_PATH, sess_options=options, providers=['CPUExecutionProvider'] ) def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None): """Hochoptimiertes Sampling mit NumPy""" # 1. Repetition Penalty (optional, falls history vorhanden) if history is not None and penalty != 1.0: # Wir bestrafen bereits generierte Tokens direkt in den Logits unique_history = np.unique(history) # Nur gültige Token-Indices bestrafen valid_indices = unique_history[unique_history < len(logits)] logits[valid_indices] /= penalty # 2. Temperature Scaling logits = logits / max(temp, 1e-6) # 3. Top-K via Partition (schneller als vollständiges Sortieren) # Sucht die k größten Werte ohne den Rest zu sortieren top_k_idx = np.argpartition(logits, -k)[-k:] top_k_logits = logits[top_k_idx] # 4. Softmax shifted_logits = top_k_logits - np.max(top_k_logits) exp_logits = np.exp(shifted_logits) probs = exp_logits / np.sum(exp_logits) # 5. Sample choice = np.random.choice(top_k_idx, p=probs) return int(choice) @app.post("/chat") async def chat(request: Request): try: data = await request.json() user_prompt = data.get('prompt', '') max_len = int(data.get('maxLen', 100)) temp = float(data.get('temp', 0.7)) top_k = int(data.get('topK', 25)) repetition_penalty = float(data.get('penalty', 1.2)) # Alpaca Instruction Format full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n" tokens = TOKENIZER.encode(full_prompt) async def generate(): nonlocal tokens # Wir behalten die Historie für die Penalty im Auge history = np.array(tokens, dtype=np.int32) for _ in range(max_len): # 1. Context Handling: Immer exakt 1024 (Padding rechtsbündig) ctx = tokens[-1024:] input_array = np.zeros((1, 1024), dtype=np.int64) input_array[0, -len(ctx):] = ctx # 2. Inferenz (Synchroner Call in asynchronem Generator) # Das ist der Flaschenhals, hier arbeitet die CPU outputs = session.run(None, {'input': input_array}) # 3. Logits extrahieren (letztes Token, erste VOCAB_SIZE) logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32) # 4. Sampling next_token = fast_top_k_sample( logits, k=top_k, temp=temp, penalty=repetition_penalty, history=history ) if next_token == 50256: # EOS Token break # 5. Update state tokens.append(next_token) history = np.append(history, next_token) # 6. Stream zum Client yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n" # Kurze Pause für den Event-Loop await asyncio.sleep(0.01) return StreamingResponse(generate(), media_type="text/event-stream") except Exception as e: print(f"Error: {e}") return {"error": str(e)} @app.get("/") async def health(): return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)