Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Union, Dict, List | |
| from einops import rearrange | |
| from refnet.util import exists, default | |
| from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock | |
| def get_module_safe(self, module_path: str): | |
| current_module = self | |
| try: | |
| for part in module_path.split('.'): | |
| current_module = getattr(current_module, part) | |
| return current_module | |
| except AttributeError: | |
| raise AttributeError(f"Cannot find modules {module_path}") | |
| def switch_lora(self, v, label=None): | |
| for t in [self.to_q, self.to_k, self.to_v]: | |
| t.set_lora_active(v, label) | |
| def lora_forward(self, x, context, mask, scale=1., scale_factor= None): | |
| def qkv_forward(x, context): | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| return q, k, v | |
| assert exists(scale_factor), "Scale factor must be assigned before masked attention" | |
| mask = rearrange( | |
| F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"), | |
| "b c h w -> b (h w) c" | |
| ).contiguous() | |
| c1, c2 = context.chunk(2, dim=1) | |
| # Background region cross-attention | |
| if self.use_lora: | |
| self.switch_lora(False, "foreground") | |
| q2, k2, v2 = qkv_forward(x, c2) | |
| bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale | |
| # Character region cross-attention | |
| if self.use_lora: | |
| self.switch_lora(True, "foreground") | |
| q1, k1, v1 = qkv_forward(x, c1) | |
| fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale | |
| fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale | |
| return fg_out * mask + bg_out * (1 - mask) | |
| # return torch.where(mask > self.mask_threshold, fg_out, bg_out) | |
| def dual_lora_forward(self, x, context, mask, scale=1., scale_factor=None): | |
| """ | |
| This function hacks cross-attention layers. | |
| Args: | |
| x: Query input | |
| context: Key and value input | |
| mask: Character mask | |
| scale: Attention scale | |
| sacle_factor: Current latent size factor | |
| """ | |
| def qkv_forward(x, context): | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| return q, k, v | |
| assert exists(scale_factor), "Scale factor must be assigned before masked attention" | |
| mask = rearrange( | |
| F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"), | |
| "b c h w -> b (h w) c" | |
| ).contiguous() | |
| c1, c2 = context.chunk(2, dim=1) | |
| # Background region cross-attention | |
| if self.use_lora: | |
| self.switch_lora(True, "background") | |
| self.switch_lora(False, "foreground") | |
| q2, k2, v2 = qkv_forward(x, c2) | |
| bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale | |
| # Foreground region cross-attention | |
| if self.use_lora: | |
| self.switch_lora(False, "background") | |
| self.switch_lora(True, "foreground") | |
| q1, k1, v1 = qkv_forward(x, c1) | |
| fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale | |
| fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale | |
| # return fg_out * mask + bg_out * (1 - mask) | |
| return torch.where(mask > self.mask_threshold, fg_out, bg_out) | |
| class MultiLoraInjectedLinear(nn.Linear): | |
| """ | |
| A linear layer that can hold multiple LoRA adapters and merge them. | |
| """ | |
| def __init__( | |
| self, | |
| in_features, | |
| out_features, | |
| bias = False, | |
| ): | |
| super().__init__(in_features, out_features, bias) | |
| self.lora_adapters: Dict[str, Dict[str, nn.Module]] = {} # {label: {up/down: layer}} | |
| self.lora_scales: Dict[str, float] = {} | |
| self.active_loras: Dict[str, bool] = {} | |
| self.original_weight = None | |
| self.original_bias = None | |
| # Freeze original weights | |
| self.weight.requires_grad_(False) | |
| if exists(self.bias): | |
| self.bias.requires_grad_(False) | |
| def add_lora_adapter(self, label: str, r: int, scale: float = 1.0, dropout_p: float = 0.0): | |
| """Add a new LoRA adapter with the given label.""" | |
| if isinstance(r, float): | |
| r = int(r * self.out_features) | |
| lora_down = nn.Linear(self.in_features, r, bias=self.bias is not None) | |
| lora_up = nn.Linear(r, self.out_features, bias=self.bias is not None) | |
| dropout = nn.Dropout(dropout_p) | |
| # Initialize weights | |
| nn.init.normal_(lora_down.weight, std=1 / r) | |
| nn.init.zeros_(lora_up.weight) | |
| self.lora_adapters[label] = { | |
| 'down': lora_down, | |
| 'up': lora_up, | |
| 'dropout': dropout, | |
| } | |
| self.lora_scales[label] = scale | |
| self.active_loras[label] = True | |
| # Register as submodules | |
| self.add_module(f'lora_down_{label}', lora_down) | |
| self.add_module(f'lora_up_{label}', lora_up) | |
| self.add_module(f'lora_dropout_{label}', dropout) | |
| def get_trainable_layers(self, label: str = None): | |
| """Get trainable layers for specific LoRA or all LoRAs.""" | |
| layers = [] | |
| if exists(label): | |
| if label in self.lora_adapters: | |
| adapter = self.lora_adapters[label] | |
| layers.extend([adapter['down'], adapter['up']]) | |
| else: | |
| for adapter in self.lora_adapters.values(): | |
| layers.extend([adapter['down'], adapter['up']]) | |
| return layers | |
| def set_lora_active(self, active: bool, label: str): | |
| """Activate or deactivate a specific LoRA adapter.""" | |
| if label in self.active_loras: | |
| self.active_loras[label] = active | |
| def set_lora_scale(self, scale: float, label: str): | |
| """Set the scale for a specific LoRA adapter.""" | |
| if label in self.lora_scales: | |
| self.lora_scales[label] = scale | |
| def merge_lora_weights(self, labels: List[str] = None): | |
| """Merge specified LoRA adapters into the base weights.""" | |
| if labels is None: | |
| labels = list(self.lora_adapters.keys()) | |
| # Store original weights if not already stored | |
| if self.original_weight is None: | |
| self.original_weight = self.weight.clone() | |
| if exists(self.bias): | |
| self.original_bias = self.bias.clone() | |
| merged_weight = self.original_weight.clone() | |
| merged_bias = self.original_bias.clone() if exists(self.original_bias) else None | |
| for label in labels: | |
| if label in self.lora_adapters and self.active_loras.get(label, False): | |
| lora_up, lora_down = self.lora_adapters[label]['up'], self.lora_adapters[label]['down'] | |
| scale = self.lora_scales[label] | |
| lora_weight = lora_up.weight @ lora_down.weight | |
| merged_weight += scale * lora_weight | |
| if exists(merged_bias) and exists(lora_up.bias): | |
| lora_bias = lora_up.bias + lora_up.weight @ lora_down.bias | |
| merged_bias += scale * lora_bias | |
| # Update weights | |
| self.weight = nn.Parameter(merged_weight, requires_grad=False) | |
| if exists(merged_bias): | |
| self.bias = nn.Parameter(merged_bias, requires_grad=False) | |
| # Deactivate all LoRAs after merging | |
| for label in labels: | |
| self.active_loras[label] = False | |
| def recover_original_weight(self): | |
| """Recover the original weights before any LoRA modifications.""" | |
| if self.original_weight is not None: | |
| self.weight = nn.Parameter(self.original_weight.clone()) | |
| if exists(self.original_bias): | |
| self.bias = nn.Parameter(self.original_bias.clone()) | |
| # Reactivate all LoRAs | |
| for label in self.active_loras: | |
| self.active_loras[label] = True | |
| def forward(self, input): | |
| output = super().forward(input) | |
| # Add contributions from active LoRAs | |
| for label, adapter in self.lora_adapters.items(): | |
| if self.active_loras.get(label, False): | |
| lora_out = adapter['up'](adapter['dropout'](adapter['down'](input))) | |
| output += self.lora_scales[label] * lora_out | |
| return output | |
| class LoraModules: | |
| def __init__(self, sd, lora_params, *args, **kwargs): | |
| self.modules = {} | |
| self.multi_lora_layers: Dict[str, MultiLoraInjectedLinear] = {} # path -> MultiLoraLayer | |
| for cfg in lora_params: | |
| root_module = get_module_safe(sd, cfg.pop("root_module")) | |
| label = cfg.pop("label", "lora") | |
| self.inject_lora(label, root_module, **cfg) | |
| def inject_lora( | |
| self, | |
| label, | |
| root_module, | |
| r, | |
| split_forward = False, | |
| target_keys = ("to_q", "to_k", "to_v"), | |
| filter_keys = None, | |
| target_class = None, | |
| scale = 1.0, | |
| dropout_p = 0.0, | |
| ): | |
| def check_condition(path, child, class_list): | |
| if exists(filter_keys) and any(path.find(key) > -1 for key in filter_keys): | |
| return False | |
| if exists(target_keys) and any(path.endswith(key) for key in target_keys): | |
| return True | |
| if exists(class_list) and any( | |
| isinstance(child, module_class) for module_class in class_list | |
| ): | |
| return True | |
| return False | |
| def retrieve_target_modules(): | |
| from refnet.util import get_obj_from_str | |
| target_class_list = [get_obj_from_str(t) for t in target_class] if exists(target_class) else None | |
| modules = [] | |
| for name, module in root_module.named_modules(): | |
| for key, child in module._modules.items(): | |
| full_path = name + '.' + key if name else key | |
| if check_condition(full_path, child, target_class_list): | |
| modules.append((module, child, key, full_path)) | |
| return modules | |
| modules: list[Union[nn.Module]] = [] | |
| retrieved_modules = retrieve_target_modules() | |
| for parent, child, child_name, full_path in retrieved_modules: | |
| # Check if this layer already has a MultiLoraInjectedLinear | |
| if full_path in self.multi_lora_layers: | |
| # Add LoRA to existing MultiLoraInjectedLinear | |
| multi_lora_layer = self.multi_lora_layers[full_path] | |
| multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p) | |
| else: | |
| # Check if the current layer is already a MultiLoraInjectedLinear | |
| if isinstance(child, MultiLoraInjectedLinear): | |
| child.add_lora_adapter(label, r, scale, dropout_p) | |
| self.multi_lora_layers[full_path] = child | |
| else: | |
| # Replace with MultiLoraInjectedLinear and add first LoRA | |
| multi_lora_layer = MultiLoraInjectedLinear( | |
| in_features=child.weight.shape[1], | |
| out_features=child.weight.shape[0], | |
| bias=exists(child.bias), | |
| ) | |
| multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p) | |
| parent._modules[child_name] = multi_lora_layer | |
| self.multi_lora_layers[full_path] = multi_lora_layer | |
| if split_forward: | |
| parent.masked_forward = dual_lora_forward.__get__(parent, parent.__class__) | |
| else: | |
| parent.masked_forward = lora_forward.__get__(parent, parent.__class__) | |
| parent.use_lora = True | |
| parent.switch_lora = switch_lora.__get__(parent, parent.__class__) | |
| modules.append(parent) | |
| self.modules[label] = modules | |
| print(f"Activated {label} lora with {len(self.multi_lora_layers)} layers") | |
| return self.multi_lora_layers, modules | |
| def get_trainable_layers(self, label = None): | |
| """Get all trainable layers, optionally filtered by label.""" | |
| layers = [] | |
| for lora_layer in self.multi_lora_layers.values(): | |
| layers += lora_layer.get_trainable_layers(label) | |
| return layers | |
| def switch_lora(self, mode, label = None): | |
| if exists(label): | |
| for layer in self.multi_lora_layers.values(): | |
| layer.set_lora_active(mode, label) | |
| for module in self.modules[label]: | |
| module.use_lora = mode | |
| else: | |
| for layer in self.multi_lora_layers.values(): | |
| for lora_label in layer.lora_adapters.keys(): | |
| layer.set_lora_active(mode, lora_label) | |
| for modules in self.modules.values(): | |
| for module in modules: | |
| module.use_lora = mode | |
| def adjust_lora_scales(self, scale, label = None): | |
| if exists(label): | |
| for layer in self.multi_lora_layers.values(): | |
| layer.set_lora_scale(scale, label) | |
| else: | |
| for layer in self.multi_lora_layers.values(): | |
| for lora_label in layer.lora_adapters.keys(): | |
| layer.set_lora_scale(scale, lora_label) | |
| def merge_lora(self, labels = None): | |
| if labels is None: | |
| labels = list(self.modules.keys()) | |
| elif isinstance(labels, str): | |
| labels = [labels] | |
| for layer in self.multi_lora_layers.values(): | |
| layer.merge_lora_weights(labels) | |
| def recover_lora(self): | |
| for layer in self.multi_lora_layers.values(): | |
| layer.recover_original_weight() | |
| def get_lora_info(self): | |
| """Get information about all LoRA adapters.""" | |
| info = {} | |
| for path, layer in self.multi_lora_layers.items(): | |
| info[path] = { | |
| 'labels': list(layer.lora_adapters.keys()), | |
| 'active': {label: active for label, active in layer.active_loras.items()}, | |
| 'scales': layer.lora_scales.copy() | |
| } | |
| return info |