Spaces:
Running
Running
File size: 3,911 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """
LoRA adapter hot-swap manager.
Uses PEFT's multi-adapter API:
- model.load_adapter(path, adapter_name=lang) — first load (~2s per adapter)
- model.set_adapter(lang) — subsequent swap (~50ms)
This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights,
vs reloading the full 1.5GB model per language.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from peft import PeftModel
if TYPE_CHECKING:
from transformers import WhisperForConditionalGeneration
logger = logging.getLogger(__name__)
class AdapterManager:
"""Manages registration and hot-swapping of LoRA language adapters."""
def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None:
self._base_model = base_model
self._config = config
self._registry: dict[str, str] = {} # language_code -> adapter_path
self._peft_model: PeftModel | None = None
self._active: str | None = None
def register(self, language: str, adapter_path: str) -> None:
"""Register an adapter path. Does not load it yet."""
path = Path(adapter_path)
if not path.exists():
logger.warning(
"Adapter path '%s' for language '%s' does not exist. "
"Run training first, or check the path.",
adapter_path, language,
)
self._registry[language] = str(path)
logger.info("Registered adapter '%s' → %s", language, adapter_path)
def load_adapter(self, language: str) -> None:
"""
Load an adapter into the model for the first time.
Slow (~2s): reads adapter weights from disk.
Subsequent activate() calls reuse the already-loaded weights.
"""
if language not in self._registry:
raise KeyError(f"No adapter registered for language '{language}'. "
f"Available: {list(self._registry)}")
adapter_path = self._registry[language]
if self._peft_model is None:
# First adapter: wrap the base model with PeftModel
logger.info("Wrapping base model with first adapter '%s'...", language)
self._peft_model = PeftModel.from_pretrained(
self._base_model,
adapter_path,
adapter_name=language,
)
else:
# Subsequent adapters: load into the existing PeftModel
logger.info("Loading adapter '%s' into existing PeftModel...", language)
self._peft_model.load_adapter(adapter_path, adapter_name=language)
self._active = language
logger.info("Adapter '%s' loaded and active.", language)
def activate(self, language: str) -> None:
"""
Hot-swap to a previously loaded adapter (~50ms).
Call load_adapter() first if this adapter hasn't been loaded.
"""
if self._peft_model is None:
self.load_adapter(language)
return
loaded = set(self._peft_model.peft_config.keys())
if language not in loaded:
self.load_adapter(language)
return
self._peft_model.set_adapter(language)
self._active = language
logger.debug("Hot-swapped to adapter '%s'.", language)
def get_model(self) -> "WhisperForConditionalGeneration | PeftModel":
"""Return the PeftModel (or base model if no adapter loaded yet)."""
return self._peft_model if self._peft_model is not None else self._base_model
def get_active(self) -> str | None:
return self._active
def list_available(self) -> list[str]:
return list(self._registry.keys())
def list_loaded(self) -> list[str]:
if self._peft_model is None:
return []
return list(self._peft_model.peft_config.keys())
|