Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -456,6 +456,32 @@ def load_lora_into_cache(lora_path: str) -> StateDict:
|
|
| 456 |
print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
|
| 457 |
return state_dict
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
def build_fused_state_dict(
|
| 461 |
base_transformer,
|
|
@@ -482,10 +508,13 @@ def build_fused_state_dict(
|
|
| 482 |
progress_callback(0.1, "Loading LoRA state dicts into memory")
|
| 483 |
|
| 484 |
# Step 1: Load all LoRA state dicts (uses cache after first load)
|
|
|
|
| 485 |
lora_sd_with_strengths = []
|
| 486 |
for lora_path, strength in lora_configs:
|
| 487 |
sd = load_lora_into_cache(lora_path)
|
| 488 |
-
|
|
|
|
|
|
|
| 489 |
|
| 490 |
if progress_callback:
|
| 491 |
progress_callback(0.3, "Extracting base transformer state dict")
|
|
|
|
| 456 |
print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
|
| 457 |
return state_dict
|
| 458 |
|
| 459 |
+
def _rename_lora_keys_for_base_model(lora_sd: StateDict) -> StateDict:
|
| 460 |
+
"""
|
| 461 |
+
Rename LoRA state dict keys to match the base model's key format.
|
| 462 |
+
|
| 463 |
+
LoRA files typically use ComfyUI-style keys like:
|
| 464 |
+
"diffusion_model.layer1.blocks.0.lora_A.weight"
|
| 465 |
+
|
| 466 |
+
But the base model uses keys like:
|
| 467 |
+
"layer1.blocks.0.lora_A.weight"
|
| 468 |
+
|
| 469 |
+
This function strips the "diffusion_model." prefix to match them.
|
| 470 |
+
"""
|
| 471 |
+
renamed_sd = {}
|
| 472 |
+
for key, tensor in lora_sd.sd.items():
|
| 473 |
+
if key.startswith("diffusion_model."):
|
| 474 |
+
new_key = key[len("diffusion_model."):]
|
| 475 |
+
else:
|
| 476 |
+
new_key = key
|
| 477 |
+
renamed_sd[new_key] = tensor
|
| 478 |
+
|
| 479 |
+
return StateDict(
|
| 480 |
+
sd=renamed_sd,
|
| 481 |
+
device=lora_sd.device,
|
| 482 |
+
size=lora_sd.size,
|
| 483 |
+
dtype=lora_sd.dtype
|
| 484 |
+
)
|
| 485 |
|
| 486 |
def build_fused_state_dict(
|
| 487 |
base_transformer,
|
|
|
|
| 508 |
progress_callback(0.1, "Loading LoRA state dicts into memory")
|
| 509 |
|
| 510 |
# Step 1: Load all LoRA state dicts (uses cache after first load)
|
| 511 |
+
# and rename keys to match base model's key format
|
| 512 |
lora_sd_with_strengths = []
|
| 513 |
for lora_path, strength in lora_configs:
|
| 514 |
sd = load_lora_into_cache(lora_path)
|
| 515 |
+
# Apply key renaming to strip "diffusion_model." prefix
|
| 516 |
+
sd_renamed = _rename_lora_keys_for_base_model(sd)
|
| 517 |
+
lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
|
| 518 |
|
| 519 |
if progress_callback:
|
| 520 |
progress_callback(0.3, "Extracting base transformer state dict")
|