| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Adapted from |
| https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py |
| """ |
| import warnings |
| from typing import Dict, Optional, Union |
|
|
| from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer |
| from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod |
|
|
|
|
| AUTO_QUANTIZER_MAPPING = { |
| "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, |
| "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, |
| } |
|
|
| AUTO_QUANTIZATION_CONFIG_MAPPING = { |
| "bitsandbytes_4bit": BitsAndBytesConfig, |
| "bitsandbytes_8bit": BitsAndBytesConfig, |
| } |
|
|
|
|
| class DiffusersAutoQuantizer: |
| """ |
| The auto diffusers quantizer class that takes care of automatically instantiating to the correct |
| `DiffusersQuantizer` given the `QuantizationConfig`. |
| """ |
|
|
| @classmethod |
| def from_dict(cls, quantization_config_dict: Dict): |
| quant_method = quantization_config_dict.get("quant_method", None) |
| |
| if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False): |
| suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit" |
| quant_method = QuantizationMethod.BITS_AND_BYTES + suffix |
| elif quant_method is None: |
| raise ValueError( |
| "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized" |
| ) |
|
|
| if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): |
| raise ValueError( |
| f"Unknown quantization type, got {quant_method} - supported types are:" |
| f" {list(AUTO_QUANTIZER_MAPPING.keys())}" |
| ) |
|
|
| target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method] |
| return target_cls.from_dict(quantization_config_dict) |
|
|
| @classmethod |
| def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs): |
| |
| if isinstance(quantization_config, dict): |
| quantization_config = cls.from_dict(quantization_config) |
|
|
| quant_method = quantization_config.quant_method |
|
|
| |
| |
| if quant_method == QuantizationMethod.BITS_AND_BYTES: |
| if quantization_config.load_in_8bit: |
| quant_method += "_8bit" |
| else: |
| quant_method += "_4bit" |
|
|
| if quant_method not in AUTO_QUANTIZER_MAPPING.keys(): |
| raise ValueError( |
| f"Unknown quantization type, got {quant_method} - supported types are:" |
| f" {list(AUTO_QUANTIZER_MAPPING.keys())}" |
| ) |
|
|
| target_cls = AUTO_QUANTIZER_MAPPING[quant_method] |
| return target_cls(quantization_config, **kwargs) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| model_config = cls.load_config(pretrained_model_name_or_path, **kwargs) |
| if getattr(model_config, "quantization_config", None) is None: |
| raise ValueError( |
| f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized." |
| ) |
| quantization_config_dict = model_config.quantization_config |
| quantization_config = cls.from_dict(quantization_config_dict) |
| |
| quantization_config.update(kwargs) |
|
|
| return cls.from_config(quantization_config) |
|
|
| @classmethod |
| def merge_quantization_configs( |
| cls, |
| quantization_config: Union[dict, QuantizationConfigMixin], |
| quantization_config_from_args: Optional[QuantizationConfigMixin], |
| ): |
| """ |
| handles situations where both quantization_config from args and quantization_config from model config are |
| present. |
| """ |
| if quantization_config_from_args is not None: |
| warning_msg = ( |
| "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading" |
| " already has a `quantization_config` attribute. The `quantization_config` from the model will be used." |
| ) |
| else: |
| warning_msg = "" |
|
|
| if isinstance(quantization_config, dict): |
| quantization_config = cls.from_dict(quantization_config) |
|
|
| if warning_msg != "": |
| warnings.warn(warning_msg) |
|
|
| return quantization_config |
|
|