| import threading |
| import time |
| from typing import Any, Dict, Generator, List, Optional |
|
|
| from huggingface_hub import hf_hub_download |
| from llama_cpp import Llama |
|
|
| from src.core.config import settings |
|
|
|
|
| class ModelEngine: |
| def __init__(self): |
| self.llm = None |
| self.lock = threading.Lock() |
| self._load_model() |
|
|
| def _load_model(self): |
| try: |
| print(f"Downloading/Loading model: {settings.REPO_ID}...") |
| model_path = hf_hub_download( |
| repo_id=settings.REPO_ID, filename=settings.FILENAME |
| ) |
| self.llm = Llama( |
| model_path=model_path, |
| n_ctx=settings.CONTEXT_SIZE, |
| n_threads=settings.N_THREADS, |
| n_gpu_layers=settings.N_GPU_LAYERS, |
| chat_format="chatml", |
| verbose=True, |
| ) |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"CRITICAL ERROR loading model: {e}") |
|
|
| def generate( |
| self, |
| messages: List[Dict[str, Any]], |
| abort_event: Optional[threading.Event] = None, |
| **kwargs, |
| ): |
| if not self.llm: |
| raise RuntimeError("Model not loaded") |
|
|
| stream_mode = kwargs.get("stream", True) |
|
|
| |
| llama_kwargs = { |
| "messages": messages, |
| "max_tokens": int(kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)), |
| "temperature": float(kwargs.get("temperature", settings.DEFAULT_TEMP)), |
| "top_p": float(kwargs.get("top_p", 0.95)), |
| "repeat_penalty": float(kwargs.get("repeat_penalty", 1.15)), |
| "stop": kwargs.get("stop", []), |
| "stream": stream_mode, |
| } |
|
|
| |
| if kwargs.get("tools"): |
| llama_kwargs["tools"] = kwargs["tools"] |
| if kwargs.get("tool_choice"): |
| llama_kwargs["tool_choice"] = kwargs["tool_choice"] |
|
|
| |
| if stream_mode: |
|
|
| def stream_generator(): |
| acquired = False |
| while not acquired: |
| if abort_event and abort_event.is_set(): |
| print("Request aborted while waiting in queue.") |
| return |
| acquired = self.lock.acquire(timeout=0.5) |
|
|
| try: |
| if abort_event and abort_event.is_set(): |
| return |
|
|
| response = self.llm.create_chat_completion(**llama_kwargs) |
| for chunk in response: |
| if abort_event and abort_event.is_set(): |
| print("Request aborted during generation.") |
| break |
| yield chunk |
| finally: |
| self.lock.release() |
|
|
| return stream_generator() |
|
|
| |
| else: |
| acquired = False |
| while not acquired: |
| if abort_event and abort_event.is_set(): |
| print("Request aborted while waiting in queue.") |
| return |
| acquired = self.lock.acquire(timeout=0.5) |
|
|
| try: |
| if abort_event and abort_event.is_set(): |
| return |
|
|
| response = self.llm.create_chat_completion(**llama_kwargs) |
| return response |
| finally: |
| self.lock.release() |
|
|
|
|
| engine = ModelEngine() |
|
|