|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Set |
|
|
|
|
|
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
|
|
|
|
|
def inverse_2x2(matrices): |
|
|
|
|
|
|
|
|
|
|
|
a = matrices[..., 0, 0] |
|
|
b = matrices[..., 0, 1] |
|
|
c = matrices[..., 1, 0] |
|
|
d = matrices[..., 1, 1] |
|
|
|
|
|
|
|
|
det = a * d - b * c |
|
|
|
|
|
|
|
|
|
|
|
inv_det = 1.0 / det |
|
|
|
|
|
|
|
|
inv_matrices = torch.empty_like(matrices) |
|
|
inv_matrices[..., 0, 0] = d * inv_det |
|
|
inv_matrices[..., 0, 1] = -b * inv_det |
|
|
inv_matrices[..., 1, 0] = -c * inv_det |
|
|
inv_matrices[..., 1, 1] = a * inv_det |
|
|
|
|
|
return inv_matrices |
|
|
|
|
|
class Rotation(nn.Module): |
|
|
""" |
|
|
Rotation layer based on Cayley transformation for parameter-efficient fine-tuning. |
|
|
|
|
|
This layer implements orthogonal fine-tuning through Cayley transformation: |
|
|
h(x) = (I - A)^{-1} (I + A) x |
|
|
|
|
|
where A = XY^T with X = [U; -V] and Y = [V; U] |
|
|
""" |
|
|
|
|
|
def __init__(self, r, dim, T=1.0, num_rotations=4): |
|
|
super().__init__() |
|
|
self.r = r |
|
|
self.T = T |
|
|
self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True) |
|
|
self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True) |
|
|
self.num_rotations = num_rotations |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Apply Cayley transformation to input x. |
|
|
|
|
|
A = XY^T where X = [U; -V], Y = [V; U] |
|
|
Cayley transformation: h(x) = (I - A)^{-1} (I + A) x |
|
|
|
|
|
Uses Woodbury identity for efficient computation: |
|
|
(I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (..., dim) |
|
|
|
|
|
Returns: |
|
|
Transformed tensor of shape (..., dim) |
|
|
""" |
|
|
x_dtype = x.dtype |
|
|
X = torch.cat([self.U, -self.V], dim=1) |
|
|
Y = torch.cat([self.V, self.U], dim=1) * self.T |
|
|
|
|
|
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) |
|
|
I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1) |
|
|
I_minus_YX = I_2r - Y_T_X |
|
|
|
|
|
if self.r == 1: |
|
|
I_minus_YX_inv = inverse_2x2(I_minus_YX) |
|
|
else: |
|
|
|
|
|
I_minus_YX = I_minus_YX.to(torch.float32) |
|
|
I_minus_YX_inv = torch.linalg.inv(I_minus_YX) |
|
|
I_minus_YX_inv = I_minus_YX_inv.to(x_dtype) |
|
|
|
|
|
Yx = torch.einsum("...d,nrd->...nr", x, Y) |
|
|
I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx) |
|
|
|
|
|
second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) |
|
|
second_term = second_term.sum(dim=-2) |
|
|
|
|
|
output = x + 2 * second_term |
|
|
|
|
|
return output |
|
|
|
|
|
def get_delta_weight(self): |
|
|
""" |
|
|
Compute the delta weight matrix induced by the rotation layer. |
|
|
|
|
|
Returns: |
|
|
Delta weight matrix of shape (dim, dim) |
|
|
""" |
|
|
X = torch.cat([self.U, -self.V], dim=1) |
|
|
Y = torch.cat([self.V, self.U], dim=1) * self.T |
|
|
|
|
|
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) |
|
|
I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1) |
|
|
I_minus_YX = I_2r - Y_T_X |
|
|
|
|
|
if self.r == 1: |
|
|
I_minus_YX_inv = inverse_2x2(I_minus_YX) |
|
|
I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) |
|
|
else: |
|
|
I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) |
|
|
I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) |
|
|
second_term = second_term.sum(dim=0) |
|
|
total_delta_weight = 2 * second_term |
|
|
return total_delta_weight |
|
|
|
|
|
|
|
|
class RotationLayer(BaseTunerLayer): |
|
|
""" |
|
|
Adapter-like wrapper that attaches Rotation modules to a base linear layer. |
|
|
""" |
|
|
|
|
|
adapter_layer_names: tuple[str, ...] = ("rotation",) |
|
|
other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling") |
|
|
|
|
|
def __init__(self, base_layer: nn.Module, **kwargs): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.base_layer = base_layer |
|
|
self.rotation = nn.ModuleDict() |
|
|
self.scaling={} |
|
|
self._adapter_config = {} |
|
|
|
|
|
|
|
|
self._disable_adapters = False |
|
|
self.merged_adapters: list[str] = [] |
|
|
self._cast_input_dtype_enabled = True |
|
|
self.kwargs = kwargs |
|
|
|
|
|
if isinstance(base_layer, nn.Linear): |
|
|
self.in_features = base_layer.in_features |
|
|
self.out_features = base_layer.out_features |
|
|
else: |
|
|
raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.") |
|
|
|
|
|
@property |
|
|
def _available_adapters(self) -> set[str]: |
|
|
return set(self.rotation.keys()) |
|
|
|
|
|
@property |
|
|
def disable_adapters(self) -> bool: |
|
|
return self._disable_adapters |
|
|
|
|
|
@property |
|
|
def merged(self) -> bool: |
|
|
return bool(self.merged_adapters) |
|
|
|
|
|
@property |
|
|
def active_adapters(self) -> list[str]: |
|
|
|
|
|
return getattr(self, "_active_adapters", list(self.rotation.keys())) |
|
|
|
|
|
def get_base_layer(self) -> nn.Module: |
|
|
return self.base_layer |
|
|
|
|
|
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
|
|
if not self._cast_input_dtype_enabled: |
|
|
return x |
|
|
return x.to(dtype) |
|
|
|
|
|
def update_layer( |
|
|
self, |
|
|
adapter_name: str, |
|
|
r: int, |
|
|
T: float, |
|
|
num_rotations: int, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Add / update a rotation adapter for this layer. |
|
|
""" |
|
|
|
|
|
if r <= 0: |
|
|
raise ValueError(f"r must be positive, got {r}") |
|
|
if num_rotations <= 0: |
|
|
raise ValueError(f"num_rotations must be positive, got {num_rotations}") |
|
|
|
|
|
rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations) |
|
|
self.rotation[adapter_name] = rot |
|
|
self.scaling[adapter_name] = 1.0 |
|
|
self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations} |
|
|
|
|
|
|
|
|
def set_active_adapters(self, adapters: Optional[list[str]]): |
|
|
if adapters is None: |
|
|
if hasattr(self, "_active_adapters"): |
|
|
delattr(self, "_active_adapters") |
|
|
else: |
|
|
self._active_adapters = adapters |
|
|
|
|
|
|
|
|
class Linear(nn.Module, RotationLayer): |
|
|
""" |
|
|
A linear layer with an integrated rotation layer for parameter-efficient fine-tuning. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
base_layer: nn.Linear, |
|
|
adapter_name: str, |
|
|
r: int, |
|
|
T: float, |
|
|
num_rotations: int, |
|
|
**kwargs): |
|
|
|
|
|
super().__init__() |
|
|
RotationLayer.__init__(self, base_layer=base_layer, **kwargs) |
|
|
|
|
|
self._active_adapter = adapter_name |
|
|
|
|
|
self.update_layer( |
|
|
adapter_name=adapter_name, |
|
|
r=r, |
|
|
T=T, |
|
|
num_rotations=num_rotations, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None): |
|
|
""" |
|
|
Merge the adapter effect into the base layer weights: |
|
|
W_merged = W @ R |
|
|
where R = I + delta (delta returned by get_delta_weight()). |
|
|
""" |
|
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
|
|
|
|
if not adapter_names: |
|
|
return |
|
|
|
|
|
base_layer = self.get_base_layer() |
|
|
orig_dtype = base_layer.weight.dtype |
|
|
|
|
|
W = base_layer.weight.data |
|
|
|
|
|
for active_adapter in adapter_names: |
|
|
if active_adapter not in self._available_adapters: |
|
|
continue |
|
|
delta_R = self.rotation[active_adapter].get_delta_weight() |
|
|
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R |
|
|
|
|
|
merged_W = W.to(R.dtype) @ R |
|
|
if safe_merge and not torch.isfinite(merged_W).all(): |
|
|
raise ValueError("Merging resulted in non-finite weights. Aborting merge.") |
|
|
|
|
|
base_layer.weight.data = merged_W.contiguous().to(orig_dtype) |
|
|
|
|
|
self.merged_adapters.append(active_adapter) |
|
|
|
|
|
|
|
|
def unmerge(self): |
|
|
""" |
|
|
Reverse merges in LIFO order (pop merged adapters and invert R). |
|
|
""" |
|
|
base_layer = self.get_base_layer() |
|
|
orig_dtype = base_layer.weight.dtype |
|
|
|
|
|
while self.merged_adapters: |
|
|
active_adapter = self.merged_adapters.pop() |
|
|
if active_adapter not in self._available_adapters: |
|
|
continue |
|
|
delta_R = self.rotation[active_adapter].get_delta_weight() |
|
|
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R |
|
|
R_inv = torch.linalg.inv(R) |
|
|
merged_W = base_layer.weight.data.to(R.dtype) |
|
|
unmerged_W = merged_W @ R_inv |
|
|
base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype) |
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
|
x_dtype = x.dtype |
|
|
base_layer = self.get_base_layer() |
|
|
|
|
|
if self.disable_adapters: |
|
|
|
|
|
if self.merged: |
|
|
self.unmerge() |
|
|
return base_layer(x, *args, **kwargs).to(x_dtype) |
|
|
|
|
|
if self.merged: |
|
|
|
|
|
return base_layer(x, *args, **kwargs).to(x_dtype) |
|
|
|
|
|
|
|
|
for active_adapter in self.active_adapters: |
|
|
if active_adapter not in self.rotation: |
|
|
continue |
|
|
rotation = self.rotation[active_adapter] |
|
|
x = self._cast_input_dtype(x, rotation.U.dtype) |
|
|
x = rotation(x) |
|
|
|
|
|
return base_layer(x, *args, **kwargs).to(x_dtype) |
|
|
|
|
|
def __repr__(self): |
|
|
return f"rotation.{super().__repr__()}" |
|
|
|