import torch import torch.nn as nn from typing import Optional, Set from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge def inverse_2x2(matrices): # Extract matrix elements # matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]] a = matrices[..., 0, 0] b = matrices[..., 0, 1] c = matrices[..., 1, 0] d = matrices[..., 1, 1] # Compute determinant det = a * d - b * c # Compute inverse using the formula: # inv = (1/det) * [[d, -b], [-c, a]] inv_det = 1.0 / det # Create output tensor inv_matrices = torch.empty_like(matrices) inv_matrices[..., 0, 0] = d * inv_det inv_matrices[..., 0, 1] = -b * inv_det inv_matrices[..., 1, 0] = -c * inv_det inv_matrices[..., 1, 1] = a * inv_det return inv_matrices class Rotation(nn.Module): """ Rotation layer based on Cayley transformation for parameter-efficient fine-tuning. This layer implements orthogonal fine-tuning through Cayley transformation: h(x) = (I - A)^{-1} (I + A) x where A = XY^T with X = [U; -V] and Y = [V; U] """ def __init__(self, r, dim, T=1.0, num_rotations=4): super().__init__() self.r = r self.T = T self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True) self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True) self.num_rotations = num_rotations def forward(self, x): """ Apply Cayley transformation to input x. A = XY^T where X = [U; -V], Y = [V; U] Cayley transformation: h(x) = (I - A)^{-1} (I + A) x Uses Woodbury identity for efficient computation: (I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T Args: x: Input tensor of shape (..., dim) Returns: Transformed tensor of shape (..., dim) """ x_dtype = x.dtype X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim) Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim) Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r) I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1) I_minus_YX = I_2r - Y_T_X if self.r == 1: I_minus_YX_inv = inverse_2x2(I_minus_YX) else: # make it float32 I_minus_YX = I_minus_YX.to(torch.float32) I_minus_YX_inv = torch.linalg.inv(I_minus_YX) # Shape: (num_rotations, 2r, 2r) I_minus_YX_inv = I_minus_YX_inv.to(x_dtype) Yx = torch.einsum("...d,nrd->...nr", x, Y) # Shape: (batch*seq_len, num_rotations, 2r) I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx) second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) # Shape: (batch*seq_len, num_rotations, dim) second_term = second_term.sum(dim=-2) # Sum over rotations output = x + 2 * second_term # Shape: (batch*seq_len, dim) return output def get_delta_weight(self): """ Compute the delta weight matrix induced by the rotation layer. Returns: Delta weight matrix of shape (dim, dim) """ X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim) Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim) Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r) I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1) I_minus_YX = I_2r - Y_T_X if self.r == 1: I_minus_YX_inv = inverse_2x2(I_minus_YX) I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim) else: I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) # Shape: (num_rotations, 2r, dim) I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype) # I_minus_YX_float = I_minus_YX.float() # I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float) # Shape: (num_rotations, 2r, 2r) # I_minus_YX_inv = I_minus_YX_inv.to(X.dtype) # I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim) second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) # Shape: (num_rotations, dim, dim) second_term = second_term.sum(dim=0) total_delta_weight = 2 * second_term return total_delta_weight class RotationLayer(BaseTunerLayer): """ Adapter-like wrapper that attaches Rotation modules to a base linear layer. """ adapter_layer_names: tuple[str, ...] = ("rotation",) other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling") def __init__(self, base_layer: nn.Module, **kwargs): # Let BaseTunerLayer do its init (it usually subclasses nn.Module) super().__init__() # store base layer and adapter containers self.base_layer = base_layer self.rotation = nn.ModuleDict() # mapping adapter_name -> Rotation module self.scaling={} # default scaling per adapter self._adapter_config = {} # store r, T, num_rotations per adapter # flags (exposed in a simple way) self._disable_adapters = False self.merged_adapters: list[str] = [] self._cast_input_dtype_enabled = True self.kwargs = kwargs if isinstance(base_layer, nn.Linear): self.in_features = base_layer.in_features self.out_features = base_layer.out_features else: raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.") @property def _available_adapters(self) -> set[str]: return set(self.rotation.keys()) @property def disable_adapters(self) -> bool: return self._disable_adapters @property def merged(self) -> bool: return bool(self.merged_adapters) @property def active_adapters(self) -> list[str]: # If some external mechanism sets active adapters, prefer it; else use all added adapters. return getattr(self, "_active_adapters", list(self.rotation.keys())) def get_base_layer(self) -> nn.Module: return self.base_layer def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: if not self._cast_input_dtype_enabled: return x return x.to(dtype) def update_layer( self, adapter_name: str, r: int, T: float, num_rotations: int, **kwargs, ): """ Add / update a rotation adapter for this layer. """ if r <= 0: raise ValueError(f"r must be positive, got {r}") if num_rotations <= 0: raise ValueError(f"num_rotations must be positive, got {num_rotations}") rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations) self.rotation[adapter_name] = rot self.scaling[adapter_name] = 1.0 self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations} # (optional) helper to set currently active adapters externally def set_active_adapters(self, adapters: Optional[list[str]]): if adapters is None: if hasattr(self, "_active_adapters"): delattr(self, "_active_adapters") else: self._active_adapters = adapters class Linear(nn.Module, RotationLayer): """ A linear layer with an integrated rotation layer for parameter-efficient fine-tuning. """ def __init__(self, base_layer: nn.Linear, adapter_name: str, r: int, T: float, num_rotations: int, **kwargs): super().__init__() RotationLayer.__init__(self, base_layer=base_layer, **kwargs) self._active_adapter = adapter_name self.update_layer( adapter_name=adapter_name, r=r, T=T, num_rotations=num_rotations, **kwargs, ) def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None): """ Merge the adapter effect into the base layer weights: W_merged = W @ R where R = I + delta (delta returned by get_delta_weight()). """ adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: return base_layer = self.get_base_layer() orig_dtype = base_layer.weight.dtype # base_layer.weight shape: (out_features, in_features) W = base_layer.weight.data # (out, in) for active_adapter in adapter_names: if active_adapter not in self._available_adapters: continue delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in) R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R # (in, in) # merged W = W @ R merged_W = W.to(R.dtype) @ R if safe_merge and not torch.isfinite(merged_W).all(): raise ValueError("Merging resulted in non-finite weights. Aborting merge.") base_layer.weight.data = merged_W.contiguous().to(orig_dtype) # mark merged (so unmerge can restore by inverse) self.merged_adapters.append(active_adapter) def unmerge(self): """ Reverse merges in LIFO order (pop merged adapters and invert R). """ base_layer = self.get_base_layer() orig_dtype = base_layer.weight.dtype while self.merged_adapters: active_adapter = self.merged_adapters.pop() if active_adapter not in self._available_adapters: continue delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in) R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R R_inv = torch.linalg.inv(R) merged_W = base_layer.weight.data.to(R.dtype) unmerged_W = merged_W @ R_inv base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x_dtype = x.dtype base_layer = self.get_base_layer() if self.disable_adapters: # if merged, unmerge to ensure base_layer produces original behavior if self.merged: self.unmerge() return base_layer(x, *args, **kwargs).to(x_dtype) if self.merged: # if merged into base layer, just forward return base_layer(x, *args, **kwargs).to(x_dtype) # otherwise apply active adapters (transform inputs) then call base layer for active_adapter in self.active_adapters: if active_adapter not in self.rotation: continue rotation = self.rotation[active_adapter] x = self._cast_input_dtype(x, rotation.U.dtype) x = rotation(x) return base_layer(x, *args, **kwargs).to(x_dtype) def __repr__(self): return f"rotation.{super().__repr__()}"