Spaces:
Running
Running
| """Model loading and status management. | |
| Supports two models: | |
| - MiniCPM5-1B (text-only, fast) | |
| - MiniCPM-V-4.6 (vision + text, image understanding) | |
| Only one model is loaded at a time to conserve memory. | |
| The model is loaded in a background thread on startup. | |
| """ | |
| from __future__ import annotations | |
| import gc | |
| import logging | |
| import threading | |
| from typing import Any | |
| from code.config.constants import DEFAULT_MODEL_KEY, MODEL_CONFIGS | |
| logger = logging.getLogger(__name__) | |
| # ─── Module-level state ───────────────────────────────────────────────── | |
| _current_model_key: str = DEFAULT_MODEL_KEY | |
| _model = None | |
| _tokenizer_or_processor = None | |
| _model_loaded = False | |
| _model_loading = False | |
| _load_error: str | None = None | |
| def _unload_model() -> None: | |
| """Unload current model and free memory.""" | |
| global _model, _tokenizer_or_processor, _model_loaded | |
| if _model is not None: | |
| del _model | |
| _model = None | |
| if _tokenizer_or_processor is not None: | |
| del _tokenizer_or_processor | |
| _tokenizer_or_processor = None | |
| _model_loaded = False | |
| gc.collect() | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except ImportError: | |
| pass | |
| def load_model(model_key: str | None = None) -> None: | |
| """Load a model by key. Unloads the previous model first.""" | |
| global _model, _tokenizer_or_processor, _model_loaded, _model_loading | |
| global _load_error, _current_model_key | |
| if model_key is None: | |
| model_key = _current_model_key | |
| if model_key not in MODEL_CONFIGS: | |
| _load_error = f"Unknown model: {model_key}" | |
| logger.error(_load_error) | |
| return | |
| # Skip if already loading or already loaded with same key | |
| if _model_loading: | |
| return | |
| if _model_loaded and _current_model_key == model_key: | |
| return | |
| _model_loading = True | |
| _load_error = None | |
| # Unload previous model if switching | |
| if _model_loaded and _current_model_key != model_key: | |
| logger.info("Switching model from %s to %s", _current_model_key, model_key) | |
| _unload_model() | |
| _current_model_key = model_key | |
| config = MODEL_CONFIGS[model_key] | |
| model_id = config["id"] | |
| try: | |
| import torch | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| device_map = "auto" if torch.cuda.is_available() else None | |
| if config["type"] == "vlm": | |
| _load_vlm_model(model_id, dtype, device_map) | |
| else: | |
| _load_text_model(model_id, dtype, device_map) | |
| _model_loaded = True | |
| logger.info("%s model loaded successfully.", config["name"]) | |
| except Exception as exc: | |
| _load_error = str(exc) | |
| logger.exception("Failed to load model %s: %s", model_id, exc) | |
| finally: | |
| _model_loading = False | |
| def _load_text_model(model_id: str, dtype, device_map) -> None: | |
| """Load a text-only model (AutoModelForCausalLM + AutoTokenizer).""" | |
| global _model, _tokenizer_or_processor | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info("Loading %s (text model)...", model_id) | |
| _tokenizer_or_processor = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| ) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| if device_map is None: | |
| _model = _model.to("cpu") | |
| _model.eval() | |
| def _load_vlm_model(model_id: str, dtype, device_map) -> None: | |
| """Load a vision-language model (AutoModelForImageTextToText + AutoProcessor).""" | |
| global _model, _tokenizer_or_processor | |
| try: | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| except ImportError: | |
| # Fallback for older transformers | |
| logger.warning("AutoModelForImageTextToText not found, trying AutoModel...") | |
| from transformers import AutoModel as AutoModelForImageTextToText | |
| from transformers import AutoProcessor | |
| logger.info("Loading %s (VLM)...", model_id) | |
| _tokenizer_or_processor = AutoProcessor.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| ) | |
| _model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| if device_map is None: | |
| _model = _model.to("cpu") | |
| _model.eval() | |
| def start_background_load(model_key: str | None = None) -> threading.Thread: | |
| """Start loading the model in a background daemon thread.""" | |
| thread = threading.Thread(target=load_model, args=(model_key,), daemon=True) | |
| thread.start() | |
| return thread | |
| def switch_model(model_key: str) -> dict[str, Any]: | |
| """Switch to a different model. Returns status immediately, loads in background.""" | |
| global _current_model_key | |
| if model_key not in MODEL_CONFIGS: | |
| return {"success": False, "message": f"Unknown model: {model_key}"} | |
| if _current_model_key == model_key and _model_loaded: | |
| return {"success": True, "message": f"Already using {MODEL_CONFIGS[model_key]['name']}"} | |
| _current_model_key = model_key | |
| _model_loaded = False | |
| # Start loading in background | |
| start_background_load(model_key) | |
| config = MODEL_CONFIGS[model_key] | |
| return { | |
| "success": True, | |
| "message": f"Switching to {config['name']}...", | |
| "model_key": model_key, | |
| "model_name": config["name"], | |
| } | |
| def get_model_status() -> dict[str, Any]: | |
| """Return current model loading status.""" | |
| config = MODEL_CONFIGS.get(_current_model_key, {}) | |
| if _model_loaded: | |
| return { | |
| "status": "ready", | |
| "message": f"{config.get('name', 'Model')} loaded and ready", | |
| "model_key": _current_model_key, | |
| "model_name": config.get("name", ""), | |
| "model_type": config.get("type", "text"), | |
| } | |
| if _model_loading: | |
| return { | |
| "status": "loading", | |
| "message": f"Loading {config.get('name', 'model')}... (this may take a few minutes)", | |
| "model_key": _current_model_key, | |
| "model_name": config.get("name", ""), | |
| "model_type": config.get("type", "text"), | |
| } | |
| if _load_error: | |
| return { | |
| "status": "error", | |
| "message": f"Model load error: {_load_error}", | |
| "model_key": _current_model_key, | |
| "model_name": config.get("name", ""), | |
| "model_type": config.get("type", "text"), | |
| } | |
| return { | |
| "status": "unknown", | |
| "message": "Model not initialized", | |
| "model_key": _current_model_key, | |
| "model_name": config.get("name", ""), | |
| "model_type": config.get("type", "text"), | |
| } | |
| def get_model(): | |
| """Return the loaded model instance (or None).""" | |
| return _model | |
| def get_tokenizer_or_processor(): | |
| """Return the loaded tokenizer or processor (or None).""" | |
| return _tokenizer_or_processor | |
| def is_model_loaded() -> bool: | |
| """Return True if the model has been loaded successfully.""" | |
| return _model_loaded | |
| def get_current_model_key() -> str: | |
| """Return the key of the currently selected model.""" | |
| return _current_model_key | |
| def get_current_model_type() -> str: | |
| """Return 'text' or 'vlm' for the current model.""" | |
| return MODEL_CONFIGS.get(_current_model_key, {}).get("type", "text") | |