File size: 3,810 Bytes
f9aca5d
010db11
 
54f6dae
 
 
 
42fa16e
54f6dae
 
 
 
 
f9aca5d
54f6dae
 
 
 
f9aca5d
42fa16e
 
 
54f6dae
 
42fa16e
 
 
3b802b2
010db11
54f6dae
f9aca5d
54f6dae
f9aca5d
54f6dae
6eebe14
010db11
6eebe14
 
010db11
6eebe14
54f6dae
42fa16e
f9aca5d
6eebe14
 
 
 
 
 
 
 
047143f
6eebe14
 
 
 
 
 
 
 
 
a8a31d7
2b63295
 
010db11
2b63295
 
 
 
 
 
 
6eebe14
2b63295
6eebe14
2b63295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eebe14
2b63295
 
010db11
54f6dae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import threading
import time
from typing import Any, Dict, Generator, List, Optional

from huggingface_hub import hf_hub_download
from llama_cpp import Llama

from src.core.config import settings


class ModelEngine:
    def __init__(self):
        self.llm = None
        self.lock = threading.Lock()
        self._load_model()

    def _load_model(self):
        try:
            print(f"Downloading/Loading model: {settings.REPO_ID}...")
            model_path = hf_hub_download(
                repo_id=settings.REPO_ID, filename=settings.FILENAME
            )
            self.llm = Llama(
                model_path=model_path,
                n_ctx=settings.CONTEXT_SIZE,
                n_threads=settings.N_THREADS,
                n_gpu_layers=settings.N_GPU_LAYERS,
                chat_format="chatml",
                verbose=True,
            )
            print("Model loaded successfully!")
        except Exception as e:
            print(f"CRITICAL ERROR loading model: {e}")

    def generate(
        self,
        messages: List[Dict[str, Any]],
        abort_event: Optional[threading.Event] = None,
        **kwargs,
    ):
        if not self.llm:
            raise RuntimeError("Model not loaded")

        stream_mode = kwargs.get("stream", True)

        # Подготавливаем аргументы для llama_cpp_python
        llama_kwargs = {
            "messages": messages,
            "max_tokens": int(kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)),
            "temperature": float(kwargs.get("temperature", settings.DEFAULT_TEMP)),
            "top_p": float(kwargs.get("top_p", 0.95)),
            "repeat_penalty": float(kwargs.get("repeat_penalty", 1.15)),
            "stop": kwargs.get("stop", []),
            "stream": stream_mode,
        }

        # Прокидываем инструменты (Tool Calling)
        if kwargs.get("tools"):
            llama_kwargs["tools"] = kwargs["tools"]
        if kwargs.get("tool_choice"):
            llama_kwargs["tool_choice"] = kwargs["tool_choice"]

        # Если включен стриминг, выносим yield во внутреннюю функцию-генератор
        if stream_mode:

            def stream_generator():
                acquired = False
                while not acquired:
                    if abort_event and abort_event.is_set():
                        print("Request aborted while waiting in queue.")
                        return
                    acquired = self.lock.acquire(timeout=0.5)

                try:
                    if abort_event and abort_event.is_set():
                        return

                    response = self.llm.create_chat_completion(**llama_kwargs)
                    for chunk in response:
                        if abort_event and abort_event.is_set():
                            print("Request aborted during generation.")
                            break
                        yield chunk
                finally:
                    self.lock.release()

            return stream_generator()

        # Если стриминг выключен (режим Агента), возвращаем обычный словарь
        else:
            acquired = False
            while not acquired:
                if abort_event and abort_event.is_set():
                    print("Request aborted while waiting in queue.")
                    return
                acquired = self.lock.acquire(timeout=0.5)

            try:
                if abort_event and abort_event.is_set():
                    return

                response = self.llm.create_chat_completion(**llama_kwargs)
                return response
            finally:
                self.lock.release()


engine = ModelEngine()