| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import copy |
| from typing import TYPE_CHECKING, Dict, List, Union |
|
|
| from ..utils import logging |
|
|
|
|
| if TYPE_CHECKING: |
| |
| from ..models import UNet2DConditionModel |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def _translate_into_actual_layer_name(name): |
| """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" |
| if name == "mid": |
| return "mid_block.attentions.0" |
|
|
| updown, block, attn = name.split(".") |
|
|
| updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") |
| block = block.replace("block_", "") |
| attn = "attentions." + attn |
|
|
| return ".".join((updown, block, attn)) |
|
|
|
|
| def _maybe_expand_lora_scales( |
| unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0 |
| ): |
| blocks_with_transformer = { |
| "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], |
| "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], |
| } |
| transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} |
|
|
| expanded_weight_scales = [ |
| _maybe_expand_lora_scales_for_one_adapter( |
| weight_for_adapter, |
| blocks_with_transformer, |
| transformer_per_block, |
| unet.state_dict(), |
| default_scale=default_scale, |
| ) |
| for weight_for_adapter in weight_scales |
| ] |
|
|
| return expanded_weight_scales |
|
|
|
|
| def _maybe_expand_lora_scales_for_one_adapter( |
| scales: Union[float, Dict], |
| blocks_with_transformer: Dict[str, int], |
| transformer_per_block: Dict[str, int], |
| state_dict: None, |
| default_scale: float = 1.0, |
| ): |
| """ |
| Expands the inputs into a more granular dictionary. See the example below for more details. |
| |
| Parameters: |
| scales (`Union[float, Dict]`): |
| Scales dict to expand. |
| blocks_with_transformer (`Dict[str, int]`): |
| Dict with keys 'up' and 'down', showing which blocks have transformer layers |
| transformer_per_block (`Dict[str, int]`): |
| Dict with keys 'up' and 'down', showing how many transformer layers each block has |
| |
| E.g. turns |
| ```python |
| scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}} |
| blocks_with_transformer = {"down": [1, 2], "up": [0, 1]} |
| transformer_per_block = {"down": 2, "up": 3} |
| ``` |
| into |
| ```python |
| { |
| "down.block_1.0": 2, |
| "down.block_1.1": 2, |
| "down.block_2.0": 2, |
| "down.block_2.1": 2, |
| "mid": 3, |
| "up.block_0.0": 4, |
| "up.block_0.1": 4, |
| "up.block_0.2": 4, |
| "up.block_1.0": 5, |
| "up.block_1.1": 6, |
| "up.block_1.2": 7, |
| } |
| ``` |
| """ |
| if sorted(blocks_with_transformer.keys()) != ["down", "up"]: |
| raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") |
|
|
| if sorted(transformer_per_block.keys()) != ["down", "up"]: |
| raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") |
|
|
| if not isinstance(scales, dict): |
| |
| return scales |
|
|
| scales = copy.deepcopy(scales) |
|
|
| if "mid" not in scales: |
| scales["mid"] = default_scale |
| elif isinstance(scales["mid"], list): |
| if len(scales["mid"]) == 1: |
| scales["mid"] = scales["mid"][0] |
| else: |
| raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") |
|
|
| for updown in ["up", "down"]: |
| if updown not in scales: |
| scales[updown] = default_scale |
|
|
| |
| if not isinstance(scales[updown], dict): |
| scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} |
|
|
| |
| for i in blocks_with_transformer[updown]: |
| block = f"block_{i}" |
| |
| if block not in scales[updown]: |
| scales[updown][block] = default_scale |
| if not isinstance(scales[updown][block], list): |
| scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] |
| elif len(scales[updown][block]) == 1: |
| |
| scales[updown][block] = scales[updown][block] * transformer_per_block[updown] |
| elif len(scales[updown][block]) != transformer_per_block[updown]: |
| raise ValueError( |
| f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." |
| ) |
|
|
| |
| for i in blocks_with_transformer[updown]: |
| block = f"block_{i}" |
| for tf_idx, value in enumerate(scales[updown][block]): |
| scales[f"{updown}.{block}.{tf_idx}"] = value |
|
|
| del scales[updown] |
|
|
| for layer in scales.keys(): |
| if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): |
| raise ValueError( |
| f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." |
| ) |
|
|
| return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} |
|
|