"""Model loading and status management. Supports two models: - MiniCPM5-1B (text-only, fast) - MiniCPM-V-4.6 (vision + text, image understanding) Only one model is loaded at a time to conserve memory. The model is loaded in a background thread on startup. """ from __future__ import annotations import gc import logging import threading from typing import Any from code.config.constants import DEFAULT_MODEL_KEY, MODEL_CONFIGS logger = logging.getLogger(__name__) # ─── Module-level state ───────────────────────────────────────────────── _current_model_key: str = DEFAULT_MODEL_KEY _model = None _tokenizer_or_processor = None _model_loaded = False _model_loading = False _load_error: str | None = None def _unload_model() -> None: """Unload current model and free memory.""" global _model, _tokenizer_or_processor, _model_loaded if _model is not None: del _model _model = None if _tokenizer_or_processor is not None: del _tokenizer_or_processor _tokenizer_or_processor = None _model_loaded = False gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass def load_model(model_key: str | None = None) -> None: """Load a model by key. Unloads the previous model first.""" global _model, _tokenizer_or_processor, _model_loaded, _model_loading global _load_error, _current_model_key if model_key is None: model_key = _current_model_key if model_key not in MODEL_CONFIGS: _load_error = f"Unknown model: {model_key}" logger.error(_load_error) return # Skip if already loading or already loaded with same key if _model_loading: return if _model_loaded and _current_model_key == model_key: return _model_loading = True _load_error = None # Unload previous model if switching if _model_loaded and _current_model_key != model_key: logger.info("Switching model from %s to %s", _current_model_key, model_key) _unload_model() _current_model_key = model_key config = MODEL_CONFIGS[model_key] model_id = config["id"] try: import torch dtype = torch.float16 if torch.cuda.is_available() else torch.float32 device_map = "auto" if torch.cuda.is_available() else None if config["type"] == "vlm": _load_vlm_model(model_id, dtype, device_map) else: _load_text_model(model_id, dtype, device_map) _model_loaded = True logger.info("%s model loaded successfully.", config["name"]) except Exception as exc: _load_error = str(exc) logger.exception("Failed to load model %s: %s", model_id, exc) finally: _model_loading = False def _load_text_model(model_id: str, dtype, device_map) -> None: """Load a text-only model (AutoModelForCausalLM + AutoTokenizer).""" global _model, _tokenizer_or_processor from transformers import AutoModelForCausalLM, AutoTokenizer logger.info("Loading %s (text model)...", model_id) _tokenizer_or_processor = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, ) _model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, ) if device_map is None: _model = _model.to("cpu") _model.eval() def _load_vlm_model(model_id: str, dtype, device_map) -> None: """Load a vision-language model (AutoModelForImageTextToText + AutoProcessor).""" global _model, _tokenizer_or_processor try: from transformers import AutoModelForImageTextToText, AutoProcessor except ImportError: # Fallback for older transformers logger.warning("AutoModelForImageTextToText not found, trying AutoModel...") from transformers import AutoModel as AutoModelForImageTextToText from transformers import AutoProcessor logger.info("Loading %s (VLM)...", model_id) _tokenizer_or_processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, ) _model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, ) if device_map is None: _model = _model.to("cpu") _model.eval() def start_background_load(model_key: str | None = None) -> threading.Thread: """Start loading the model in a background daemon thread.""" thread = threading.Thread(target=load_model, args=(model_key,), daemon=True) thread.start() return thread def switch_model(model_key: str) -> dict[str, Any]: """Switch to a different model. Returns status immediately, loads in background.""" global _current_model_key if model_key not in MODEL_CONFIGS: return {"success": False, "message": f"Unknown model: {model_key}"} if _current_model_key == model_key and _model_loaded: return {"success": True, "message": f"Already using {MODEL_CONFIGS[model_key]['name']}"} _current_model_key = model_key _model_loaded = False # Start loading in background start_background_load(model_key) config = MODEL_CONFIGS[model_key] return { "success": True, "message": f"Switching to {config['name']}...", "model_key": model_key, "model_name": config["name"], } def get_model_status() -> dict[str, Any]: """Return current model loading status.""" config = MODEL_CONFIGS.get(_current_model_key, {}) if _model_loaded: return { "status": "ready", "message": f"{config.get('name', 'Model')} loaded and ready", "model_key": _current_model_key, "model_name": config.get("name", ""), "model_type": config.get("type", "text"), } if _model_loading: return { "status": "loading", "message": f"Loading {config.get('name', 'model')}... (this may take a few minutes)", "model_key": _current_model_key, "model_name": config.get("name", ""), "model_type": config.get("type", "text"), } if _load_error: return { "status": "error", "message": f"Model load error: {_load_error}", "model_key": _current_model_key, "model_name": config.get("name", ""), "model_type": config.get("type", "text"), } return { "status": "unknown", "message": "Model not initialized", "model_key": _current_model_key, "model_name": config.get("name", ""), "model_type": config.get("type", "text"), } def get_model(): """Return the loaded model instance (or None).""" return _model def get_tokenizer_or_processor(): """Return the loaded tokenizer or processor (or None).""" return _tokenizer_or_processor def is_model_loaded() -> bool: """Return True if the model has been loaded successfully.""" return _model_loaded def get_current_model_key() -> str: """Return the key of the currently selected model.""" return _current_model_key def get_current_model_type() -> str: """Return 'text' or 'vlm' for the current model.""" return MODEL_CONFIGS.get(_current_model_key, {}).get("type", "text")