| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Set |
| |
|
| | from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
| |
|
| | def inverse_2x2(matrices): |
| |
|
| | |
| | |
| | a = matrices[..., 0, 0] |
| | b = matrices[..., 0, 1] |
| | c = matrices[..., 1, 0] |
| | d = matrices[..., 1, 1] |
| | |
| | |
| | det = a * d - b * c |
| | |
| | |
| | |
| | inv_det = 1.0 / det |
| | |
| | |
| | inv_matrices = torch.empty_like(matrices) |
| | inv_matrices[..., 0, 0] = d * inv_det |
| | inv_matrices[..., 0, 1] = -b * inv_det |
| | inv_matrices[..., 1, 0] = -c * inv_det |
| | inv_matrices[..., 1, 1] = a * inv_det |
| | |
| | return inv_matrices |
| |
|
| | class Rotation(nn.Module): |
| | """ |
| | Rotation layer based on Cayley transformation for parameter-efficient fine-tuning. |
| | |
| | This layer implements orthogonal fine-tuning through Cayley transformation: |
| | h(x) = (I - A)^{-1} (I + A) x |
| | |
| | where A = XY^T with X = [U; -V] and Y = [V; U] |
| | """ |
| | |
| | def __init__(self, r, dim, T=1.0, num_rotations=4): |
| | super().__init__() |
| | self.r = r |
| | self.T = T |
| | self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True) |
| | self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True) |
| | self.num_rotations = num_rotations |
| | |
| |
|
| | def forward(self, x): |
| | """ |
| | Apply Cayley transformation to input x. |
| | |
| | A = XY^T where X = [U; -V], Y = [V; U] |
| | Cayley transformation: h(x) = (I - A)^{-1} (I + A) x |
| | |
| | Uses Woodbury identity for efficient computation: |
| | (I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T |
| | |
| | Args: |
| | x: Input tensor of shape (..., dim) |
| | |
| | Returns: |
| | Transformed tensor of shape (..., dim) |
| | """ |
| | x_dtype = x.dtype |
| | X = torch.cat([self.U, -self.V], dim=1) |
| | Y = torch.cat([self.V, self.U], dim=1) * self.T |
| |
|
| | Y_T_X = torch.matmul(Y, X.transpose(1, 2)) |
| | I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1) |
| | I_minus_YX = I_2r - Y_T_X |
| | |
| | if self.r == 1: |
| | I_minus_YX_inv = inverse_2x2(I_minus_YX) |
| | else: |
| | |
| | I_minus_YX = I_minus_YX.to(torch.float32) |
| | I_minus_YX_inv = torch.linalg.inv(I_minus_YX) |
| | I_minus_YX_inv = I_minus_YX_inv.to(x_dtype) |
| | |
| | Yx = torch.einsum("...d,nrd->...nr", x, Y) |
| | I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx) |
| |
|
| | second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) |
| | second_term = second_term.sum(dim=-2) |
| |
|
| | output = x + 2 * second_term |
| | |
| | return output |
| | |
| | def get_delta_weight(self): |
| | """ |
| | Compute the delta weight matrix induced by the rotation layer. |
| | |
| | Returns: |
| | Delta weight matrix of shape (dim, dim) |
| | """ |
| | X = torch.cat([self.U, -self.V], dim=1) |
| | Y = torch.cat([self.V, self.U], dim=1) * self.T |
| |
|
| | Y_T_X = torch.matmul(Y, X.transpose(1, 2)) |
| | I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1) |
| | I_minus_YX = I_2r - Y_T_X |
| | |
| | if self.r == 1: |
| | I_minus_YX_inv = inverse_2x2(I_minus_YX) |
| | I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) |
| | else: |
| | I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) |
| | I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) |
| | second_term = second_term.sum(dim=0) |
| | total_delta_weight = 2 * second_term |
| | return total_delta_weight |
| | |
| |
|
| | class RotationLayer(BaseTunerLayer): |
| | """ |
| | Adapter-like wrapper that attaches Rotation modules to a base linear layer. |
| | """ |
| |
|
| | adapter_layer_names: tuple[str, ...] = ("rotation",) |
| | other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling") |
| |
|
| | def __init__(self, base_layer: nn.Module, **kwargs): |
| | |
| | super().__init__() |
| | |
| | self.base_layer = base_layer |
| | self.rotation = nn.ModuleDict() |
| | self.scaling={} |
| | self._adapter_config = {} |
| |
|
| | |
| | self._disable_adapters = False |
| | self.merged_adapters: list[str] = [] |
| | self._cast_input_dtype_enabled = True |
| | self.kwargs = kwargs |
| |
|
| | if isinstance(base_layer, nn.Linear): |
| | self.in_features = base_layer.in_features |
| | self.out_features = base_layer.out_features |
| | else: |
| | raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.") |
| |
|
| | @property |
| | def _available_adapters(self) -> set[str]: |
| | return set(self.rotation.keys()) |
| |
|
| | @property |
| | def disable_adapters(self) -> bool: |
| | return self._disable_adapters |
| |
|
| | @property |
| | def merged(self) -> bool: |
| | return bool(self.merged_adapters) |
| |
|
| | @property |
| | def active_adapters(self) -> list[str]: |
| | |
| | return getattr(self, "_active_adapters", list(self.rotation.keys())) |
| |
|
| | def get_base_layer(self) -> nn.Module: |
| | return self.base_layer |
| |
|
| | def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
| | if not self._cast_input_dtype_enabled: |
| | return x |
| | return x.to(dtype) |
| |
|
| | def update_layer( |
| | self, |
| | adapter_name: str, |
| | r: int, |
| | T: float, |
| | num_rotations: int, |
| | **kwargs, |
| | ): |
| | """ |
| | Add / update a rotation adapter for this layer. |
| | """ |
| | |
| | if r <= 0: |
| | raise ValueError(f"r must be positive, got {r}") |
| | if num_rotations <= 0: |
| | raise ValueError(f"num_rotations must be positive, got {num_rotations}") |
| | |
| | rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations) |
| | self.rotation[adapter_name] = rot |
| | self.scaling[adapter_name] = 1.0 |
| | self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations} |
| |
|
| | |
| | def set_active_adapters(self, adapters: Optional[list[str]]): |
| | if adapters is None: |
| | if hasattr(self, "_active_adapters"): |
| | delattr(self, "_active_adapters") |
| | else: |
| | self._active_adapters = adapters |
| | |
| | |
| | class Linear(nn.Module, RotationLayer): |
| | """ |
| | A linear layer with an integrated rotation layer for parameter-efficient fine-tuning. |
| | """ |
| | |
| | def __init__(self, |
| | base_layer: nn.Linear, |
| | adapter_name: str, |
| | r: int, |
| | T: float, |
| | num_rotations: int, |
| | **kwargs): |
| | |
| | super().__init__() |
| | RotationLayer.__init__(self, base_layer=base_layer, **kwargs) |
| | |
| | self._active_adapter = adapter_name |
| | |
| | self.update_layer( |
| | adapter_name=adapter_name, |
| | r=r, |
| | T=T, |
| | num_rotations=num_rotations, |
| | **kwargs, |
| | ) |
| | |
| | def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None): |
| | """ |
| | Merge the adapter effect into the base layer weights: |
| | W_merged = W @ R |
| | where R = I + delta (delta returned by get_delta_weight()). |
| | """ |
| | adapter_names = check_adapters_to_merge(self, adapter_names) |
| |
|
| | if not adapter_names: |
| | return |
| |
|
| | base_layer = self.get_base_layer() |
| | orig_dtype = base_layer.weight.dtype |
| | |
| | W = base_layer.weight.data |
| |
|
| | for active_adapter in adapter_names: |
| | if active_adapter not in self._available_adapters: |
| | continue |
| | delta_R = self.rotation[active_adapter].get_delta_weight() |
| | R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R |
| | |
| | merged_W = W.to(R.dtype) @ R |
| | if safe_merge and not torch.isfinite(merged_W).all(): |
| | raise ValueError("Merging resulted in non-finite weights. Aborting merge.") |
| |
|
| | base_layer.weight.data = merged_W.contiguous().to(orig_dtype) |
| | |
| | self.merged_adapters.append(active_adapter) |
| | |
| | |
| | def unmerge(self): |
| | """ |
| | Reverse merges in LIFO order (pop merged adapters and invert R). |
| | """ |
| | base_layer = self.get_base_layer() |
| | orig_dtype = base_layer.weight.dtype |
| |
|
| | while self.merged_adapters: |
| | active_adapter = self.merged_adapters.pop() |
| | if active_adapter not in self._available_adapters: |
| | continue |
| | delta_R = self.rotation[active_adapter].get_delta_weight() |
| | R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R |
| | R_inv = torch.linalg.inv(R) |
| | merged_W = base_layer.weight.data.to(R.dtype) |
| | unmerged_W = merged_W @ R_inv |
| | base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype) |
| | |
| | |
| | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| | x_dtype = x.dtype |
| | base_layer = self.get_base_layer() |
| |
|
| | if self.disable_adapters: |
| | |
| | if self.merged: |
| | self.unmerge() |
| | return base_layer(x, *args, **kwargs).to(x_dtype) |
| |
|
| | if self.merged: |
| | |
| | return base_layer(x, *args, **kwargs).to(x_dtype) |
| |
|
| | |
| | for active_adapter in self.active_adapters: |
| | if active_adapter not in self.rotation: |
| | continue |
| | rotation = self.rotation[active_adapter] |
| | x = self._cast_input_dtype(x, rotation.U.dtype) |
| | x = rotation(x) |
| |
|
| | return base_layer(x, *args, **kwargs).to(x_dtype) |
| |
|
| | def __repr__(self): |
| | return f"rotation.{super().__repr__()}" |
| | |