| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from enum import Enum |
| | from dataclasses import asdict |
| | from tqdm import tqdm |
| |
|
| |
|
| | from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer |
| |
|
| | from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules |
| |
|
| | from .layer import RotationLayer, Linear |
| |
|
| | TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() |
| |
|
| | class RotationTuner(BaseTuner): |
| | |
| | prefix: str = "rotation_" |
| | tuner_layer_class = RotationLayer |
| | target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING |
| | |
| | |
| | @staticmethod |
| | def _check_target_module_exists(rotation_config, key: str) -> bool: |
| | return check_target_module_exists(rotation_config, key) |
| | |
| | def _create_and_replace( |
| | self, |
| | rotation_config, |
| | adapter_name: str, |
| | target: nn.Module, |
| | target_name: str, |
| | parent: nn.Module, |
| | current_key: str, |
| | **optional_kwargs, |
| | ) -> None: |
| | """ |
| | Create and replace a target module with a rotation-augmented version. |
| | |
| | This method is called when an existing module is already a RotationLayer |
| | and needs to have a new adapter added to it. |
| | |
| | Args: |
| | rotation_config: Configuration for the rotation adapter |
| | adapter_name: Name of the adapter to add |
| | target: The target module to augment |
| | target_name: Name of the target module |
| | parent: Parent module containing the target |
| | current_key: Full key path to the current module |
| | **optional_kwargs: Additional optional arguments |
| | |
| | Raises: |
| | ValueError: If current_key is not provided |
| | """ |
| | |
| | if current_key is None: |
| | raise ValueError("current_key must be provided to create Rotation layer") |
| | |
| | |
| | if isinstance(target, RotationLayer): |
| | target.update_layer( |
| | adapter_name=adapter_name, |
| | r=rotation_config.r, |
| | T=rotation_config.T, |
| | num_rotations=rotation_config.num_rotations, |
| | ) |
| | else: |
| | |
| | new_module = self._create_new_module( |
| | rotation_config=rotation_config, |
| | adapter_name=adapter_name, |
| | target=target, |
| | **optional_kwargs, |
| | ) |
| | if new_module is not None: |
| | self._replace_module(parent, target_name, new_module, target) |
| | |
| | def _replace_module(self, parent, child_name, new_module, child): |
| | |
| | setattr(parent, child_name, new_module) |
| | |
| | |
| | if hasattr(child, "base_layer"): |
| | child = child.base_layer |
| | |
| | meta = torch.device("meta") |
| | |
| | for name, module in new_module.named_modules(): |
| | if (self.prefix in name) or ("ranknum" in name): |
| | if hasattr(child, "qweight"): |
| | weight = child.qweight |
| | elif hasattr(child, "W_q"): |
| | weight = child.W_q |
| | elif hasattr(child, "weight"): |
| | weight = child.weight |
| | elif getattr(child, "in_proj_weight", None) is not None: |
| | weight = child.in_proj_weight |
| | else: |
| | weight = next(child.parameters()) |
| | if not any(p.device == meta for p in module.parameters()): |
| | module.to(weight.device) |
| | |
| | def _mark_only_adapters_as_trainable(self, model): |
| | |
| | |
| | for n, p in model.named_parameters(): |
| | if self.prefix not in n: |
| | p.requires_grad = False |
| | else: |
| | p.requires_grad = True |
| | |
| | |
| | for active_adapter in self.active_adapters: |
| | bias_config = self.peft_config[active_adapter].bias |
| | |
| | if bias_config == "none": |
| | continue |
| | elif bias_config == "all": |
| | |
| | for n, p in model.named_parameters(): |
| | if "bias" in n: |
| | p.requires_grad = True |
| | elif bias_config == "rotation_only": |
| | |
| | for name, m in model.named_modules(): |
| | if isinstance(m, RotationLayer): |
| | if hasattr(m, "bias") and m.bias is not None: |
| | m.bias.requires_grad = True |
| | else: |
| | raise NotImplementedError( |
| | f"Requested bias configuration '{bias_config}' is not implemented. " |
| | f"Supported values: 'none', 'all', 'rotation_only'" |
| | ) |
| | |
| | @staticmethod |
| | def _create_new_module( |
| | rotation_config, |
| | adapter_name: str, |
| | target: nn.Module, |
| | **kwargs, |
| | ) -> Optional[nn.Module]: |
| | """ |
| | Create a new rotation-augmented module. |
| | |
| | Args: |
| | rotation_config: Configuration for the rotation adapter |
| | adapter_name: Name of the adapter |
| | target: Base module to augment |
| | **kwargs: Additional arguments |
| | |
| | Returns: |
| | New RotationLayer module wrapping the target, or None if unsupported |
| | """ |
| | if isinstance(target, nn.Linear): |
| | return Linear( |
| | base_layer=target, |
| | adapter_name=adapter_name, |
| | r=rotation_config.r, |
| | T=rotation_config.T, |
| | num_rotations=rotation_config.num_rotations, |
| | **kwargs, |
| | ) |
| | else: |
| | |
| | print( |
| | f"Rotation layer does not support {type(target).__name__} yet. " |
| | f"Skipping this module." |
| | ) |
| | return None |
| | |
| | |
| | def __getattr__(self, name: str): |
| | """Forward missing attributes to the wrapped module.""" |
| | try: |
| | return super().__getattr__(name) |
| | except AttributeError: |
| | if name == "model": |
| | raise |
| | return getattr(self.model, name) |
| | |
| | def get_peft_config_as_dict(self, inference: bool = False): |
| | config_dict = {} |
| | for key, value in self.peft_config.items(): |
| | config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} |
| | if inference: |
| | config["inference_mode"] = True |
| | config_dict[key] = config |
| | return config |
| | |
| | |
| | def _set_adapter_layers(self, enabled=True): |
| | for module in self.model.modules(): |
| | if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): |
| | module.enable_adapters(enabled) |
| |
|
| | def enable_adapter_layers(self) -> None: |
| | """Enable all adapters. |
| | |
| | Call this if you have previously disabled all adapters and want to re-enable them. |
| | """ |
| | self._set_adapter_layers(enabled=True) |
| |
|
| | def disable_adapter_layers(self): |
| | for active_adapter in self.active_adapters: |
| | val = self.peft_config[active_adapter].bias |
| | if val != "none": |
| | msg = ( |
| | f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " |
| | "output as the base model would without adaption." |
| | ) |
| | print(msg) |
| | self._set_adapter_layers(enabled=False) |
| |
|
| | def set_adapter(self, adapter_name): |
| | """Set the active adapter(s). |
| | |
| | Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is |
| | not desired, use the following code. |
| | |
| | ```py |
| | >>> for name, param in model_peft.named_parameters(): |
| | ... if ...: # some check on name (ex. if 'lora' in name) |
| | ... param.requires_grad = False |
| | ``` |
| | |
| | Args: |
| | adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. |
| | """ |
| | for module in self.model.modules(): |
| | if isinstance(module, RotationLayer): |
| | if module.merged: |
| | print("Adapter cannot be set when the model is merged. Unmerging the model first.") |
| | module.unmerge() |
| | module.set_adapter(adapter_name) |
| | self.active_adapter = adapter_name |
| | |
| | def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None: |
| | """ |
| | Merge adapter weights into the base model weights. |
| | |
| | This can speed up inference by eliminating the need for runtime |
| | rotation computations. |
| | |
| | Args: |
| | adapter_names: List of adapter names to merge. If None, merges all |
| | active adapters. |
| | """ |
| | for module in self.model.modules(): |
| | if isinstance(module, RotationLayer): |
| | module.merge(safe_merge=False, adapter_names=adapter_names) |
| | |
| | |
| | def unmerge_adapter(self) -> None: |
| | """ |
| | Unmerge adapter weights from the base model weights. |
| | |
| | This reverses the merge operation, restoring dynamic adapter behavior. |
| | """ |
| | for module in self.model.modules(): |
| | if isinstance(module, RotationLayer): |
| | module.unmerge() |
| | |
| | @staticmethod |
| | def _prepare_adapter_config(peft_config, model_config): |
| | |
| | if peft_config.target_modules is None: |
| | if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING: |
| | raise ValueError("Please specify `target_modules` in `peft_config`") |
| | peft_config.target_modules = set( |
| | TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]] |
| | ) |
| | |
| | return peft_config |
| | |
| | |
| | def _check_new_adapter_config(self, config) -> None: |
| | """ |
| | Check the validity of a new adapter configuration. |
| | |
| | Args: |
| | config: Configuration to validate |
| | |
| | Raises: |
| | ValueError: If configuration is invalid |
| | """ |
| | |
| | if config.r <= 0: |
| | raise ValueError(f"r must be positive, got {config.r}") |
| | |
| | |
| | if config.num_rotations <= 0: |
| | raise ValueError( |
| | f"num_rotations must be positive, got {config.num_rotations}" |
| | ) |
| | |
| | |
| | |
| | valid_bias_configs = ["none", "all", "rotation_only"] |
| | if hasattr(config, "bias") and config.bias not in valid_bias_configs: |
| | raise ValueError( |
| | f"Invalid bias configuration '{config.bias}'. " |
| | f"Must be one of {valid_bias_configs}" |
| | ) |
| | |
| | |
| | def _unload_and_optionally_merge( |
| | self, |
| | merge=True, |
| | progressbar: bool = False, |
| | safe_merge: bool = False, |
| | adapter_names: Optional[list[str]] = None, |
| | ): |
| | if merge: |
| | self._check_merge_allowed() |
| |
|
| | key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] |
| | desc = "Unloading " + ("and merging " if merge else "") + "model" |
| | for key in tqdm(key_list, disable=not progressbar, desc=desc): |
| | try: |
| | parent, target, target_name = _get_submodules(self.model, key) |
| | except AttributeError: |
| | continue |
| | with onload_layer(target): |
| | if hasattr(target, "unload_and_optionally_merge_module"): |
| | |
| | unloaded_module = target.unload_and_optionally_merge_module( |
| | merge=merge, safe_merge=safe_merge, adapter_names=adapter_names |
| | ) |
| | self._replace_module(parent, target_name, unloaded_module, target) |
| | elif hasattr(target, "base_layer"): |
| | if merge: |
| | target.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
| | self._replace_module(parent, target_name, target.get_base_layer(), target) |
| |
|
| | return self.model |
| |
|
| | def delete_adapter(self, adapter_name: str) -> None: |
| | """ |
| | Deletes an existing adapter. |
| | |
| | Args: |
| | adapter_name (str): Name of the adapter to be deleted. |
| | """ |
| | if adapter_name not in list(self.peft_config.keys()): |
| | raise ValueError(f"Adapter {adapter_name} does not exist") |
| | del self.peft_config[adapter_name] |
| |
|
| | key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] |
| | new_adapter = None |
| | for key in key_list: |
| | _, target, _ = _get_submodules(self.model, key) |
| | if isinstance(target, RotationLayer): |
| | target.delete_adapter(adapter_name) |
| | if new_adapter is None: |
| | new_adapter = target.active_adapters[:] |
| |
|
| | self.active_adapter = new_adapter or [] |
| | self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter) |
| |
|
| | def merge_and_unload( |
| | self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None |
| | ) -> torch.nn.Module: |
| | r""" |
| | This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as |
| | a standalone model. |
| | |
| | Args: |
| | progressbar (`bool`): |
| | whether to show a progressbar indicating the unload and merge process |
| | safe_merge (`bool`): |
| | whether to activate the safe merging check to check if there is any potential Nan in the adapter |
| | weights |
| | adapter_names (`List[str]`, *optional*): |
| | The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
| | to `None`. |
| | |
| | """ |
| | return self._unload_and_optionally_merge( |
| | progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names |
| | ) |
| |
|
| | def unload(self) -> torch.nn.Module: |
| | """ |
| | Gets back the base model by removing all the oft modules without merging. This gives back the original base |
| | model. |
| | """ |
| | return self._unload_and_optionally_merge(merge=False) |
| | |
| |
|