File size: 8,325 Bytes
29fc577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
model/lora.py β€” LoRA (Low-Rank Adaptation) for EVAFRILL-Mo hybrid models.

Injects trainable low-rank adapters into:
  - Attention layers: qkv_proj, out_proj
  - Mamba-2 layers: in_proj, out_proj

Usage:
    model = LLM.from_pretrained(checkpoint)
    apply_lora(model, rank=32, alpha=64)
    # Only LoRA params are trainable; base model is frozen

    # After training, merge LoRA weights back:
    merge_lora(model)

    # Or save/load LoRA weights separately:
    save_lora(model, path)
    load_lora(model, path)
"""

from __future__ import annotations

import math
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .attention import MultiHeadAttention
from .mamba_block import Mamba2Block


class LoRALinear(nn.Module):
    """LoRA adapter wrapping an existing nn.Linear layer.

    Computes: output = original_linear(x) + (alpha/rank) * x @ A^T @ B^T
    where A: (rank, in_features), B: (out_features, rank)
    """

    def __init__(
        self,
        original: nn.Linear,
        rank: int = 32,
        alpha: float = 64.0,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.original = original
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        in_features = original.in_features
        out_features = original.out_features

        # A: down-projection (in_features β†’ rank)
        # Create on same device/dtype as original weights
        _dev = original.weight.device
        _dt = original.weight.dtype
        self.lora_A = nn.Parameter(torch.empty(rank, in_features, device=_dev, dtype=_dt))
        # B: up-projection (rank β†’ out_features)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank, device=_dev, dtype=_dt))

        # Initialize A with kaiming uniform, B with zeros
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        # B is already zeros β†’ initial LoRA output is zero

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # Freeze original weights
        original.weight.requires_grad = False
        if original.bias is not None:
            original.bias.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original forward
        result = self.original(x)
        # LoRA path: x β†’ dropout β†’ A β†’ B β†’ scale
        lora_out = self.dropout(x)
        lora_out = F.linear(lora_out, self.lora_A)  # (..., rank)
        lora_out = F.linear(lora_out, self.lora_B)  # (..., out_features)
        return result + lora_out * self.scaling

    def merge_weights(self) -> None:
        """Merge LoRA weights into the original linear layer permanently."""
        with torch.no_grad():
            # W' = W + scaling * B @ A
            self.original.weight.add_(
                (self.lora_B @ self.lora_A) * self.scaling
            )

    @property
    def weight(self) -> torch.Tensor:
        """Access original weight for compatibility."""
        return self.original.weight

    @property
    def bias(self) -> Optional[torch.Tensor]:
        return self.original.bias


def apply_lora(
    model: nn.Module,
    rank: int = 32,
    alpha: float = 64.0,
    dropout: float = 0.0,
    target_modules: Optional[list[str]] = None,
) -> int:
    """Apply LoRA adapters to a model, freeze base weights.

    Args:
        model: The LLM model (raw, not DDP-wrapped).
        rank: LoRA rank (default 32).
        alpha: LoRA scaling factor (default 64).
        dropout: Dropout on LoRA path (default 0).
        target_modules: List of module attribute names to adapt.
            Default: ["qkv_proj", "out_proj", "in_proj"]
            (covers both Attention and Mamba layers).

    Returns:
        Number of LoRA parameters added.
    """
    if target_modules is None:
        target_modules = ["qkv_proj", "out_proj", "in_proj"]

    # First, freeze ALL parameters
    for param in model.parameters():
        param.requires_grad = False

    lora_count = 0
    total_lora_params = 0

    for name, module in model.named_modules():
        # Check Attention layers
        if isinstance(module, MultiHeadAttention):
            for attr in target_modules:
                if hasattr(module, attr):
                    original = getattr(module, attr)
                    if isinstance(original, nn.Linear):
                        lora_layer = LoRALinear(original, rank=rank, alpha=alpha, dropout=dropout)
                        setattr(module, attr, lora_layer)
                        params = rank * original.in_features + original.out_features * rank
                        total_lora_params += params
                        lora_count += 1

        # Check Mamba layers
        elif isinstance(module, Mamba2Block):
            for attr in target_modules:
                if hasattr(module, attr):
                    original = getattr(module, attr)
                    if isinstance(original, nn.Linear):
                        lora_layer = LoRALinear(original, rank=rank, alpha=alpha, dropout=dropout)
                        setattr(module, attr, lora_layer)
                        params = rank * original.in_features + original.out_features * rank
                        total_lora_params += params
                        lora_count += 1

    print(f"[LoRA] Applied {lora_count} adapters, {total_lora_params:,} trainable params "
          f"(rank={rank}, alpha={alpha})")
    return total_lora_params


def merge_lora(model: nn.Module) -> int:
    """Merge all LoRA weights back into base model and remove LoRA layers.

    Returns:
        Number of LoRA layers merged.
    """
    merged = 0
    for name, module in model.named_modules():
        for attr_name in list(vars(module).keys()):
            # Check nn.Module children
            pass

        if isinstance(module, (MultiHeadAttention, Mamba2Block)):
            for attr in ["qkv_proj", "out_proj", "in_proj"]:
                if hasattr(module, attr):
                    layer = getattr(module, attr)
                    if isinstance(layer, LoRALinear):
                        layer.merge_weights()
                        setattr(module, attr, layer.original)
                        merged += 1

    # Unfreeze all parameters after merging
    for param in model.parameters():
        param.requires_grad = True

    print(f"[LoRA] Merged {merged} adapters back into base model")
    return merged


def get_lora_params(model: nn.Module) -> list[nn.Parameter]:
    """Get all LoRA trainable parameters."""
    params = []
    for module in model.modules():
        if isinstance(module, LoRALinear):
            params.append(module.lora_A)
            params.append(module.lora_B)
    return params


def save_lora(model: nn.Module, path: str | Path) -> Path:
    """Save only the LoRA adapter weights."""
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    lora_state = {}
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            lora_state[f"{name}.lora_A"] = module.lora_A.data.cpu()
            lora_state[f"{name}.lora_B"] = module.lora_B.data.cpu()

    save_path = path / "lora_weights.pt"
    torch.save(lora_state, save_path)
    n_params = sum(v.numel() for v in lora_state.values())
    size_mb = save_path.stat().st_size / 1e6
    print(f"[LoRA] Saved {len(lora_state)} tensors ({n_params:,} params, {size_mb:.1f} MB) β†’ {save_path}")
    return save_path


def load_lora(model: nn.Module, path: str | Path) -> int:
    """Load LoRA adapter weights. LoRA layers must already be applied."""
    path = Path(path)
    lora_file = path / "lora_weights.pt" if path.is_dir() else path
    lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)

    loaded = 0
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            a_key = f"{name}.lora_A"
            b_key = f"{name}.lora_B"
            if a_key in lora_state and b_key in lora_state:
                module.lora_A.data.copy_(lora_state[a_key])
                module.lora_B.data.copy_(lora_state[b_key])
                loaded += 1

    print(f"[LoRA] Loaded {loaded} adapter weight pairs from {lora_file}")
    return loaded