Code_LLM / src /core /engine.py
AnatoliiG
fix repeat haluc
047143f
Raw
History Blame Contribute Delete
3.81 kB
import threading
import time
from typing import Any, Dict, Generator, List, Optional
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from src.core.config import settings
class ModelEngine:
def __init__(self):
self.llm = None
self.lock = threading.Lock()
self._load_model()
def _load_model(self):
try:
print(f"Downloading/Loading model: {settings.REPO_ID}...")
model_path = hf_hub_download(
repo_id=settings.REPO_ID, filename=settings.FILENAME
)
self.llm = Llama(
model_path=model_path,
n_ctx=settings.CONTEXT_SIZE,
n_threads=settings.N_THREADS,
n_gpu_layers=settings.N_GPU_LAYERS,
chat_format="chatml",
verbose=True,
)
print("Model loaded successfully!")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
def generate(
self,
messages: List[Dict[str, Any]],
abort_event: Optional[threading.Event] = None,
**kwargs,
):
if not self.llm:
raise RuntimeError("Model not loaded")
stream_mode = kwargs.get("stream", True)
# Подготавливаем аргументы для llama_cpp_python
llama_kwargs = {
"messages": messages,
"max_tokens": int(kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)),
"temperature": float(kwargs.get("temperature", settings.DEFAULT_TEMP)),
"top_p": float(kwargs.get("top_p", 0.95)),
"repeat_penalty": float(kwargs.get("repeat_penalty", 1.15)),
"stop": kwargs.get("stop", []),
"stream": stream_mode,
}
# Прокидываем инструменты (Tool Calling)
if kwargs.get("tools"):
llama_kwargs["tools"] = kwargs["tools"]
if kwargs.get("tool_choice"):
llama_kwargs["tool_choice"] = kwargs["tool_choice"]
# Если включен стриминг, выносим yield во внутреннюю функцию-генератор
if stream_mode:
def stream_generator():
acquired = False
while not acquired:
if abort_event and abort_event.is_set():
print("Request aborted while waiting in queue.")
return
acquired = self.lock.acquire(timeout=0.5)
try:
if abort_event and abort_event.is_set():
return
response = self.llm.create_chat_completion(**llama_kwargs)
for chunk in response:
if abort_event and abort_event.is_set():
print("Request aborted during generation.")
break
yield chunk
finally:
self.lock.release()
return stream_generator()
# Если стриминг выключен (режим Агента), возвращаем обычный словарь
else:
acquired = False
while not acquired:
if abort_event and abort_event.is_set():
print("Request aborted while waiting in queue.")
return
acquired = self.lock.acquire(timeout=0.5)
try:
if abort_event and abort_event.is_set():
return
response = self.llm.create_chat_completion(**llama_kwargs)
return response
finally:
self.lock.release()
engine = ModelEngine()