nvan15's picture
Batch upload part 19
b816a2c verified
from typing import Optional
import torch
import torch.nn as nn
from enum import Enum
from dataclasses import asdict
from tqdm import tqdm
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules
from .layer import RotationLayer, Linear
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
class RotationTuner(BaseTuner):
prefix: str = "rotation_"
tuner_layer_class = RotationLayer
target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING
@staticmethod
def _check_target_module_exists(rotation_config, key: str) -> bool:
return check_target_module_exists(rotation_config, key)
def _create_and_replace(
self,
rotation_config,
adapter_name: str,
target: nn.Module,
target_name: str,
parent: nn.Module,
current_key: str,
**optional_kwargs,
) -> None:
"""
Create and replace a target module with a rotation-augmented version.
This method is called when an existing module is already a RotationLayer
and needs to have a new adapter added to it.
Args:
rotation_config: Configuration for the rotation adapter
adapter_name: Name of the adapter to add
target: The target module to augment
target_name: Name of the target module
parent: Parent module containing the target
current_key: Full key path to the current module
**optional_kwargs: Additional optional arguments
Raises:
ValueError: If current_key is not provided
"""
if current_key is None:
raise ValueError("current_key must be provided to create Rotation layer")
# Check if target is already a RotationLayer
if isinstance(target, RotationLayer):
target.update_layer(
adapter_name=adapter_name,
r=rotation_config.r,
T=rotation_config.T,
num_rotations=rotation_config.num_rotations,
)
else:
# Create new rotation layer
new_module = self._create_new_module(
rotation_config=rotation_config,
adapter_name=adapter_name,
target=target,
**optional_kwargs,
)
if new_module is not None:
self._replace_module(parent, target_name, new_module, target)
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "W_q"):
weight = child.W_q
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
else:
weight = next(child.parameters())
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)
def _mark_only_adapters_as_trainable(self, model):
# First, freeze all parameters
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
else:
p.requires_grad = True
# Handle bias parameters based on config
for active_adapter in self.active_adapters:
bias_config = self.peft_config[active_adapter].bias
if bias_config == "none":
continue
elif bias_config == "all":
# Enable all bias parameters
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias_config == "rotation_only":
# Enable only bias in rotation layers
for name, m in model.named_modules():
if isinstance(m, RotationLayer):
if hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError(
f"Requested bias configuration '{bias_config}' is not implemented. "
f"Supported values: 'none', 'all', 'rotation_only'"
)
@staticmethod
def _create_new_module(
rotation_config,
adapter_name: str,
target: nn.Module,
**kwargs,
) -> Optional[nn.Module]:
"""
Create a new rotation-augmented module.
Args:
rotation_config: Configuration for the rotation adapter
adapter_name: Name of the adapter
target: Base module to augment
**kwargs: Additional arguments
Returns:
New RotationLayer module wrapping the target, or None if unsupported
"""
if isinstance(target, nn.Linear):
return Linear(
base_layer=target,
adapter_name=adapter_name,
r=rotation_config.r,
T=rotation_config.T,
num_rotations=rotation_config.num_rotations,
**kwargs,
)
else:
# Unsupported layer type
print(
f"Rotation layer does not support {type(target).__name__} yet. "
f"Skipping this module."
)
return None
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
if inference:
config["inference_mode"] = True
config_dict[key] = config
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self) -> None:
"""Enable all adapters.
Call this if you have previously disabled all adapters and want to re-enable them.
"""
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
for active_adapter in self.active_adapters:
val = self.peft_config[active_adapter].bias
if val != "none":
msg = (
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
"output as the base model would without adaption."
)
print(msg)
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name):
"""Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
"""
for module in self.model.modules():
if isinstance(module, RotationLayer):
if module.merged:
print("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge adapter weights into the base model weights.
This can speed up inference by eliminating the need for runtime
rotation computations.
Args:
adapter_names: List of adapter names to merge. If None, merges all
active adapters.
"""
for module in self.model.modules():
if isinstance(module, RotationLayer):
module.merge(safe_merge=False, adapter_names=adapter_names)
def unmerge_adapter(self) -> None:
"""
Unmerge adapter weights from the base model weights.
This reverses the merge operation, restoring dynamic adapter behavior.
"""
for module in self.model.modules():
if isinstance(module, RotationLayer):
module.unmerge()
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _check_new_adapter_config(self, config) -> None:
"""
Check the validity of a new adapter configuration.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
# Validate rank
if config.r <= 0:
raise ValueError(f"r must be positive, got {config.r}")
# Validate num_rotations
if config.num_rotations <= 0:
raise ValueError(
f"num_rotations must be positive, got {config.num_rotations}"
)
# Validate bias configuration
valid_bias_configs = ["none", "all", "rotation_only"]
if hasattr(config, "bias") and config.bias not in valid_bias_configs:
raise ValueError(
f"Invalid bias configuration '{config.bias}'. "
f"Must be one of {valid_bias_configs}"
)
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
if merge:
self._check_merge_allowed()
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "unload_and_optionally_merge_module"):
# if layers have special unloading method, like MultiheadAttention, use that
unloaded_module = target.unload_and_optionally_merge_module(
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
)
self._replace_module(parent, target_name, unloaded_module, target)
elif hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
return self.model
def delete_adapter(self, adapter_name: str) -> None:
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, RotationLayer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]
self.active_adapter = new_adapter or []
self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> torch.nn.Module:
r"""
This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as
a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self) -> torch.nn.Module:
"""
Gets back the base model by removing all the oft modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)