File size: 7,615 Bytes
4412065
 
380204e
 
 
 
 
4412065
 
 
 
 
380204e
4412065
 
 
 
380204e
4412065
 
 
 
 
380204e
4412065
380204e
4412065
 
 
 
 
380204e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4412065
380204e
 
 
 
 
 
 
 
 
 
 
 
4412065
 
 
 
 
380204e
 
 
 
 
 
 
 
 
4412065
 
 
 
 
 
380204e
 
 
 
 
4412065
380204e
4412065
 
 
380204e
4412065
 
 
 
380204e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4412065
380204e
4412065
 
 
 
380204e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4412065
 
380204e
4412065
380204e
 
 
 
 
 
 
4412065
380204e
 
 
 
 
 
 
4412065
380204e
 
 
 
 
 
 
 
 
 
 
 
 
 
4412065
 
 
 
 
 
 
380204e
 
 
4412065
 
 
 
 
380204e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Model loading and status management.

Supports two models:
- MiniCPM5-1B (text-only, fast)
- MiniCPM-V-4.6 (vision + text, image understanding)

Only one model is loaded at a time to conserve memory.
The model is loaded in a background thread on startup.
"""

from __future__ import annotations

import gc
import logging
import threading
from typing import Any

from code.config.constants import DEFAULT_MODEL_KEY, MODEL_CONFIGS

logger = logging.getLogger(__name__)

# ─── Module-level state ─────────────────────────────────────────────────

_current_model_key: str = DEFAULT_MODEL_KEY
_model = None
_tokenizer_or_processor = None
_model_loaded = False
_model_loading = False
_load_error: str | None = None


def _unload_model() -> None:
    """Unload current model and free memory."""
    global _model, _tokenizer_or_processor, _model_loaded

    if _model is not None:
        del _model
        _model = None
    if _tokenizer_or_processor is not None:
        del _tokenizer_or_processor
        _tokenizer_or_processor = None

    _model_loaded = False
    gc.collect()

    try:
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except ImportError:
        pass


def load_model(model_key: str | None = None) -> None:
    """Load a model by key. Unloads the previous model first."""
    global _model, _tokenizer_or_processor, _model_loaded, _model_loading
    global _load_error, _current_model_key

    if model_key is None:
        model_key = _current_model_key

    if model_key not in MODEL_CONFIGS:
        _load_error = f"Unknown model: {model_key}"
        logger.error(_load_error)
        return

    # Skip if already loading or already loaded with same key
    if _model_loading:
        return
    if _model_loaded and _current_model_key == model_key:
        return

    _model_loading = True
    _load_error = None

    # Unload previous model if switching
    if _model_loaded and _current_model_key != model_key:
        logger.info("Switching model from %s to %s", _current_model_key, model_key)
        _unload_model()

    _current_model_key = model_key
    config = MODEL_CONFIGS[model_key]
    model_id = config["id"]

    try:
        import torch

        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        device_map = "auto" if torch.cuda.is_available() else None

        if config["type"] == "vlm":
            _load_vlm_model(model_id, dtype, device_map)
        else:
            _load_text_model(model_id, dtype, device_map)

        _model_loaded = True
        logger.info("%s model loaded successfully.", config["name"])

    except Exception as exc:
        _load_error = str(exc)
        logger.exception("Failed to load model %s: %s", model_id, exc)
    finally:
        _model_loading = False


def _load_text_model(model_id: str, dtype, device_map) -> None:
    """Load a text-only model (AutoModelForCausalLM + AutoTokenizer)."""
    global _model, _tokenizer_or_processor

    from transformers import AutoModelForCausalLM, AutoTokenizer

    logger.info("Loading %s (text model)...", model_id)

    _tokenizer_or_processor = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True,
    )
    _model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )

    if device_map is None:
        _model = _model.to("cpu")

    _model.eval()


def _load_vlm_model(model_id: str, dtype, device_map) -> None:
    """Load a vision-language model (AutoModelForImageTextToText + AutoProcessor)."""
    global _model, _tokenizer_or_processor

    try:
        from transformers import AutoModelForImageTextToText, AutoProcessor
    except ImportError:
        # Fallback for older transformers
        logger.warning("AutoModelForImageTextToText not found, trying AutoModel...")
        from transformers import AutoModel as AutoModelForImageTextToText
        from transformers import AutoProcessor

    logger.info("Loading %s (VLM)...", model_id)

    _tokenizer_or_processor = AutoProcessor.from_pretrained(
        model_id,
        trust_remote_code=True,
    )
    _model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )

    if device_map is None:
        _model = _model.to("cpu")

    _model.eval()


def start_background_load(model_key: str | None = None) -> threading.Thread:
    """Start loading the model in a background daemon thread."""
    thread = threading.Thread(target=load_model, args=(model_key,), daemon=True)
    thread.start()
    return thread


def switch_model(model_key: str) -> dict[str, Any]:
    """Switch to a different model. Returns status immediately, loads in background."""
    global _current_model_key

    if model_key not in MODEL_CONFIGS:
        return {"success": False, "message": f"Unknown model: {model_key}"}

    if _current_model_key == model_key and _model_loaded:
        return {"success": True, "message": f"Already using {MODEL_CONFIGS[model_key]['name']}"}

    _current_model_key = model_key
    _model_loaded = False

    # Start loading in background
    start_background_load(model_key)

    config = MODEL_CONFIGS[model_key]
    return {
        "success": True,
        "message": f"Switching to {config['name']}...",
        "model_key": model_key,
        "model_name": config["name"],
    }


def get_model_status() -> dict[str, Any]:
    """Return current model loading status."""
    config = MODEL_CONFIGS.get(_current_model_key, {})
    if _model_loaded:
        return {
            "status": "ready",
            "message": f"{config.get('name', 'Model')} loaded and ready",
            "model_key": _current_model_key,
            "model_name": config.get("name", ""),
            "model_type": config.get("type", "text"),
        }
    if _model_loading:
        return {
            "status": "loading",
            "message": f"Loading {config.get('name', 'model')}... (this may take a few minutes)",
            "model_key": _current_model_key,
            "model_name": config.get("name", ""),
            "model_type": config.get("type", "text"),
        }
    if _load_error:
        return {
            "status": "error",
            "message": f"Model load error: {_load_error}",
            "model_key": _current_model_key,
            "model_name": config.get("name", ""),
            "model_type": config.get("type", "text"),
        }
    return {
        "status": "unknown",
        "message": "Model not initialized",
        "model_key": _current_model_key,
        "model_name": config.get("name", ""),
        "model_type": config.get("type", "text"),
    }


def get_model():
    """Return the loaded model instance (or None)."""
    return _model


def get_tokenizer_or_processor():
    """Return the loaded tokenizer or processor (or None)."""
    return _tokenizer_or_processor


def is_model_loaded() -> bool:
    """Return True if the model has been loaded successfully."""
    return _model_loaded


def get_current_model_key() -> str:
    """Return the key of the currently selected model."""
    return _current_model_key


def get_current_model_type() -> str:
    """Return 'text' or 'vlm' for the current model."""
    return MODEL_CONFIGS.get(_current_model_key, {}).get("type", "text")