LH-Tech-AI commited on
Commit
b30d9e6
·
verified ·
1 Parent(s): 382d966

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from fastapi import FastAPI, Request
2
- from fastapi.responses import StreamingResponse, FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.staticfiles import StaticFiles
5
  import onnxruntime as ort
6
  import numpy as np
7
  import tiktoken
@@ -10,22 +9,33 @@ import os
10
 
11
  app = FastAPI()
12
 
13
- # CORS erlauben
14
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
15
 
16
  # Modell & Tokenizer laden
17
  tokenizer = tiktoken.get_encoding("gpt2")
18
  MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
19
- # Nutzt optimierte CPU-Einstellungen
20
- session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
 
 
 
21
 
22
  def top_k_sample(logits, k=50, temp=0.7):
23
- logits = logits / temp
24
- # Numpy ist VIEL schneller als JS-Schleifen
25
  top_k_indices = np.argpartition(logits, -k)[-k:]
26
  top_k_logits = logits[top_k_indices]
 
 
27
  exp_logits = np.exp(top_k_logits - np.max(top_k_logits))
28
  probs = exp_logits / np.sum(exp_logits)
 
29
  return int(np.random.choice(top_k_indices, p=probs))
30
 
31
  @app.post("/chat")
@@ -34,20 +44,34 @@ async def chat(request: Request):
34
  prompt = f"Instruction:\n{data['prompt']}\n\nResponse:\n"
35
  tokens = tokenizer.encode(prompt)
36
 
 
 
 
 
37
  async def generate():
38
  nonlocal tokens
39
- for _ in range(data.get('maxLen', 100)):
 
40
  ctx = tokens[-1024:]
 
41
  padded = np.zeros((1, 1024), dtype=np.int64)
42
  padded[0, -len(ctx):] = ctx
43
 
 
44
  outputs = session.run(None, {'input': padded})
 
45
  logits = outputs[0][0, -1, :50304]
46
 
47
- next_token = top_k_sample(logits, k=data.get('topK', 25), temp=data.get('temp', 0.5))
48
- if next_token == 50256: break
 
 
49
 
50
  tokens.append(next_token)
51
  yield f"data: {json.dumps({'token': tokenizer.decode([next_token])})}\n\n"
52
 
53
- return StreamingResponse(generate(), media_type="text/event-stream")
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
+ from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  import onnxruntime as ort
5
  import numpy as np
6
  import tiktoken
 
9
 
10
  app = FastAPI()
11
 
12
+ # WICHTIG: Erlaubt deinem externen Frontend den Zugriff
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Hier kannst du später deine Domain eintragen
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
 
20
  # Modell & Tokenizer laden
21
  tokenizer = tiktoken.get_encoding("gpt2")
22
  MODEL_PATH = "SmaLLMPro_350M_int8.onnx"
23
+
24
+ # Optimierte Session-Optionen für CPU
25
+ sess_options = ort.SessionOptions()
26
+ sess_options.intra_op_num_threads = 2 # HF Spaces haben meist 2 Kerne
27
+ session = ort.InferenceSession(MODEL_PATH, sess_options, providers=['CPUExecutionProvider'])
28
 
29
  def top_k_sample(logits, k=50, temp=0.7):
30
+ logits = logits / max(temp, 1e-6)
31
+ # Nur die Top-K Werte betrachten (spart massiv Zeit beim Sortieren)
32
  top_k_indices = np.argpartition(logits, -k)[-k:]
33
  top_k_logits = logits[top_k_indices]
34
+
35
+ # Stabiler Softmax
36
  exp_logits = np.exp(top_k_logits - np.max(top_k_logits))
37
  probs = exp_logits / np.sum(exp_logits)
38
+
39
  return int(np.random.choice(top_k_indices, p=probs))
40
 
41
  @app.post("/chat")
 
44
  prompt = f"Instruction:\n{data['prompt']}\n\nResponse:\n"
45
  tokens = tokenizer.encode(prompt)
46
 
47
+ max_len = int(data.get('maxLen', 100))
48
+ temp = float(data.get('temp', 0.7))
49
+ top_k = int(data.get('topK', 40))
50
+
51
  async def generate():
52
  nonlocal tokens
53
+ for _ in range(max_len):
54
+ # Kontext auf 1024 beschränken
55
  ctx = tokens[-1024:]
56
+ # Padding (Rechtsbündig)
57
  padded = np.zeros((1, 1024), dtype=np.int64)
58
  padded[0, -len(ctx):] = ctx
59
 
60
+ # Inferenz
61
  outputs = session.run(None, {'input': padded})
62
+ # Wir nehmen nur die Logits des letzten Tokens
63
  logits = outputs[0][0, -1, :50304]
64
 
65
+ next_token = top_k_sample(logits, k=top_k, temp=temp)
66
+
67
+ if next_token == 50256: # EOS
68
+ break
69
 
70
  tokens.append(next_token)
71
  yield f"data: {json.dumps({'token': tokenizer.decode([next_token])})}\n\n"
72
 
73
+ return StreamingResponse(generate(), media_type="text/event-stream")
74
+
75
+ @app.get("/")
76
+ def health():
77
+ return {"status": "SmaLLMPro API is online"}