dagloop5 commited on
Commit
5490bd4
·
verified ·
1 Parent(s): 3cc8e3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -1
app.py CHANGED
@@ -456,6 +456,32 @@ def load_lora_into_cache(lora_path: str) -> StateDict:
456
  print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
457
  return state_dict
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  def build_fused_state_dict(
461
  base_transformer,
@@ -482,10 +508,13 @@ def build_fused_state_dict(
482
  progress_callback(0.1, "Loading LoRA state dicts into memory")
483
 
484
  # Step 1: Load all LoRA state dicts (uses cache after first load)
 
485
  lora_sd_with_strengths = []
486
  for lora_path, strength in lora_configs:
487
  sd = load_lora_into_cache(lora_path)
488
- lora_sd_with_strengths.append(LoraStateDictWithStrength(sd, float(strength)))
 
 
489
 
490
  if progress_callback:
491
  progress_callback(0.3, "Extracting base transformer state dict")
 
456
  print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
457
  return state_dict
458
 
459
+ def _rename_lora_keys_for_base_model(lora_sd: StateDict) -> StateDict:
460
+ """
461
+ Rename LoRA state dict keys to match the base model's key format.
462
+
463
+ LoRA files typically use ComfyUI-style keys like:
464
+ "diffusion_model.layer1.blocks.0.lora_A.weight"
465
+
466
+ But the base model uses keys like:
467
+ "layer1.blocks.0.lora_A.weight"
468
+
469
+ This function strips the "diffusion_model." prefix to match them.
470
+ """
471
+ renamed_sd = {}
472
+ for key, tensor in lora_sd.sd.items():
473
+ if key.startswith("diffusion_model."):
474
+ new_key = key[len("diffusion_model."):]
475
+ else:
476
+ new_key = key
477
+ renamed_sd[new_key] = tensor
478
+
479
+ return StateDict(
480
+ sd=renamed_sd,
481
+ device=lora_sd.device,
482
+ size=lora_sd.size,
483
+ dtype=lora_sd.dtype
484
+ )
485
 
486
  def build_fused_state_dict(
487
  base_transformer,
 
508
  progress_callback(0.1, "Loading LoRA state dicts into memory")
509
 
510
  # Step 1: Load all LoRA state dicts (uses cache after first load)
511
+ # and rename keys to match base model's key format
512
  lora_sd_with_strengths = []
513
  for lora_path, strength in lora_configs:
514
  sd = load_lora_into_cache(lora_path)
515
+ # Apply key renaming to strip "diffusion_model." prefix
516
+ sd_renamed = _rename_lora_keys_for_base_model(sd)
517
+ lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
518
 
519
  if progress_callback:
520
  progress_callback(0.3, "Extracting base transformer state dict")