File size: 3,911 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
LoRA adapter hot-swap manager.

Uses PEFT's multi-adapter API:
  - model.load_adapter(path, adapter_name=lang)  — first load (~2s per adapter)
  - model.set_adapter(lang)                       — subsequent swap (~50ms)

This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights,
vs reloading the full 1.5GB model per language.
"""
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

from peft import PeftModel

if TYPE_CHECKING:
    from transformers import WhisperForConditionalGeneration

logger = logging.getLogger(__name__)


class AdapterManager:
    """Manages registration and hot-swapping of LoRA language adapters."""

    def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None:
        self._base_model = base_model
        self._config = config
        self._registry: dict[str, str] = {}  # language_code -> adapter_path
        self._peft_model: PeftModel | None = None
        self._active: str | None = None

    def register(self, language: str, adapter_path: str) -> None:
        """Register an adapter path. Does not load it yet."""
        path = Path(adapter_path)
        if not path.exists():
            logger.warning(
                "Adapter path '%s' for language '%s' does not exist. "
                "Run training first, or check the path.",
                adapter_path, language,
            )
        self._registry[language] = str(path)
        logger.info("Registered adapter '%s' → %s", language, adapter_path)

    def load_adapter(self, language: str) -> None:
        """
        Load an adapter into the model for the first time.
        Slow (~2s): reads adapter weights from disk.
        Subsequent activate() calls reuse the already-loaded weights.
        """
        if language not in self._registry:
            raise KeyError(f"No adapter registered for language '{language}'. "
                           f"Available: {list(self._registry)}")

        adapter_path = self._registry[language]

        if self._peft_model is None:
            # First adapter: wrap the base model with PeftModel
            logger.info("Wrapping base model with first adapter '%s'...", language)
            self._peft_model = PeftModel.from_pretrained(
                self._base_model,
                adapter_path,
                adapter_name=language,
            )
        else:
            # Subsequent adapters: load into the existing PeftModel
            logger.info("Loading adapter '%s' into existing PeftModel...", language)
            self._peft_model.load_adapter(adapter_path, adapter_name=language)

        self._active = language
        logger.info("Adapter '%s' loaded and active.", language)

    def activate(self, language: str) -> None:
        """
        Hot-swap to a previously loaded adapter (~50ms).
        Call load_adapter() first if this adapter hasn't been loaded.
        """
        if self._peft_model is None:
            self.load_adapter(language)
            return

        loaded = set(self._peft_model.peft_config.keys())
        if language not in loaded:
            self.load_adapter(language)
            return

        self._peft_model.set_adapter(language)
        self._active = language
        logger.debug("Hot-swapped to adapter '%s'.", language)

    def get_model(self) -> "WhisperForConditionalGeneration | PeftModel":
        """Return the PeftModel (or base model if no adapter loaded yet)."""
        return self._peft_model if self._peft_model is not None else self._base_model

    def get_active(self) -> str | None:
        return self._active

    def list_available(self) -> list[str]:
        return list(self._registry.keys())

    def list_loaded(self) -> list[str]:
        if self._peft_model is None:
            return []
        return list(self._peft_model.peft_config.keys())