File size: 7,615 Bytes
4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e 4412065 380204e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | """Model loading and status management.
Supports two models:
- MiniCPM5-1B (text-only, fast)
- MiniCPM-V-4.6 (vision + text, image understanding)
Only one model is loaded at a time to conserve memory.
The model is loaded in a background thread on startup.
"""
from __future__ import annotations
import gc
import logging
import threading
from typing import Any
from code.config.constants import DEFAULT_MODEL_KEY, MODEL_CONFIGS
logger = logging.getLogger(__name__)
# βββ Module-level state βββββββββββββββββββββββββββββββββββββββββββββββββ
_current_model_key: str = DEFAULT_MODEL_KEY
_model = None
_tokenizer_or_processor = None
_model_loaded = False
_model_loading = False
_load_error: str | None = None
def _unload_model() -> None:
"""Unload current model and free memory."""
global _model, _tokenizer_or_processor, _model_loaded
if _model is not None:
del _model
_model = None
if _tokenizer_or_processor is not None:
del _tokenizer_or_processor
_tokenizer_or_processor = None
_model_loaded = False
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
def load_model(model_key: str | None = None) -> None:
"""Load a model by key. Unloads the previous model first."""
global _model, _tokenizer_or_processor, _model_loaded, _model_loading
global _load_error, _current_model_key
if model_key is None:
model_key = _current_model_key
if model_key not in MODEL_CONFIGS:
_load_error = f"Unknown model: {model_key}"
logger.error(_load_error)
return
# Skip if already loading or already loaded with same key
if _model_loading:
return
if _model_loaded and _current_model_key == model_key:
return
_model_loading = True
_load_error = None
# Unload previous model if switching
if _model_loaded and _current_model_key != model_key:
logger.info("Switching model from %s to %s", _current_model_key, model_key)
_unload_model()
_current_model_key = model_key
config = MODEL_CONFIGS[model_key]
model_id = config["id"]
try:
import torch
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device_map = "auto" if torch.cuda.is_available() else None
if config["type"] == "vlm":
_load_vlm_model(model_id, dtype, device_map)
else:
_load_text_model(model_id, dtype, device_map)
_model_loaded = True
logger.info("%s model loaded successfully.", config["name"])
except Exception as exc:
_load_error = str(exc)
logger.exception("Failed to load model %s: %s", model_id, exc)
finally:
_model_loading = False
def _load_text_model(model_id: str, dtype, device_map) -> None:
"""Load a text-only model (AutoModelForCausalLM + AutoTokenizer)."""
global _model, _tokenizer_or_processor
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info("Loading %s (text model)...", model_id)
_tokenizer_or_processor = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
)
_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
if device_map is None:
_model = _model.to("cpu")
_model.eval()
def _load_vlm_model(model_id: str, dtype, device_map) -> None:
"""Load a vision-language model (AutoModelForImageTextToText + AutoProcessor)."""
global _model, _tokenizer_or_processor
try:
from transformers import AutoModelForImageTextToText, AutoProcessor
except ImportError:
# Fallback for older transformers
logger.warning("AutoModelForImageTextToText not found, trying AutoModel...")
from transformers import AutoModel as AutoModelForImageTextToText
from transformers import AutoProcessor
logger.info("Loading %s (VLM)...", model_id)
_tokenizer_or_processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
)
_model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
if device_map is None:
_model = _model.to("cpu")
_model.eval()
def start_background_load(model_key: str | None = None) -> threading.Thread:
"""Start loading the model in a background daemon thread."""
thread = threading.Thread(target=load_model, args=(model_key,), daemon=True)
thread.start()
return thread
def switch_model(model_key: str) -> dict[str, Any]:
"""Switch to a different model. Returns status immediately, loads in background."""
global _current_model_key
if model_key not in MODEL_CONFIGS:
return {"success": False, "message": f"Unknown model: {model_key}"}
if _current_model_key == model_key and _model_loaded:
return {"success": True, "message": f"Already using {MODEL_CONFIGS[model_key]['name']}"}
_current_model_key = model_key
_model_loaded = False
# Start loading in background
start_background_load(model_key)
config = MODEL_CONFIGS[model_key]
return {
"success": True,
"message": f"Switching to {config['name']}...",
"model_key": model_key,
"model_name": config["name"],
}
def get_model_status() -> dict[str, Any]:
"""Return current model loading status."""
config = MODEL_CONFIGS.get(_current_model_key, {})
if _model_loaded:
return {
"status": "ready",
"message": f"{config.get('name', 'Model')} loaded and ready",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
if _model_loading:
return {
"status": "loading",
"message": f"Loading {config.get('name', 'model')}... (this may take a few minutes)",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
if _load_error:
return {
"status": "error",
"message": f"Model load error: {_load_error}",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
return {
"status": "unknown",
"message": "Model not initialized",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
def get_model():
"""Return the loaded model instance (or None)."""
return _model
def get_tokenizer_or_processor():
"""Return the loaded tokenizer or processor (or None)."""
return _tokenizer_or_processor
def is_model_loaded() -> bool:
"""Return True if the model has been loaded successfully."""
return _model_loaded
def get_current_model_key() -> str:
"""Return the key of the currently selected model."""
return _current_model_key
def get_current_model_type() -> str:
"""Return 'text' or 'vlm' for the current model."""
return MODEL_CONFIGS.get(_current_model_key, {}).get("type", "text")
|