tellurion's picture
initialize huggingface space demo
d066167
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Dict, List
from einops import rearrange
from refnet.util import exists, default
from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock
def get_module_safe(self, module_path: str):
current_module = self
try:
for part in module_path.split('.'):
current_module = getattr(current_module, part)
return current_module
except AttributeError:
raise AttributeError(f"Cannot find modules {module_path}")
def switch_lora(self, v, label=None):
for t in [self.to_q, self.to_k, self.to_v]:
t.set_lora_active(v, label)
def lora_forward(self, x, context, mask, scale=1., scale_factor= None):
def qkv_forward(x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
return q, k, v
assert exists(scale_factor), "Scale factor must be assigned before masked attention"
mask = rearrange(
F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
"b c h w -> b (h w) c"
).contiguous()
c1, c2 = context.chunk(2, dim=1)
# Background region cross-attention
if self.use_lora:
self.switch_lora(False, "foreground")
q2, k2, v2 = qkv_forward(x, c2)
bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
# Character region cross-attention
if self.use_lora:
self.switch_lora(True, "foreground")
q1, k1, v1 = qkv_forward(x, c1)
fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
return fg_out * mask + bg_out * (1 - mask)
# return torch.where(mask > self.mask_threshold, fg_out, bg_out)
def dual_lora_forward(self, x, context, mask, scale=1., scale_factor=None):
"""
This function hacks cross-attention layers.
Args:
x: Query input
context: Key and value input
mask: Character mask
scale: Attention scale
sacle_factor: Current latent size factor
"""
def qkv_forward(x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
return q, k, v
assert exists(scale_factor), "Scale factor must be assigned before masked attention"
mask = rearrange(
F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"),
"b c h w -> b (h w) c"
).contiguous()
c1, c2 = context.chunk(2, dim=1)
# Background region cross-attention
if self.use_lora:
self.switch_lora(True, "background")
self.switch_lora(False, "foreground")
q2, k2, v2 = qkv_forward(x, c2)
bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale
# Foreground region cross-attention
if self.use_lora:
self.switch_lora(False, "background")
self.switch_lora(True, "foreground")
q1, k1, v1 = qkv_forward(x, c1)
fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale
fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale
# return fg_out * mask + bg_out * (1 - mask)
return torch.where(mask > self.mask_threshold, fg_out, bg_out)
class MultiLoraInjectedLinear(nn.Linear):
"""
A linear layer that can hold multiple LoRA adapters and merge them.
"""
def __init__(
self,
in_features,
out_features,
bias = False,
):
super().__init__(in_features, out_features, bias)
self.lora_adapters: Dict[str, Dict[str, nn.Module]] = {} # {label: {up/down: layer}}
self.lora_scales: Dict[str, float] = {}
self.active_loras: Dict[str, bool] = {}
self.original_weight = None
self.original_bias = None
# Freeze original weights
self.weight.requires_grad_(False)
if exists(self.bias):
self.bias.requires_grad_(False)
def add_lora_adapter(self, label: str, r: int, scale: float = 1.0, dropout_p: float = 0.0):
"""Add a new LoRA adapter with the given label."""
if isinstance(r, float):
r = int(r * self.out_features)
lora_down = nn.Linear(self.in_features, r, bias=self.bias is not None)
lora_up = nn.Linear(r, self.out_features, bias=self.bias is not None)
dropout = nn.Dropout(dropout_p)
# Initialize weights
nn.init.normal_(lora_down.weight, std=1 / r)
nn.init.zeros_(lora_up.weight)
self.lora_adapters[label] = {
'down': lora_down,
'up': lora_up,
'dropout': dropout,
}
self.lora_scales[label] = scale
self.active_loras[label] = True
# Register as submodules
self.add_module(f'lora_down_{label}', lora_down)
self.add_module(f'lora_up_{label}', lora_up)
self.add_module(f'lora_dropout_{label}', dropout)
def get_trainable_layers(self, label: str = None):
"""Get trainable layers for specific LoRA or all LoRAs."""
layers = []
if exists(label):
if label in self.lora_adapters:
adapter = self.lora_adapters[label]
layers.extend([adapter['down'], adapter['up']])
else:
for adapter in self.lora_adapters.values():
layers.extend([adapter['down'], adapter['up']])
return layers
def set_lora_active(self, active: bool, label: str):
"""Activate or deactivate a specific LoRA adapter."""
if label in self.active_loras:
self.active_loras[label] = active
def set_lora_scale(self, scale: float, label: str):
"""Set the scale for a specific LoRA adapter."""
if label in self.lora_scales:
self.lora_scales[label] = scale
def merge_lora_weights(self, labels: List[str] = None):
"""Merge specified LoRA adapters into the base weights."""
if labels is None:
labels = list(self.lora_adapters.keys())
# Store original weights if not already stored
if self.original_weight is None:
self.original_weight = self.weight.clone()
if exists(self.bias):
self.original_bias = self.bias.clone()
merged_weight = self.original_weight.clone()
merged_bias = self.original_bias.clone() if exists(self.original_bias) else None
for label in labels:
if label in self.lora_adapters and self.active_loras.get(label, False):
lora_up, lora_down = self.lora_adapters[label]['up'], self.lora_adapters[label]['down']
scale = self.lora_scales[label]
lora_weight = lora_up.weight @ lora_down.weight
merged_weight += scale * lora_weight
if exists(merged_bias) and exists(lora_up.bias):
lora_bias = lora_up.bias + lora_up.weight @ lora_down.bias
merged_bias += scale * lora_bias
# Update weights
self.weight = nn.Parameter(merged_weight, requires_grad=False)
if exists(merged_bias):
self.bias = nn.Parameter(merged_bias, requires_grad=False)
# Deactivate all LoRAs after merging
for label in labels:
self.active_loras[label] = False
def recover_original_weight(self):
"""Recover the original weights before any LoRA modifications."""
if self.original_weight is not None:
self.weight = nn.Parameter(self.original_weight.clone())
if exists(self.original_bias):
self.bias = nn.Parameter(self.original_bias.clone())
# Reactivate all LoRAs
for label in self.active_loras:
self.active_loras[label] = True
def forward(self, input):
output = super().forward(input)
# Add contributions from active LoRAs
for label, adapter in self.lora_adapters.items():
if self.active_loras.get(label, False):
lora_out = adapter['up'](adapter['dropout'](adapter['down'](input)))
output += self.lora_scales[label] * lora_out
return output
class LoraModules:
def __init__(self, sd, lora_params, *args, **kwargs):
self.modules = {}
self.multi_lora_layers: Dict[str, MultiLoraInjectedLinear] = {} # path -> MultiLoraLayer
for cfg in lora_params:
root_module = get_module_safe(sd, cfg.pop("root_module"))
label = cfg.pop("label", "lora")
self.inject_lora(label, root_module, **cfg)
def inject_lora(
self,
label,
root_module,
r,
split_forward = False,
target_keys = ("to_q", "to_k", "to_v"),
filter_keys = None,
target_class = None,
scale = 1.0,
dropout_p = 0.0,
):
def check_condition(path, child, class_list):
if exists(filter_keys) and any(path.find(key) > -1 for key in filter_keys):
return False
if exists(target_keys) and any(path.endswith(key) for key in target_keys):
return True
if exists(class_list) and any(
isinstance(child, module_class) for module_class in class_list
):
return True
return False
def retrieve_target_modules():
from refnet.util import get_obj_from_str
target_class_list = [get_obj_from_str(t) for t in target_class] if exists(target_class) else None
modules = []
for name, module in root_module.named_modules():
for key, child in module._modules.items():
full_path = name + '.' + key if name else key
if check_condition(full_path, child, target_class_list):
modules.append((module, child, key, full_path))
return modules
modules: list[Union[nn.Module]] = []
retrieved_modules = retrieve_target_modules()
for parent, child, child_name, full_path in retrieved_modules:
# Check if this layer already has a MultiLoraInjectedLinear
if full_path in self.multi_lora_layers:
# Add LoRA to existing MultiLoraInjectedLinear
multi_lora_layer = self.multi_lora_layers[full_path]
multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
else:
# Check if the current layer is already a MultiLoraInjectedLinear
if isinstance(child, MultiLoraInjectedLinear):
child.add_lora_adapter(label, r, scale, dropout_p)
self.multi_lora_layers[full_path] = child
else:
# Replace with MultiLoraInjectedLinear and add first LoRA
multi_lora_layer = MultiLoraInjectedLinear(
in_features=child.weight.shape[1],
out_features=child.weight.shape[0],
bias=exists(child.bias),
)
multi_lora_layer.add_lora_adapter(label, r, scale, dropout_p)
parent._modules[child_name] = multi_lora_layer
self.multi_lora_layers[full_path] = multi_lora_layer
if split_forward:
parent.masked_forward = dual_lora_forward.__get__(parent, parent.__class__)
else:
parent.masked_forward = lora_forward.__get__(parent, parent.__class__)
parent.use_lora = True
parent.switch_lora = switch_lora.__get__(parent, parent.__class__)
modules.append(parent)
self.modules[label] = modules
print(f"Activated {label} lora with {len(self.multi_lora_layers)} layers")
return self.multi_lora_layers, modules
def get_trainable_layers(self, label = None):
"""Get all trainable layers, optionally filtered by label."""
layers = []
for lora_layer in self.multi_lora_layers.values():
layers += lora_layer.get_trainable_layers(label)
return layers
def switch_lora(self, mode, label = None):
if exists(label):
for layer in self.multi_lora_layers.values():
layer.set_lora_active(mode, label)
for module in self.modules[label]:
module.use_lora = mode
else:
for layer in self.multi_lora_layers.values():
for lora_label in layer.lora_adapters.keys():
layer.set_lora_active(mode, lora_label)
for modules in self.modules.values():
for module in modules:
module.use_lora = mode
def adjust_lora_scales(self, scale, label = None):
if exists(label):
for layer in self.multi_lora_layers.values():
layer.set_lora_scale(scale, label)
else:
for layer in self.multi_lora_layers.values():
for lora_label in layer.lora_adapters.keys():
layer.set_lora_scale(scale, lora_label)
def merge_lora(self, labels = None):
if labels is None:
labels = list(self.modules.keys())
elif isinstance(labels, str):
labels = [labels]
for layer in self.multi_lora_layers.values():
layer.merge_lora_weights(labels)
def recover_lora(self):
for layer in self.multi_lora_layers.values():
layer.recover_original_weight()
def get_lora_info(self):
"""Get information about all LoRA adapters."""
info = {}
for path, layer in self.multi_lora_layers.items():
info[path] = {
'labels': list(layer.lora_adapters.keys()),
'active': {label: active for label, active in layer.active_loras.items()},
'scales': layer.lora_scales.copy()
}
return info