File size: 11,748 Bytes
b816a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import torch
import torch.nn as nn
from typing import Optional, Set

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge

def inverse_2x2(matrices):

    # Extract matrix elements
    # matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]]
    a = matrices[..., 0, 0]
    b = matrices[..., 0, 1]
    c = matrices[..., 1, 0]
    d = matrices[..., 1, 1]
    
    # Compute determinant
    det = a * d - b * c
    
    # Compute inverse using the formula:
    # inv = (1/det) * [[d, -b], [-c, a]]
    inv_det = 1.0 / det
    
    # Create output tensor
    inv_matrices = torch.empty_like(matrices)
    inv_matrices[..., 0, 0] = d * inv_det
    inv_matrices[..., 0, 1] = -b * inv_det
    inv_matrices[..., 1, 0] = -c * inv_det
    inv_matrices[..., 1, 1] = a * inv_det
    
    return inv_matrices

class Rotation(nn.Module):
    """
    Rotation layer based on Cayley transformation for parameter-efficient fine-tuning.
    
    This layer implements orthogonal fine-tuning through Cayley transformation:
    h(x) = (I - A)^{-1} (I + A) x
    
    where A = XY^T with X = [U; -V] and Y = [V; U]
    """
    
    def __init__(self, r, dim, T=1.0, num_rotations=4):
        super().__init__()
        self.r = r
        self.T = T
        self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True)
        self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True)
        self.num_rotations = num_rotations
        

    def forward(self, x):
        """
        Apply Cayley transformation to input x.
        
        A = XY^T where X = [U; -V], Y = [V; U]
        Cayley transformation: h(x) = (I - A)^{-1} (I + A) x
        
        Uses Woodbury identity for efficient computation:
        (I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T
        
        Args:
            x: Input tensor of shape (..., dim)
            
        Returns:
            Transformed tensor of shape (..., dim)
        """
        x_dtype = x.dtype
        X = torch.cat([self.U, -self.V], dim=1)  # Shape: (num_rotations, 2r, dim)
        Y = torch.cat([self.V, self.U], dim=1) * self.T   # Shape: (num_rotations, 2r, dim)

        Y_T_X = torch.matmul(Y, X.transpose(1, 2))  # Shape: (num_rotations, 2r, 2r)
        I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1)
        I_minus_YX = I_2r - Y_T_X
        
        if self.r == 1:
            I_minus_YX_inv = inverse_2x2(I_minus_YX)
        else:
            # make it float32
            I_minus_YX = I_minus_YX.to(torch.float32)
            I_minus_YX_inv = torch.linalg.inv(I_minus_YX)  # Shape: (num_rotations, 2r, 2r)
            I_minus_YX_inv = I_minus_YX_inv.to(x_dtype)
        
        Yx = torch.einsum("...d,nrd->...nr", x, Y)   # Shape: (batch*seq_len, num_rotations, 2r)
        I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx)

        second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X)  # Shape: (batch*seq_len, num_rotations, dim)
        second_term = second_term.sum(dim=-2)  # Sum over rotations

        output = x + 2 * second_term  # Shape: (batch*seq_len, dim)
        
        return output
    
    def get_delta_weight(self):
        """
        Compute the delta weight matrix induced by the rotation layer.
        
        Returns:
            Delta weight matrix of shape (dim, dim)
        """
        X = torch.cat([self.U, -self.V], dim=1)  # Shape: (num_rotations, 2r, dim)
        Y = torch.cat([self.V, self.U], dim=1) * self.T   # Shape: (num_rotations, 2r, dim)

        Y_T_X = torch.matmul(Y, X.transpose(1, 2))  # Shape: (num_rotations, 2r, 2r)
        I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1)
        I_minus_YX = I_2r - Y_T_X
        
        if self.r == 1:
            I_minus_YX_inv = inverse_2x2(I_minus_YX)
            I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
        else:
            I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32))  # Shape: (num_rotations, 2r, dim)
            I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype)
        
        # I_minus_YX_float = I_minus_YX.float()
        # I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float)  # Shape: (num_rotations, 2r, 2r)
        # I_minus_YX_inv = I_minus_YX_inv.to(X.dtype)
        
        
        # I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
        second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y)  # Shape: (num_rotations, dim, dim)
        second_term = second_term.sum(dim=0)
        total_delta_weight = 2 * second_term
        return total_delta_weight
    

