File size: 3,810 Bytes
f9aca5d 010db11 54f6dae 42fa16e 54f6dae f9aca5d 54f6dae f9aca5d 42fa16e 54f6dae 42fa16e 3b802b2 010db11 54f6dae f9aca5d 54f6dae f9aca5d 54f6dae 6eebe14 010db11 6eebe14 010db11 6eebe14 54f6dae 42fa16e f9aca5d 6eebe14 047143f 6eebe14 a8a31d7 2b63295 010db11 2b63295 6eebe14 2b63295 6eebe14 2b63295 6eebe14 2b63295 010db11 54f6dae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import threading
import time
from typing import Any, Dict, Generator, List, Optional
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from src.core.config import settings
class ModelEngine:
def __init__(self):
self.llm = None
self.lock = threading.Lock()
self._load_model()
def _load_model(self):
try:
print(f"Downloading/Loading model: {settings.REPO_ID}...")
model_path = hf_hub_download(
repo_id=settings.REPO_ID, filename=settings.FILENAME
)
self.llm = Llama(
model_path=model_path,
n_ctx=settings.CONTEXT_SIZE,
n_threads=settings.N_THREADS,
n_gpu_layers=settings.N_GPU_LAYERS,
chat_format="chatml",
verbose=True,
)
print("Model loaded successfully!")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
def generate(
self,
messages: List[Dict[str, Any]],
abort_event: Optional[threading.Event] = None,
**kwargs,
):
if not self.llm:
raise RuntimeError("Model not loaded")
stream_mode = kwargs.get("stream", True)
# Подготавливаем аргументы для llama_cpp_python
llama_kwargs = {
"messages": messages,
"max_tokens": int(kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)),
"temperature": float(kwargs.get("temperature", settings.DEFAULT_TEMP)),
"top_p": float(kwargs.get("top_p", 0.95)),
"repeat_penalty": float(kwargs.get("repeat_penalty", 1.15)),
"stop": kwargs.get("stop", []),
"stream": stream_mode,
}
# Прокидываем инструменты (Tool Calling)
if kwargs.get("tools"):
llama_kwargs["tools"] = kwargs["tools"]
if kwargs.get("tool_choice"):
llama_kwargs["tool_choice"] = kwargs["tool_choice"]
# Если включен стриминг, выносим yield во внутреннюю функцию-генератор
if stream_mode:
def stream_generator():
acquired = False
while not acquired:
if abort_event and abort_event.is_set():
print("Request aborted while waiting in queue.")
return
acquired = self.lock.acquire(timeout=0.5)
try:
if abort_event and abort_event.is_set():
return
response = self.llm.create_chat_completion(**llama_kwargs)
for chunk in response:
if abort_event and abort_event.is_set():
print("Request aborted during generation.")
break
yield chunk
finally:
self.lock.release()
return stream_generator()
# Если стриминг выключен (режим Агента), возвращаем обычный словарь
else:
acquired = False
while not acquired:
if abort_event and abort_event.is_set():
print("Request aborted while waiting in queue.")
return
acquired = self.lock.acquire(timeout=0.5)
try:
if abort_event and abort_event.is_set():
return
response = self.llm.create_chat_completion(**llama_kwargs)
return response
finally:
self.lock.release()
engine = ModelEngine()
|