| """ |
| model/lora.py — LoRA (Low-Rank Adaptation) for EVAFRILL-Mo hybrid models. |
| |
| Injects trainable low-rank adapters into: |
| - Attention layers: qkv_proj, out_proj |
| - Mamba-2 layers: in_proj, out_proj |
| |
| Usage: |
| model = LLM.from_pretrained(checkpoint) |
| apply_lora(model, rank=32, alpha=64) |
| # Only LoRA params are trainable; base model is frozen |
| |
| # After training, merge LoRA weights back: |
| merge_lora(model) |
| |
| # Or save/load LoRA weights separately: |
| save_lora(model, path) |
| load_lora(model, path) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .attention import MultiHeadAttention |
| from .mamba_block import Mamba2Block |
|
|
|
|
| class LoRALinear(nn.Module): |
| """LoRA adapter wrapping an existing nn.Linear layer. |
| |
| Computes: output = original_linear(x) + (alpha/rank) * x @ A^T @ B^T |
| where A: (rank, in_features), B: (out_features, rank) |
| """ |
|
|
| def __init__( |
| self, |
| original: nn.Linear, |
| rank: int = 32, |
| alpha: float = 64.0, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.original = original |
| self.rank = rank |
| self.alpha = alpha |
| self.scaling = alpha / rank |
|
|
| in_features = original.in_features |
| out_features = original.out_features |
|
|
| |
| |
| _dev = original.weight.device |
| _dt = original.weight.dtype |
| self.lora_A = nn.Parameter(torch.empty(rank, in_features, device=_dev, dtype=_dt)) |
| |
| self.lora_B = nn.Parameter(torch.zeros(out_features, rank, device=_dev, dtype=_dt)) |
|
|
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| |
|
|
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| |
| original.weight.requires_grad = False |
| if original.bias is not None: |
| original.bias.requires_grad = False |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| result = self.original(x) |
| |
| lora_out = self.dropout(x) |
| lora_out = F.linear(lora_out, self.lora_A) |
| lora_out = F.linear(lora_out, self.lora_B) |
| return result + lora_out * self.scaling |
|
|
| def merge_weights(self) -> None: |
| """Merge LoRA weights into the original linear layer permanently.""" |
| with torch.no_grad(): |
| |
| self.original.weight.add_( |
| (self.lora_B @ self.lora_A) * self.scaling |
| ) |
|
|
| @property |
| def weight(self) -> torch.Tensor: |
| """Access original weight for compatibility.""" |
| return self.original.weight |
|
|
| @property |
| def bias(self) -> Optional[torch.Tensor]: |
| return self.original.bias |
|
|
|
|
| def apply_lora( |
| model: nn.Module, |
| rank: int = 32, |
| alpha: float = 64.0, |
| dropout: float = 0.0, |
| target_modules: Optional[list[str]] = None, |
| ) -> int: |
| """Apply LoRA adapters to a model, freeze base weights. |
| |
| Args: |
| model: The LLM model (raw, not DDP-wrapped). |
| rank: LoRA rank (default 32). |
| alpha: LoRA scaling factor (default 64). |
| dropout: Dropout on LoRA path (default 0). |
| target_modules: List of module attribute names to adapt. |
| Default: ["qkv_proj", "out_proj", "in_proj"] |
| (covers both Attention and Mamba layers). |
| |
| Returns: |
| Number of LoRA parameters added. |
| """ |
| if target_modules is None: |
| target_modules = ["qkv_proj", "out_proj", "in_proj"] |
|
|
| |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| lora_count = 0 |
| total_lora_params = 0 |
|
|
| for name, module in model.named_modules(): |
| |
| if isinstance(module, MultiHeadAttention): |
| for attr in target_modules: |
| if hasattr(module, attr): |
| original = getattr(module, attr) |
| if isinstance(original, nn.Linear): |
| lora_layer = LoRALinear(original, rank=rank, alpha=alpha, dropout=dropout) |
| setattr(module, attr, lora_layer) |
| params = rank * original.in_features + original.out_features * rank |
| total_lora_params += params |
| lora_count += 1 |
|
|
| |
| elif isinstance(module, Mamba2Block): |
| for attr in target_modules: |
| if hasattr(module, attr): |
| original = getattr(module, attr) |
| if isinstance(original, nn.Linear): |
| lora_layer = LoRALinear(original, rank=rank, alpha=alpha, dropout=dropout) |
| setattr(module, attr, lora_layer) |
| params = rank * original.in_features + original.out_features * rank |
| total_lora_params += params |
| lora_count += 1 |
|
|
| print(f"[LoRA] Applied {lora_count} adapters, {total_lora_params:,} trainable params " |
| f"(rank={rank}, alpha={alpha})") |
| return total_lora_params |
|
|
|
|
| def merge_lora(model: nn.Module) -> int: |
| """Merge all LoRA weights back into base model and remove LoRA layers. |
| |
| Returns: |
| Number of LoRA layers merged. |
| """ |
| merged = 0 |
| for name, module in model.named_modules(): |
| for attr_name in list(vars(module).keys()): |
| |
| pass |
|
|
| if isinstance(module, (MultiHeadAttention, Mamba2Block)): |
| for attr in ["qkv_proj", "out_proj", "in_proj"]: |
| if hasattr(module, attr): |
| layer = getattr(module, attr) |
| if isinstance(layer, LoRALinear): |
| layer.merge_weights() |
| setattr(module, attr, layer.original) |
| merged += 1 |
|
|
| |
| for param in model.parameters(): |
| param.requires_grad = True |
|
|
| print(f"[LoRA] Merged {merged} adapters back into base model") |
| return merged |
|
|
|
|
| def get_lora_params(model: nn.Module) -> list[nn.Parameter]: |
| """Get all LoRA trainable parameters.""" |
| params = [] |
| for module in model.modules(): |
| if isinstance(module, LoRALinear): |
| params.append(module.lora_A) |
| params.append(module.lora_B) |
| return params |
|
|
|
|
| def save_lora(model: nn.Module, path: str | Path) -> Path: |
| """Save only the LoRA adapter weights.""" |
| path = Path(path) |
| path.mkdir(parents=True, exist_ok=True) |
|
|
| lora_state = {} |
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALinear): |
| lora_state[f"{name}.lora_A"] = module.lora_A.data.cpu() |
| lora_state[f"{name}.lora_B"] = module.lora_B.data.cpu() |
|
|
| save_path = path / "lora_weights.pt" |
| torch.save(lora_state, save_path) |
| n_params = sum(v.numel() for v in lora_state.values()) |
| size_mb = save_path.stat().st_size / 1e6 |
| print(f"[LoRA] Saved {len(lora_state)} tensors ({n_params:,} params, {size_mb:.1f} MB) → {save_path}") |
| return save_path |
|
|
|
|
| def load_lora(model: nn.Module, path: str | Path) -> int: |
| """Load LoRA adapter weights. LoRA layers must already be applied.""" |
| path = Path(path) |
| lora_file = path / "lora_weights.pt" if path.is_dir() else path |
| lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) |
|
|
| loaded = 0 |
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALinear): |
| a_key = f"{name}.lora_A" |
| b_key = f"{name}.lora_B" |
| if a_key in lora_state and b_key in lora_state: |
| module.lora_A.data.copy_(lora_state[a_key]) |
| module.lora_B.data.copy_(lora_state[b_key]) |
| loaded += 1 |
|
|
| print(f"[LoRA] Loaded {loaded} adapter weight pairs from {lora_file}") |
| return loaded |
|
|