""" LoRA adapter hot-swap manager. Uses PEFT's multi-adapter API: - model.load_adapter(path, adapter_name=lang) — first load (~2s per adapter) - model.set_adapter(lang) — subsequent swap (~50ms) This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights, vs reloading the full 1.5GB model per language. """ from __future__ import annotations import logging from pathlib import Path from typing import TYPE_CHECKING from peft import PeftModel if TYPE_CHECKING: from transformers import WhisperForConditionalGeneration logger = logging.getLogger(__name__) class AdapterManager: """Manages registration and hot-swapping of LoRA language adapters.""" def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None: self._base_model = base_model self._config = config self._registry: dict[str, str] = {} # language_code -> adapter_path self._peft_model: PeftModel | None = None self._active: str | None = None def register(self, language: str, adapter_path: str) -> None: """Register an adapter path. Does not load it yet.""" path = Path(adapter_path) if not path.exists(): logger.warning( "Adapter path '%s' for language '%s' does not exist. " "Run training first, or check the path.", adapter_path, language, ) self._registry[language] = str(path) logger.info("Registered adapter '%s' → %s", language, adapter_path) def load_adapter(self, language: str) -> None: """ Load an adapter into the model for the first time. Slow (~2s): reads adapter weights from disk. Subsequent activate() calls reuse the already-loaded weights. """ if language not in self._registry: raise KeyError(f"No adapter registered for language '{language}'. " f"Available: {list(self._registry)}") adapter_path = self._registry[language] if self._peft_model is None: # First adapter: wrap the base model with PeftModel logger.info("Wrapping base model with first adapter '%s'...", language) self._peft_model = PeftModel.from_pretrained( self._base_model, adapter_path, adapter_name=language, ) else: # Subsequent adapters: load into the existing PeftModel logger.info("Loading adapter '%s' into existing PeftModel...", language) self._peft_model.load_adapter(adapter_path, adapter_name=language) self._active = language logger.info("Adapter '%s' loaded and active.", language) def activate(self, language: str) -> None: """ Hot-swap to a previously loaded adapter (~50ms). Call load_adapter() first if this adapter hasn't been loaded. """ if self._peft_model is None: self.load_adapter(language) return loaded = set(self._peft_model.peft_config.keys()) if language not in loaded: self.load_adapter(language) return self._peft_model.set_adapter(language) self._active = language logger.debug("Hot-swapped to adapter '%s'.", language) def get_model(self) -> "WhisperForConditionalGeneration | PeftModel": """Return the PeftModel (or base model if no adapter loaded yet).""" return self._peft_model if self._peft_model is not None else self._base_model def get_active(self) -> str | None: return self._active def list_available(self) -> list[str]: return list(self._registry.keys()) def list_loaded(self) -> list[str]: if self._peft_model is None: return [] return list(self._peft_model.peft_config.keys())