| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Adapted from |
| https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py |
| """ |
|
|
| import inspect |
| from inspect import signature |
| from typing import Union |
|
|
| from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging |
| from ..quantization_config import QuantizationMethod |
|
|
|
|
| if is_torch_available(): |
| import torch |
| import torch.nn as nn |
|
|
| if is_bitsandbytes_available(): |
| import bitsandbytes as bnb |
|
|
| if is_accelerate_available(): |
| import accelerate |
| from accelerate import init_empty_weights |
| from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def _replace_with_bnb_linear( |
| model, |
| modules_to_not_convert=None, |
| current_key_name=None, |
| quantization_config=None, |
| has_been_replaced=False, |
| ): |
| """ |
| Private method that wraps the recursion for module replacement. |
| |
| Returns the converted model and a boolean that indicates if the conversion has been successfull or not. |
| """ |
| for name, module in model.named_children(): |
| if current_key_name is None: |
| current_key_name = [] |
| current_key_name.append(name) |
|
|
| if isinstance(module, nn.Linear) and name not in modules_to_not_convert: |
| |
| current_key_name_str = ".".join(current_key_name) |
| if not any( |
| (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert |
| ): |
| with init_empty_weights(): |
| in_features = module.in_features |
| out_features = module.out_features |
|
|
| if quantization_config.quantization_method() == "llm_int8": |
| model._modules[name] = bnb.nn.Linear8bitLt( |
| in_features, |
| out_features, |
| module.bias is not None, |
| has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, |
| threshold=quantization_config.llm_int8_threshold, |
| ) |
| has_been_replaced = True |
| else: |
| if ( |
| quantization_config.llm_int8_skip_modules is not None |
| and name in quantization_config.llm_int8_skip_modules |
| ): |
| pass |
| else: |
| extra_kwargs = ( |
| {"quant_storage": quantization_config.bnb_4bit_quant_storage} |
| if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters) |
| else {} |
| ) |
| model._modules[name] = bnb.nn.Linear4bit( |
| in_features, |
| out_features, |
| module.bias is not None, |
| quantization_config.bnb_4bit_compute_dtype, |
| compress_statistics=quantization_config.bnb_4bit_use_double_quant, |
| quant_type=quantization_config.bnb_4bit_quant_type, |
| **extra_kwargs, |
| ) |
| has_been_replaced = True |
| |
| model._modules[name].source_cls = type(module) |
| |
| model._modules[name].requires_grad_(False) |
| if len(list(module.children())) > 0: |
| _, has_been_replaced = _replace_with_bnb_linear( |
| module, |
| modules_to_not_convert, |
| current_key_name, |
| quantization_config, |
| has_been_replaced=has_been_replaced, |
| ) |
| |
| current_key_name.pop(-1) |
| return model, has_been_replaced |
|
|
|
|
| def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): |
| """ |
| Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or |
| `bnb.nn.Linear4bit` using the `bitsandbytes` library. |
| |
| References: |
| * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at |
| Scale](https://arxiv.org/abs/2208.07339) |
| * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) |
| |
| Parameters: |
| model (`torch.nn.Module`): |
| Input model or `torch.nn.Module` as the function is run recursively. |
| modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`): |
| Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in |
| full precision for numerical stability reasons. |
| current_key_name (`List[`str`]`, *optional*): |
| An array to track the current key of the recursion. This is used to check whether the current key (part of |
| it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or |
| `disk`). |
| quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'): |
| To configure and manage settings related to quantization, a technique used to compress neural network |
| models by reducing the precision of the weights and activations, thus making models more efficient in terms |
| of both storage and computation. |
| """ |
| model, has_been_replaced = _replace_with_bnb_linear( |
| model, modules_to_not_convert, current_key_name, quantization_config |
| ) |
|
|
| if not has_been_replaced: |
| logger.warning( |
| "You are loading your model in 8bit or 4bit but no linear modules were found in your model." |
| " Please double check your model architecture, or submit an issue on github if you think this is" |
| " a bug." |
| ) |
|
|
| return model |
|
|
|
|
| |
| def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): |
| """ |
| Helper function to dequantize 4bit or 8bit bnb weights. |
| |
| If the weight is not a bnb quantized weight, it will be returned as is. |
| """ |
| if not isinstance(weight, torch.nn.Parameter): |
| raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") |
|
|
| cls_name = weight.__class__.__name__ |
| if cls_name not in ("Params4bit", "Int8Params"): |
| return weight |
|
|
| if cls_name == "Params4bit": |
| output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) |
| logger.warning_once( |
| f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" |
| ) |
| return output_tensor |
|
|
| if state.SCB is None: |
| state.SCB = weight.SCB |
|
|
| im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) |
| im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) |
| im, Sim = bnb.functional.transform(im, "col32") |
| if state.CxB is None: |
| state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) |
| out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) |
| return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() |
|
|
|
|
| def _create_accelerate_new_hook(old_hook): |
| r""" |
| Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: |
| https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with |
| some changes |
| """ |
| old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) |
| old_hook_attr = old_hook.__dict__ |
| filtered_old_hook_attr = {} |
| old_hook_init_signature = inspect.signature(old_hook_cls.__init__) |
| for k in old_hook_attr.keys(): |
| if k in old_hook_init_signature.parameters: |
| filtered_old_hook_attr[k] = old_hook_attr[k] |
| new_hook = old_hook_cls(**filtered_old_hook_attr) |
| return new_hook |
|
|
|
|
| def _dequantize_and_replace( |
| model, |
| modules_to_not_convert=None, |
| current_key_name=None, |
| quantization_config=None, |
| has_been_replaced=False, |
| ): |
| """ |
| Converts a quantized model into its dequantized original version. The newly converted model will have some |
| performance drop compared to the original model before quantization - use it only for specific usecases such as |
| QLoRA adapters merging. |
| |
| Returns the converted model and a boolean that indicates if the conversion has been successfull or not. |
| """ |
| quant_method = quantization_config.quantization_method() |
|
|
| target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit |
|
|
| for name, module in model.named_children(): |
| if current_key_name is None: |
| current_key_name = [] |
| current_key_name.append(name) |
|
|
| if isinstance(module, target_cls) and name not in modules_to_not_convert: |
| |
| current_key_name_str = ".".join(current_key_name) |
|
|
| if not any( |
| (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert |
| ): |
| bias = getattr(module, "bias", None) |
|
|
| device = module.weight.device |
| with init_empty_weights(): |
| new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None) |
|
|
| if quant_method == "llm_int8": |
| state = module.state |
| else: |
| state = None |
|
|
| new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) |
|
|
| if bias is not None: |
| new_module.bias = bias |
|
|
| |
| if hasattr(module, "_hf_hook"): |
| old_hook = module._hf_hook |
| new_hook = _create_accelerate_new_hook(old_hook) |
|
|
| remove_hook_from_module(module) |
| add_hook_to_module(new_module, new_hook) |
|
|
| new_module.to(device) |
| model._modules[name] = new_module |
| has_been_replaced = True |
| if len(list(module.children())) > 0: |
| _, has_been_replaced = _dequantize_and_replace( |
| module, |
| modules_to_not_convert, |
| current_key_name, |
| quantization_config, |
| has_been_replaced=has_been_replaced, |
| ) |
| |
| current_key_name.pop(-1) |
| return model, has_been_replaced |
|
|
|
|
| def dequantize_and_replace( |
| model, |
| modules_to_not_convert=None, |
| quantization_config=None, |
| ): |
| model, has_been_replaced = _dequantize_and_replace( |
| model, |
| modules_to_not_convert=modules_to_not_convert, |
| quantization_config=quantization_config, |
| ) |
|
|
| if not has_been_replaced: |
| logger.warning( |
| "For some reason the model has not been properly dequantized. You might see unexpected behavior." |
| ) |
|
|
| return model |
|
|
|
|
| def _check_bnb_status(module) -> Union[bool, bool]: |
| is_loaded_in_4bit_bnb = ( |
| hasattr(module, "is_loaded_in_4bit") |
| and module.is_loaded_in_4bit |
| and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES |
| ) |
| is_loaded_in_8bit_bnb = ( |
| hasattr(module, "is_loaded_in_8bit") |
| and module.is_loaded_in_8bit |
| and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES |
| ) |
| return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb |
|
|