dagloop5 commited on
Commit
2b4573b
·
verified ·
1 Parent(s): dd5f993

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -30
app.py CHANGED
@@ -460,39 +460,41 @@ def _rename_lora_keys_for_base_model(lora_sd: StateDict, base_keys: set[str]) ->
460
  """
461
  Rename LoRA state dict keys to match the base model's key format.
462
 
463
- LoRA files from different sources use various key structures.
464
- We need to normalize them to match the base model.
465
 
466
- Common patterns:
467
- - "diffusion_model.transformer.xxx" -> "xxx" (ComfyUI export)
468
- - "diffusion_model.xxx" -> "xxx" (strip prefix)
469
- - "transformer.xxx" -> "xxx" (strip module name)
 
 
 
470
  """
471
  renamed_sd = {}
 
 
 
 
 
472
  for key, tensor in lora_sd.sd.items():
473
  new_key = key
474
 
475
- # Strip "diffusion_model." prefix
476
  if new_key.startswith("diffusion_model."):
477
  new_key = new_key[len("diffusion_model."):]
478
 
479
- # If key still doesn't match any base keys, try stripping "transformer." prefix
480
- if new_key not in base_keys and new_key.startswith("transformer."):
481
- new_key = new_key[len("transformer."):]
482
 
483
- # If it's a LoRA key, try to match the base weight key
484
- if ".lora_A.weight" in new_key or ".lora_B.weight" in new_key:
485
- base_path = new_key.replace(".lora_A.weight", ".weight").replace(".lora_B.weight", ".weight")
486
- if base_path not in base_keys:
487
- parts = base_path.split(".")
488
- if parts[0] == "transformer":
489
- alt_path = ".".join(parts[1:])
490
- if alt_path in base_keys:
491
- new_key = new_key.replace(base_path, alt_path)
492
- elif parts[0] == "layer" and len(parts) > 1:
493
- alt_path = ".".join(parts[1:])
494
- if alt_path in base_keys:
495
- new_key = new_key.replace(base_path, alt_path)
496
 
497
  renamed_sd[new_key] = tensor
498
 
@@ -542,14 +544,18 @@ def build_fused_state_dict(
542
  # Show before/after for first few keys
543
  original_keys = list(sd.sd.keys())[:3]
544
  renamed_keys = list(sd_renamed.sd.keys())[:3]
545
- print(f"[LoRA DEBUG] Before rename: {original_keys}")
546
- print(f"[LoRA DEBUG] After rename: {renamed_keys}")
547
 
548
- lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
549
- print(f"[LoRA] Loaded and renamed {len(sd_renamed.sd)} keys from {os.path.basename(lora_path)}")
550
-
551
- # Debug: Check LoRA key matching
552
- print(f"[LoRA DEBUG] Sample base keys: {list(base_key_set)[:10]}")
 
 
 
 
553
 
554
  if progress_callback:
555
  progress_callback(0.3, "Extracting base transformer state dict")
 
460
  """
461
  Rename LoRA state dict keys to match the base model's key format.
462
 
463
+ LoRA keys: transformer_blocks.0.attn1.to_k.lora_A.weight
464
+ Base keys: velocity_model.transformer_blocks.4.audio_attn1.to_out.0.weight
465
 
466
+ We need to:
467
+ 1. Add 'velocity_model.' prefix
468
+ 2. Match block indices (LoRA block 0 might correspond to base block 4, etc.)
469
+
470
+ Actually, LTX LoRA files typically have entries for ALL blocks.
471
+ The LoRA files from ComfyUI exports may have different structures.
472
+ We need to normalize to: velocity_model.transformer_blocks.X.attn.to_weight
473
  """
474
  renamed_sd = {}
475
+
476
+ # First pass: identify the pattern
477
+ # LoRA keys like: transformer_blocks.0.attn1.to_k.lora_A.weight
478
+ # Need to become: velocity_model.transformer_blocks.0.attn1.to_k.weight
479
+
480
  for key, tensor in lora_sd.sd.items():
481
  new_key = key
482
 
483
+ # Strip "diffusion_model." prefix if present
484
  if new_key.startswith("diffusion_model."):
485
  new_key = new_key[len("diffusion_model."):]
486
 
487
+ # Now add "velocity_model." prefix
488
+ if not new_key.startswith("velocity_model."):
489
+ new_key = "velocity_model." + new_key
490
 
491
+ # Convert LoRA key suffix to base weight suffix
492
+ # .lora_A.weight -> .weight
493
+ # .lora_B.weight -> .weight
494
+ if new_key.endswith(".lora_A.weight"):
495
+ new_key = new_key[:-len(".lora_A.weight")] + ".weight"
496
+ elif new_key.endswith(".lora_B.weight"):
497
+ new_key = new_key[:-len(".lora_B.weight")] + ".weight"
 
 
 
 
 
 
498
 
499
  renamed_sd[new_key] = tensor
500
 
 
544
  # Show before/after for first few keys
545
  original_keys = list(sd.sd.keys())[:3]
546
  renamed_keys = list(sd_renamed.sd.keys())[:3]
547
+ print(f"[LoRA DEBUG] Before: {original_keys}")
548
+ print(f"[LoRA DEBUG] After: {renamed_keys}")
549
 
550
+ # Check if renamed keys exist in base model
551
+ sample_renamed = list(sd_renamed.sd.keys())[0]
552
+ exists_in_base = sample_renamed in base_key_set
553
+ print(f"[LoRA DEBUG] Sample renamed key exists in base? {exists_in_base}")
554
+ if not exists_in_base:
555
+ print(f"[LoRA DEBUG] Checking for similar base keys...")
556
+ prefix = sample_renamed.rsplit(".", 1)[0] # Remove .weight suffix
557
+ similar = [k for k in base_key_set if k.startswith(prefix[:30])]
558
+ print(f"[LoRA DEBUG] Similar base keys: {similar[:3]}")
559
 
560
  if progress_callback:
561
  progress_callback(0.3, "Extracting base transformer state dict")