| """Model loading and status management. |
| |
| Supports two models: |
| - MiniCPM5-1B (text-only, fast) |
| - MiniCPM-V-4.6 (vision + text, image understanding) |
| |
| Only one model is loaded at a time to conserve memory. |
| The model is loaded in a background thread on startup. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import gc |
| import logging |
| import threading |
| from typing import Any |
|
|
| from code.config.constants import DEFAULT_MODEL_KEY, MODEL_CONFIGS |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| _current_model_key: str = DEFAULT_MODEL_KEY |
| _model = None |
| _tokenizer_or_processor = None |
| _model_loaded = False |
| _model_loading = False |
| _load_error: str | None = None |
|
|
|
|
| def _unload_model() -> None: |
| """Unload current model and free memory.""" |
| global _model, _tokenizer_or_processor, _model_loaded |
|
|
| if _model is not None: |
| del _model |
| _model = None |
| if _tokenizer_or_processor is not None: |
| del _tokenizer_or_processor |
| _tokenizer_or_processor = None |
|
|
| _model_loaded = False |
| gc.collect() |
|
|
| try: |
| import torch |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except ImportError: |
| pass |
|
|
|
|
| def load_model(model_key: str | None = None) -> None: |
| """Load a model by key. Unloads the previous model first.""" |
| global _model, _tokenizer_or_processor, _model_loaded, _model_loading |
| global _load_error, _current_model_key |
|
|
| if model_key is None: |
| model_key = _current_model_key |
|
|
| if model_key not in MODEL_CONFIGS: |
| _load_error = f"Unknown model: {model_key}" |
| logger.error(_load_error) |
| return |
|
|
| |
| if _model_loading: |
| return |
| if _model_loaded and _current_model_key == model_key: |
| return |
|
|
| _model_loading = True |
| _load_error = None |
|
|
| |
| if _model_loaded and _current_model_key != model_key: |
| logger.info("Switching model from %s to %s", _current_model_key, model_key) |
| _unload_model() |
|
|
| _current_model_key = model_key |
| config = MODEL_CONFIGS[model_key] |
| model_id = config["id"] |
|
|
| try: |
| import torch |
|
|
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| device_map = "auto" if torch.cuda.is_available() else None |
|
|
| if config["type"] == "vlm": |
| _load_vlm_model(model_id, dtype, device_map) |
| else: |
| _load_text_model(model_id, dtype, device_map) |
|
|
| _model_loaded = True |
| logger.info("%s model loaded successfully.", config["name"]) |
|
|
| except Exception as exc: |
| _load_error = str(exc) |
| logger.exception("Failed to load model %s: %s", model_id, exc) |
| finally: |
| _model_loading = False |
|
|
|
|
| def _load_text_model(model_id: str, dtype, device_map) -> None: |
| """Load a text-only model (AutoModelForCausalLM + AutoTokenizer).""" |
| global _model, _tokenizer_or_processor |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| logger.info("Loading %s (text model)...", model_id) |
|
|
| _tokenizer_or_processor = AutoTokenizer.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| ) |
| _model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=dtype, |
| device_map=device_map, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| ) |
|
|
| if device_map is None: |
| _model = _model.to("cpu") |
|
|
| _model.eval() |
|
|
|
|
| def _load_vlm_model(model_id: str, dtype, device_map) -> None: |
| """Load a vision-language model (AutoModelForImageTextToText + AutoProcessor).""" |
| global _model, _tokenizer_or_processor |
|
|
| try: |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
| except ImportError: |
| |
| logger.warning("AutoModelForImageTextToText not found, trying AutoModel...") |
| from transformers import AutoModel as AutoModelForImageTextToText |
| from transformers import AutoProcessor |
|
|
| logger.info("Loading %s (VLM)...", model_id) |
|
|
| _tokenizer_or_processor = AutoProcessor.from_pretrained( |
| model_id, |
| trust_remote_code=True, |
| ) |
| _model = AutoModelForImageTextToText.from_pretrained( |
| model_id, |
| torch_dtype=dtype, |
| device_map=device_map, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| ) |
|
|
| if device_map is None: |
| _model = _model.to("cpu") |
|
|
| _model.eval() |
|
|
|
|
| def start_background_load(model_key: str | None = None) -> threading.Thread: |
| """Start loading the model in a background daemon thread.""" |
| thread = threading.Thread(target=load_model, args=(model_key,), daemon=True) |
| thread.start() |
| return thread |
|
|
|
|
| def switch_model(model_key: str) -> dict[str, Any]: |
| """Switch to a different model. Returns status immediately, loads in background.""" |
| global _current_model_key |
|
|
| if model_key not in MODEL_CONFIGS: |
| return {"success": False, "message": f"Unknown model: {model_key}"} |
|
|
| if _current_model_key == model_key and _model_loaded: |
| return {"success": True, "message": f"Already using {MODEL_CONFIGS[model_key]['name']}"} |
|
|
| _current_model_key = model_key |
| _model_loaded = False |
|
|
| |
| start_background_load(model_key) |
|
|
| config = MODEL_CONFIGS[model_key] |
| return { |
| "success": True, |
| "message": f"Switching to {config['name']}...", |
| "model_key": model_key, |
| "model_name": config["name"], |
| } |
|
|
|
|
| def get_model_status() -> dict[str, Any]: |
| """Return current model loading status.""" |
| config = MODEL_CONFIGS.get(_current_model_key, {}) |
| if _model_loaded: |
| return { |
| "status": "ready", |
| "message": f"{config.get('name', 'Model')} loaded and ready", |
| "model_key": _current_model_key, |
| "model_name": config.get("name", ""), |
| "model_type": config.get("type", "text"), |
| } |
| if _model_loading: |
| return { |
| "status": "loading", |
| "message": f"Loading {config.get('name', 'model')}... (this may take a few minutes)", |
| "model_key": _current_model_key, |
| "model_name": config.get("name", ""), |
| "model_type": config.get("type", "text"), |
| } |
| if _load_error: |
| return { |
| "status": "error", |
| "message": f"Model load error: {_load_error}", |
| "model_key": _current_model_key, |
| "model_name": config.get("name", ""), |
| "model_type": config.get("type", "text"), |
| } |
| return { |
| "status": "unknown", |
| "message": "Model not initialized", |
| "model_key": _current_model_key, |
| "model_name": config.get("name", ""), |
| "model_type": config.get("type", "text"), |
| } |
|
|
|
|
| def get_model(): |
| """Return the loaded model instance (or None).""" |
| return _model |
|
|
|
|
| def get_tokenizer_or_processor(): |
| """Return the loaded tokenizer or processor (or None).""" |
| return _tokenizer_or_processor |
|
|
|
|
| def is_model_loaded() -> bool: |
| """Return True if the model has been loaded successfully.""" |
| return _model_loaded |
|
|
|
|
| def get_current_model_key() -> str: |
| """Return the key of the currently selected model.""" |
| return _current_model_key |
|
|
|
|
| def get_current_model_type() -> str: |
| """Return 'text' or 'vlm' for the current model.""" |
| return MODEL_CONFIGS.get(_current_model_key, {}).get("type", "text") |
|
|