R-Kentaren's picture
Upload folder using huggingface_hub
380204e verified
raw
history blame contribute delete
7.62 kB
"""Model loading and status management.
Supports two models:
- MiniCPM5-1B (text-only, fast)
- MiniCPM-V-4.6 (vision + text, image understanding)
Only one model is loaded at a time to conserve memory.
The model is loaded in a background thread on startup.
"""
from __future__ import annotations
import gc
import logging
import threading
from typing import Any
from code.config.constants import DEFAULT_MODEL_KEY, MODEL_CONFIGS
logger = logging.getLogger(__name__)
# ─── Module-level state ─────────────────────────────────────────────────
_current_model_key: str = DEFAULT_MODEL_KEY
_model = None
_tokenizer_or_processor = None
_model_loaded = False
_model_loading = False
_load_error: str | None = None
def _unload_model() -> None:
"""Unload current model and free memory."""
global _model, _tokenizer_or_processor, _model_loaded
if _model is not None:
del _model
_model = None
if _tokenizer_or_processor is not None:
del _tokenizer_or_processor
_tokenizer_or_processor = None
_model_loaded = False
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
def load_model(model_key: str | None = None) -> None:
"""Load a model by key. Unloads the previous model first."""
global _model, _tokenizer_or_processor, _model_loaded, _model_loading
global _load_error, _current_model_key
if model_key is None:
model_key = _current_model_key
if model_key not in MODEL_CONFIGS:
_load_error = f"Unknown model: {model_key}"
logger.error(_load_error)
return
# Skip if already loading or already loaded with same key
if _model_loading:
return
if _model_loaded and _current_model_key == model_key:
return
_model_loading = True
_load_error = None
# Unload previous model if switching
if _model_loaded and _current_model_key != model_key:
logger.info("Switching model from %s to %s", _current_model_key, model_key)
_unload_model()
_current_model_key = model_key
config = MODEL_CONFIGS[model_key]
model_id = config["id"]
try:
import torch
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device_map = "auto" if torch.cuda.is_available() else None
if config["type"] == "vlm":
_load_vlm_model(model_id, dtype, device_map)
else:
_load_text_model(model_id, dtype, device_map)
_model_loaded = True
logger.info("%s model loaded successfully.", config["name"])
except Exception as exc:
_load_error = str(exc)
logger.exception("Failed to load model %s: %s", model_id, exc)
finally:
_model_loading = False
def _load_text_model(model_id: str, dtype, device_map) -> None:
"""Load a text-only model (AutoModelForCausalLM + AutoTokenizer)."""
global _model, _tokenizer_or_processor
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info("Loading %s (text model)...", model_id)
_tokenizer_or_processor = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
)
_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
if device_map is None:
_model = _model.to("cpu")
_model.eval()
def _load_vlm_model(model_id: str, dtype, device_map) -> None:
"""Load a vision-language model (AutoModelForImageTextToText + AutoProcessor)."""
global _model, _tokenizer_or_processor
try:
from transformers import AutoModelForImageTextToText, AutoProcessor
except ImportError:
# Fallback for older transformers
logger.warning("AutoModelForImageTextToText not found, trying AutoModel...")
from transformers import AutoModel as AutoModelForImageTextToText
from transformers import AutoProcessor
logger.info("Loading %s (VLM)...", model_id)
_tokenizer_or_processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
)
_model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
if device_map is None:
_model = _model.to("cpu")
_model.eval()
def start_background_load(model_key: str | None = None) -> threading.Thread:
"""Start loading the model in a background daemon thread."""
thread = threading.Thread(target=load_model, args=(model_key,), daemon=True)
thread.start()
return thread
def switch_model(model_key: str) -> dict[str, Any]:
"""Switch to a different model. Returns status immediately, loads in background."""
global _current_model_key
if model_key not in MODEL_CONFIGS:
return {"success": False, "message": f"Unknown model: {model_key}"}
if _current_model_key == model_key and _model_loaded:
return {"success": True, "message": f"Already using {MODEL_CONFIGS[model_key]['name']}"}
_current_model_key = model_key
_model_loaded = False
# Start loading in background
start_background_load(model_key)
config = MODEL_CONFIGS[model_key]
return {
"success": True,
"message": f"Switching to {config['name']}...",
"model_key": model_key,
"model_name": config["name"],
}
def get_model_status() -> dict[str, Any]:
"""Return current model loading status."""
config = MODEL_CONFIGS.get(_current_model_key, {})
if _model_loaded:
return {
"status": "ready",
"message": f"{config.get('name', 'Model')} loaded and ready",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
if _model_loading:
return {
"status": "loading",
"message": f"Loading {config.get('name', 'model')}... (this may take a few minutes)",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
if _load_error:
return {
"status": "error",
"message": f"Model load error: {_load_error}",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
return {
"status": "unknown",
"message": "Model not initialized",
"model_key": _current_model_key,
"model_name": config.get("name", ""),
"model_type": config.get("type", "text"),
}
def get_model():
"""Return the loaded model instance (or None)."""
return _model
def get_tokenizer_or_processor():
"""Return the loaded tokenizer or processor (or None)."""
return _tokenizer_or_processor
def is_model_loaded() -> bool:
"""Return True if the model has been loaded successfully."""
return _model_loaded
def get_current_model_key() -> str:
"""Return the key of the currently selected model."""
return _current_model_key
def get_current_model_type() -> str:
"""Return 'text' or 'vlm' for the current model."""
return MODEL_CONFIGS.get(_current_model_key, {}).get("type", "text")