| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from enum import Enum |
| from dataclasses import asdict |
| from tqdm import tqdm |
|
|
|
|
| from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer |
|
|
| from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules |
|
|
| from .layer import RotationLayer, Linear |
|
|
| TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() |
|
|
| class RotationTuner(BaseTuner): |
| |
| prefix: str = "rotation_" |
| tuner_layer_class = RotationLayer |
| target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING |
| |
| |
| @staticmethod |
| def _check_target_module_exists(rotation_config, key: str) -> bool: |
| return check_target_module_exists(rotation_config, key) |
| |
| def _create_and_replace( |
| self, |
| rotation_config, |
| adapter_name: str, |
| target: nn.Module, |
| target_name: str, |
| parent: nn.Module, |
| current_key: str, |
| **optional_kwargs, |
| ) -> None: |
| """ |
| Create and replace a target module with a rotation-augmented version. |
| |
| This method is called when an existing module is already a RotationLayer |
| and needs to have a new adapter added to it. |
| |
| Args: |
| rotation_config: Configuration for the rotation adapter |
| adapter_name: Name of the adapter to add |
| target: The target module to augment |
| target_name: Name of the target module |
| parent: Parent module containing the target |
| current_key: Full key path to the current module |
| **optional_kwargs: Additional optional arguments |
| |
| Raises: |
| ValueError: If current_key is not provided |
| """ |
| |
| if current_key is None: |
| raise ValueError("current_key must be provided to create Rotation layer") |
| |
| |
| if isinstance(target, RotationLayer): |
| target.update_layer( |
| adapter_name=adapter_name, |
| r=rotation_config.r, |
| T=rotation_config.T, |
| num_rotations=rotation_config.num_rotations, |
| ) |
| else: |
| |
| new_module = self._create_new_module( |
| rotation_config=rotation_config, |
| adapter_name=adapter_name, |
| target=target, |
| **optional_kwargs, |
| ) |
| if new_module is not None: |
| self._replace_module(parent, target_name, new_module, target) |
| |
| def _replace_module(self, parent, child_name, new_module, child): |
| |
| setattr(parent, child_name, new_module) |
| |
| |
| if hasattr(child, "base_layer"): |
| child = child.base_layer |
| |
| meta = torch.device("meta") |
| |
| for name, module in new_module.named_modules(): |
| if (self.prefix in name) or ("ranknum" in name): |
| if hasattr(child, "qweight"): |
| weight = child.qweight |
| elif hasattr(child, "W_q"): |
| weight = child.W_q |
| elif hasattr(child, "weight"): |
| weight = child.weight |
| elif getattr(child, "in_proj_weight", None) is not None: |
| weight = child.in_proj_weight |
| else: |
| weight = next(child.parameters()) |
| if not any(p.device == meta for p in module.parameters()): |
| module.to(weight.device) |
| |
| def _mark_only_adapters_as_trainable(self, model): |
| |
| |
| for n, p in model.named_parameters(): |
| if self.prefix not in n: |
| p.requires_grad = False |
| else: |
| p.requires_grad = True |
| |
| |
| for active_adapter in self.active_adapters: |
| bias_config = self.peft_config[active_adapter].bias |
| |
| if bias_config == "none": |
| continue |
| elif bias_config == "all": |
| |
| for n, p in model.named_parameters(): |
| if "bias" in n: |
| p.requires_grad = True |
| elif bias_config == "rotation_only": |
| |
| for name, m in model.named_modules(): |
| if isinstance(m, RotationLayer): |
| if hasattr(m, "bias") and m.bias is not None: |
| m.bias.requires_grad = True |
| else: |
| raise NotImplementedError( |
| f"Requested bias configuration '{bias_config}' is not implemented. " |
| f"Supported values: 'none', 'all', 'rotation_only'" |
| ) |
| |
| @staticmethod |
| def _create_new_module( |
| rotation_config, |
| adapter_name: str, |
| target: nn.Module, |
| **kwargs, |
| ) -> Optional[nn.Module]: |
| """ |
| Create a new rotation-augmented module. |
| |
| Args: |
| rotation_config: Configuration for the rotation adapter |
| adapter_name: Name of the adapter |
| target: Base module to augment |
| **kwargs: Additional arguments |
| |
| Returns: |
| New RotationLayer module wrapping the target, or None if unsupported |
| """ |
| if isinstance(target, nn.Linear): |
| return Linear( |
| base_layer=target, |
| adapter_name=adapter_name, |
| r=rotation_config.r, |
| T=rotation_config.T, |
| num_rotations=rotation_config.num_rotations, |
| **kwargs, |
| ) |
| else: |
| |
| print( |
| f"Rotation layer does not support {type(target).__name__} yet. " |
| f"Skipping this module." |
| ) |
| return None |
| |
| |
| def __getattr__(self, name: str): |
| """Forward missing attributes to the wrapped module.""" |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| if name == "model": |
| raise |
| return getattr(self.model, name) |
| |
| def get_peft_config_as_dict(self, inference: bool = False): |
| config_dict = {} |
| for key, value in self.peft_config.items(): |
| config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} |
| if inference: |
| config["inference_mode"] = True |
| config_dict[key] = config |
| return config |
| |
| |
| def _set_adapter_layers(self, enabled=True): |
| for module in self.model.modules(): |
| if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): |
| module.enable_adapters(enabled) |
|
|
| def enable_adapter_layers(self) -> None: |
| """Enable all adapters. |
| |
| Call this if you have previously disabled all adapters and want to re-enable them. |
| """ |
| self._set_adapter_layers(enabled=True) |
|
|
| def disable_adapter_layers(self): |
| for active_adapter in self.active_adapters: |
| val = self.peft_config[active_adapter].bias |
| if val != "none": |
| msg = ( |
| f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " |
| "output as the base model would without adaption." |
| ) |
| print(msg) |
| self._set_adapter_layers(enabled=False) |
|
|
| def set_adapter(self, adapter_name): |
| """Set the active adapter(s). |
| |
| Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is |
| not desired, use the following code. |
| |
| ```py |
| >>> for name, param in model_peft.named_parameters(): |
| ... if ...: # some check on name (ex. if 'lora' in name) |
| ... param.requires_grad = False |
| ``` |
| |
| Args: |
| adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. |
| """ |
| for module in self.model.modules(): |
| if isinstance(module, RotationLayer): |
| if module.merged: |
| print("Adapter cannot be set when the model is merged. Unmerging the model first.") |
| module.unmerge() |
| module.set_adapter(adapter_name) |
| self.active_adapter = adapter_name |
| |
| def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None: |
| """ |
| Merge adapter weights into the base model weights. |
| |
| This can speed up inference by eliminating the need for runtime |
| rotation computations. |
| |
| Args: |
| adapter_names: List of adapter names to merge. If None, merges all |
| active adapters. |
| """ |
| for module in self.model.modules(): |
| if isinstance(module, RotationLayer): |
| module.merge(safe_merge=False, adapter_names=adapter_names) |
| |
| |
| def unmerge_adapter(self) -> None: |
| """ |
| Unmerge adapter weights from the base model weights. |
| |
| This reverses the merge operation, restoring dynamic adapter behavior. |
| """ |
| for module in self.model.modules(): |
| if isinstance(module, RotationLayer): |
| module.unmerge() |
| |
| @staticmethod |
| def _prepare_adapter_config(peft_config, model_config): |
| |
| if peft_config.target_modules is None: |
| if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING: |
| raise ValueError("Please specify `target_modules` in `peft_config`") |
| peft_config.target_modules = set( |
| TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]] |
| ) |
| |
| return peft_config |
| |
| |
| def _check_new_adapter_config(self, config) -> None: |
| """ |
| Check the validity of a new adapter configuration. |
| |
| Args: |
| config: Configuration to validate |
| |
| Raises: |
| ValueError: If configuration is invalid |
| """ |
| |
| if config.r <= 0: |
| raise ValueError(f"r must be positive, got {config.r}") |
| |
| |
| if config.num_rotations <= 0: |
| raise ValueError( |
| f"num_rotations must be positive, got {config.num_rotations}" |
| ) |
| |
| |
| |
| valid_bias_configs = ["none", "all", "rotation_only"] |
| if hasattr(config, "bias") and config.bias not in valid_bias_configs: |
| raise ValueError( |
| f"Invalid bias configuration '{config.bias}'. " |
| f"Must be one of {valid_bias_configs}" |
| ) |
| |
| |
| def _unload_and_optionally_merge( |
| self, |
| merge=True, |
| progressbar: bool = False, |
| safe_merge: bool = False, |
| adapter_names: Optional[list[str]] = None, |
| ): |
| if merge: |
| self._check_merge_allowed() |
|
|
| key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] |
| desc = "Unloading " + ("and merging " if merge else "") + "model" |
| for key in tqdm(key_list, disable=not progressbar, desc=desc): |
| try: |
| parent, target, target_name = _get_submodules(self.model, key) |
| except AttributeError: |
| continue |
| with onload_layer(target): |
| if hasattr(target, "unload_and_optionally_merge_module"): |
| |
| unloaded_module = target.unload_and_optionally_merge_module( |
| merge=merge, safe_merge=safe_merge, adapter_names=adapter_names |
| ) |
| self._replace_module(parent, target_name, unloaded_module, target) |
| elif hasattr(target, "base_layer"): |
| if merge: |
| target.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
| self._replace_module(parent, target_name, target.get_base_layer(), target) |
|
|
| return self.model |
|
|
| def delete_adapter(self, adapter_name: str) -> None: |
| """ |
| Deletes an existing adapter. |
| |
| Args: |
| adapter_name (str): Name of the adapter to be deleted. |
| """ |
| if adapter_name not in list(self.peft_config.keys()): |
| raise ValueError(f"Adapter {adapter_name} does not exist") |
| del self.peft_config[adapter_name] |
|
|
| key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] |
| new_adapter = None |
| for key in key_list: |
| _, target, _ = _get_submodules(self.model, key) |
| if isinstance(target, RotationLayer): |
| target.delete_adapter(adapter_name) |
| if new_adapter is None: |
| new_adapter = target.active_adapters[:] |
|
|
| self.active_adapter = new_adapter or [] |
| self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter) |
|
|
| def merge_and_unload( |
| self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None |
| ) -> torch.nn.Module: |
| r""" |
| This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as |
| a standalone model. |
| |
| Args: |
| progressbar (`bool`): |
| whether to show a progressbar indicating the unload and merge process |
| safe_merge (`bool`): |
| whether to activate the safe merging check to check if there is any potential Nan in the adapter |
| weights |
| adapter_names (`List[str]`, *optional*): |
| The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
| to `None`. |
| |
| """ |
| return self._unload_and_optionally_merge( |
| progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names |
| ) |
|
|
| def unload(self) -> torch.nn.Module: |
| """ |
| Gets back the base model by removing all the oft modules without merging. This gives back the original base |
| model. |
| """ |
| return self._unload_and_optionally_merge(merge=False) |
| |
|
|