class RotationLayer(BaseTunerLayer):
    """
    Adapter-like wrapper that attaches Rotation modules to a base linear layer.
    """

    adapter_layer_names: tuple[str, ...] = ("rotation",)
    other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling")

    def __init__(self, base_layer: nn.Module, **kwargs):
        # Let BaseTunerLayer do its init (it usually subclasses nn.Module)
        super().__init__()
        # store base layer and adapter containers
        self.base_layer = base_layer
        self.rotation = nn.ModuleDict()  # mapping adapter_name -> Rotation module
        self.scaling={}  # default scaling per adapter
        self._adapter_config = {}  # store r, T, num_rotations per adapter

        # flags (exposed in a simple way)
        self._disable_adapters = False
        self.merged_adapters: list[str] = []
        self._cast_input_dtype_enabled = True
        self.kwargs = kwargs

        if isinstance(base_layer, nn.Linear):
            self.in_features = base_layer.in_features
            self.out_features = base_layer.out_features
        else:
            raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.")

    @property
    def _available_adapters(self) -> set[str]:
        return set(self.rotation.keys())

    @property
    def disable_adapters(self) -> bool:
        return self._disable_adapters

    @property
    def merged(self) -> bool:
        return bool(self.merged_adapters)

    @property
    def active_adapters(self) -> list[str]:
        # If some external mechanism sets active adapters, prefer it; else use all added adapters.
        return getattr(self, "_active_adapters", list(self.rotation.keys()))

    def get_base_layer(self) -> nn.Module:
        return self.base_layer

    def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
        if not self._cast_input_dtype_enabled:
            return x
        return x.to(dtype)

    def update_layer(
        self,
        adapter_name: str,
        r: int,
        T: float,
        num_rotations: int,
        **kwargs,
    ):
        """
        Add / update a rotation adapter for this layer.
        """
        
        if r <= 0:
            raise ValueError(f"r must be positive, got {r}")
        if num_rotations <= 0:
            raise ValueError(f"num_rotations must be positive, got {num_rotations}")
        
        rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations)
        self.rotation[adapter_name] = rot
        self.scaling[adapter_name] = 1.0
        self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations}

    # (optional) helper to set currently active adapters externally
    def set_active_adapters(self, adapters: Optional[list[str]]):
        if adapters is None:
            if hasattr(self, "_active_adapters"):
                delattr(self, "_active_adapters")
        else:
            self._active_adapters = adapters
            
        
class Linear(nn.Module, RotationLayer):
    """
    A linear layer with an integrated rotation layer for parameter-efficient fine-tuning.
    """
    
    def __init__(self, 
                 base_layer: nn.Linear,
                 adapter_name: str,
                 r: int,
                 T: float,
                 num_rotations: int,
                 **kwargs):
        
        super().__init__()
        RotationLayer.__init__(self, base_layer=base_layer, **kwargs)
        
        self._active_adapter = adapter_name
        
        self.update_layer(
            adapter_name=adapter_name,
            r=r,
            T=T,
            num_rotations=num_rotations,
            **kwargs,
        )
        
    def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
        """
        Merge the adapter effect into the base layer weights:
            W_merged = W @ R
        where R = I + delta (delta returned by get_delta_weight()).
        """
        adapter_names = check_adapters_to_merge(self, adapter_names)

        if not adapter_names:
            return

        base_layer = self.get_base_layer()
        orig_dtype = base_layer.weight.dtype
        # base_layer.weight shape: (out_features, in_features)
        W = base_layer.weight.data  # (out, in)

        for active_adapter in adapter_names:
            if active_adapter not in self._available_adapters:
                continue
            delta_R = self.rotation[active_adapter].get_delta_weight()  # (in, in)
            R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R  # (in, in)
            # merged W = W @ R
            merged_W = W.to(R.dtype) @ R
            if safe_merge and not torch.isfinite(merged_W).all():
                raise ValueError("Merging resulted in non-finite weights. Aborting merge.")

            base_layer.weight.data = merged_W.contiguous().to(orig_dtype)
            # mark merged (so unmerge can restore by inverse)
            self.merged_adapters.append(active_adapter)
                
                
    def unmerge(self):
        """
        Reverse merges in LIFO order (pop merged adapters and invert R).
        """
        base_layer = self.get_base_layer()
        orig_dtype = base_layer.weight.dtype

        while self.merged_adapters:
            active_adapter = self.merged_adapters.pop()
            if active_adapter not in self._available_adapters:
                continue
            delta_R = self.rotation[active_adapter].get_delta_weight()  # (in, in)
            R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R
            R_inv = torch.linalg.inv(R)
            merged_W = base_layer.weight.data.to(R.dtype)
            unmerged_W = merged_W @ R_inv
            base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype)
            
            
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        x_dtype = x.dtype
        base_layer = self.get_base_layer()

        if self.disable_adapters:
            # if merged, unmerge to ensure base_layer produces original behavior
            if self.merged:
                self.unmerge()
            return base_layer(x, *args, **kwargs).to(x_dtype)

        if self.merged:
            # if merged into base layer, just forward
            return base_layer(x, *args, **kwargs).to(x_dtype)

        # otherwise apply active adapters (transform inputs) then call base layer
        for active_adapter in self.active_adapters:
            if active_adapter not in self.rotation:
                continue
            rotation = self.rotation[active_adapter]
            x = self._cast_input_dtype(x, rotation.U.dtype)
            x = rotation(x)

        return base_layer(x, *args, **kwargs).to(x_dtype)

    def __repr__(self):
        return f"rotation.{super().__repr__()}"