File size: 4,476 Bytes
42fa16e a8a31d7 010db11 42fa16e f9aca5d 6eebe14 42fa16e fca7a73 42fa16e a8a31d7 42fa16e f9aca5d 42fa16e a8a31d7 6eebe14 a8a31d7 010db11 a8a31d7 6eebe14 010db11 42fa16e f9aca5d fca7a73 f9aca5d fca7a73 6eebe14 f9aca5d 010db11 f9aca5d 010db11 f9aca5d 010db11 6eebe14 010db11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import asyncio
import json
import logging
import threading
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from src.core.config import settings
from src.core.engine import engine
from src.utils.helpers import get_clean_text
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/chat/completions")
async def chat_completions(request: Request):
if not engine.llm:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
data = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON")
# Бережно собираем сообщения, сохраняя служебные поля для Tool Calling
messages = []
for m in data.get("messages", []):
msg = {
"role": m.get("role", "user"),
"content": get_clean_text(m.get("content")),
}
if "tool_calls" in m:
msg["tool_calls"] = m["tool_calls"]
if "tool_call_id" in m:
msg["tool_call_id"] = m["tool_call_id"]
messages.append(msg)
# Параметры от агента
is_stream = data.get("stream", False)
stop = data.get("stop", [])
if isinstance(stop, str):
stop = [stop]
default_stops = ["<|im_end|>", "<|endoftext|>", "<|file_sep|>"]
for s in default_stops:
if s not in stop:
stop.append(s)
gen_kwargs = {
"max_tokens": data.get("max_tokens", settings.DEFAULT_MAX_TOKENS),
"temperature": data.get("temperature", settings.DEFAULT_TEMP),
"top_p": data.get("top_p", 0.95),
"stop": stop,
"stream": is_stream,
"tools": data.get("tools", None),
"tool_choice": data.get("tool_choice", None),
}
# Если Агент просит ответ целиком (stream=False)
if not is_stream:
loop = asyncio.get_running_loop()
try:
# Выполняем синхронный код в пуле потоков, чтобы не заблокировать FastAPI
response = await loop.run_in_executor(
None, lambda: engine.generate(messages, **gen_kwargs)
)
return JSONResponse(content=response)
except Exception as e:
logger.error(f"Sync generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Если это Чат (stream=True)
abort_event = threading.Event()
async def stream_generator():
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def worker():
try:
gen_kwargs["abort_event"] = abort_event
for chunk in engine.generate(messages, **gen_kwargs):
loop.call_soon_threadsafe(queue.put_nowait, chunk)
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as e:
if not abort_event.is_set():
logger.error(f"Generation error: {e}")
loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})
loop.run_in_executor(None, worker)
try:
while True:
if await request.is_disconnected():
abort_event.set()
break
try:
chunk = await asyncio.wait_for(queue.get(), timeout=0.1)
except asyncio.TimeoutError:
continue
if chunk is None:
yield "data: [DONE]\n\n"
break
if isinstance(chunk, dict) and "error" in chunk:
if abort_event.is_set():
break
err_json = json.dumps(
{"error": {"message": chunk["error"], "type": "internal_error"}}
)
yield f"data: {err_json}\n\n"
break
yield f"data: {json.dumps(chunk)}\n\n"
except asyncio.CancelledError:
abort_event.set()
raise
return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
|