dagloop5 commited on
Commit
94d835a
·
verified ·
1 Parent(s): 50ae02c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -12
app.py CHANGED
@@ -472,18 +472,56 @@ def prepare_lora_cache(
472
  """Build cached LoRA state for single transformer."""
473
  global pipeline
474
 
 
475
  progress(0.05, desc="Preparing LoRA cache...")
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  key_str = f"{checkpoint_path}:{distilled_strength}:{pose_strength}:{general_strength}:{motion_strength}:{dreamlay_strength}:{mself_strength}:{dramatic_strength}:{fluid_strength}:{liquid_strength}:{demopose_strength}:{voice_strength}:{realism_strength}:{transition_strength}"
478
  key = hashlib.sha256(key_str.encode()).hexdigest()
479
 
480
  cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
 
 
481
 
482
  if cache_path.exists():
 
483
  progress(0.20, desc="Loading cached LoRA state...")
484
  state = load_file(str(cache_path))
 
485
  pipeline.apply_cached_lora_state(state)
486
- return f"Loaded cached LoRA state: {cache_path.name}"
 
 
 
 
 
 
 
487
 
488
  entries = [
489
  (distilled_lora_path, distilled_strength),
@@ -501,34 +539,55 @@ def prepare_lora_cache(
501
  (transition_lora_path, transition_strength),
502
  ]
503
 
504
- loras = [
505
  LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
506
  for path, strength in entries
507
  if path is not None and float(strength) != 0.0
508
  ]
509
 
 
 
510
  progress(0.35, desc="Building fused state (CPU)...")
 
 
 
 
511
  tmp_ledger = pipeline.model_ledger.__class__(
512
  dtype=torch.bfloat16,
513
  device=torch.device("cpu"),
514
  checkpoint_path=str(checkpoint_path),
515
  spatial_upsampler_path=str(spatial_upsampler_path),
516
  gemma_root_path=str(gemma_root),
517
- loras=tuple(loras),
518
  quantization=None,
519
  )
 
 
 
520
  transformer = tmp_ledger.transformer()
 
 
 
 
521
  state = {k: v.detach().cpu().contiguous() for k, v in transformer.state_dict().items()}
 
 
 
522
  save_file(state, str(cache_path))
 
523
 
 
524
  del transformer, tmp_ledger
525
  gc.collect()
 
526
 
 
527
  progress(0.90, desc="Applying LoRA state to pipeline...")
528
  pipeline.apply_cached_lora_state(state)
529
 
530
  progress(1.0, desc="Done!")
531
- return f"Built and cached LoRA state: {cache_path.name}"
 
532
 
533
  # =============================================================================
534
  # LoRA State Application (called BEFORE pipeline generation)
@@ -539,23 +598,40 @@ def apply_prepared_lora_state_to_pipeline():
539
  Apply the prepared LoRA state from pipeline._cached_state to the preloaded
540
  transformer. This should be called BEFORE pipeline generation, not during.
541
  """
 
 
542
  if pipeline._cached_state is None:
543
  print("[LoRA] No prepared LoRA state available; skipping.")
 
544
  return False
545
 
546
  try:
547
  existing_transformer = _transformer # The preloaded transformer from globals
 
 
 
 
 
 
 
548
  with torch.no_grad():
549
- missing, unexpected = existing_transformer.load_state_dict(pipeline._cached_state, strict=False)
550
- if missing:
551
- print(f"[LoRA] load_state_dict mismatch: missing={len(missing)} keys")
552
- if unexpected:
553
- print(f"[LoRA] load_state_dict mismatch: unexpected={len(unexpected)} keys")
 
 
 
 
 
 
554
 
555
- print("[LoRA] Prepared LoRA state applied to the pipeline.")
556
  return True
557
  except Exception as e:
558
- print(f"[LoRA] Failed to apply LoRA state: {type(e).__name__}: {e}")
 
559
  return False
560
 
561
  # =============================================================================
@@ -759,7 +835,6 @@ def generate_video(
759
  width=width,
760
  num_frames=num_frames,
761
  frame_rate=DEFAULT_FRAME_RATE,
762
- num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps,
763
  video_guider_params=video_guider_params,
764
  audio_guider_params=audio_guider_params,
765
  images=images,
 
472
  """Build cached LoRA state for single transformer."""
473
  global pipeline
474
 
475
+ print("[LoRA] === Starting LoRA Cache Preparation ===")
476
  progress(0.05, desc="Preparing LoRA cache...")
477
 
478
+ # Validate all LoRA files exist
479
+ print("[LoRA] Validating LoRA file paths...")
480
+ lora_files = [
481
+ ("Distilled", distilled_lora_path, distilled_strength),
482
+ ("Pose", pose_lora_path, pose_strength),
483
+ ("General", general_lora_path, general_strength),
484
+ ("Motion", motion_lora_path, motion_strength),
485
+ ("Dreamlay", dreamlay_lora_path, dreamlay_strength),
486
+ ("Mself", mself_lora_path, mself_strength),
487
+ ("Dramatic", dramatic_lora_path, dramatic_strength),
488
+ ("Fluid", fluid_lora_path, fluid_strength),
489
+ ("Liquid", liquid_lora_path, liquid_strength),
490
+ ("Demopose", demopose_lora_path, demopose_strength),
491
+ ("Voice", voice_lora_path, voice_strength),
492
+ ("Realism", realism_lora_path, realism_strength),
493
+ ("Transition", transition_lora_path, transition_strength),
494
+ ]
495
+
496
+ active_loras = []
497
+ for name, path, strength in lora_files:
498
+ if path is not None and float(strength) != 0.0:
499
+ active_loras.append((name, path, strength))
500
+ print(f"[LoRA] - {name}: strength={strength}")
501
+
502
+ print(f"[LoRA] Active LoRAs: {len(active_loras)}")
503
+
504
  key_str = f"{checkpoint_path}:{distilled_strength}:{pose_strength}:{general_strength}:{motion_strength}:{dreamlay_strength}:{mself_strength}:{dramatic_strength}:{fluid_strength}:{liquid_strength}:{demopose_strength}:{voice_strength}:{realism_strength}:{transition_strength}"
505
  key = hashlib.sha256(key_str.encode()).hexdigest()
506
 
507
  cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
508
+ print(f"[LoRA] Cache key: {key[:16]}...")
509
+ print(f"[LoRA] Cache path: {cache_path}")
510
 
511
  if cache_path.exists():
512
+ print("[LoRA] Loading from existing cache...")
513
  progress(0.20, desc="Loading cached LoRA state...")
514
  state = load_file(str(cache_path))
515
+ print(f"[LoRA] Loaded state dict with {len(state)} keys, size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
516
  pipeline.apply_cached_lora_state(state)
517
+ print("[LoRA] State applied to pipeline._cached_state")
518
+ print("[LoRA] === LoRA Cache Preparation Complete ===")
519
+ return f"Loaded cached LoRA state: {cache_path.name} ({len(state)} keys)"
520
+
521
+ if not active_loras:
522
+ print("[LoRA] No non-zero LoRA strengths selected; nothing to prepare.")
523
+ print("[LoRA] === LoRA Cache Preparation Complete (no LoRAs) ===")
524
+ return "No non-zero LoRA strengths selected; nothing to prepare."
525
 
526
  entries = [
527
  (distilled_lora_path, distilled_strength),
 
539
  (transition_lora_path, transition_strength),
540
  ]
541
 
542
+ loras_for_builder = [
543
  LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
544
  for path, strength in entries
545
  if path is not None and float(strength) != 0.0
546
  ]
547
 
548
+ print(f"[LoRA] Building fused state on CPU with {len(loras_for_builder)} LoRAs...")
549
+ print("[LoRA] This may take several minutes (do not close the Space)...")
550
  progress(0.35, desc="Building fused state (CPU)...")
551
+
552
+ import time
553
+ start_time = time.time()
554
+
555
  tmp_ledger = pipeline.model_ledger.__class__(
556
  dtype=torch.bfloat16,
557
  device=torch.device("cpu"),
558
  checkpoint_path=str(checkpoint_path),
559
  spatial_upsampler_path=str(spatial_upsampler_path),
560
  gemma_root_path=str(gemma_root),
561
+ loras=tuple(loras_for_builder),
562
  quantization=None,
563
  )
564
+ print(f"[LoRA] Temporary ledger created in {time.time() - start_time:.1f}s")
565
+
566
+ print("[LoRA] Loading transformer with LoRAs applied...")
567
  transformer = tmp_ledger.transformer()
568
+ print(f"[LoRA] Transformer loaded in {time.time() - start_time:.1f}s")
569
+
570
+ print("[LoRA] Extracting state dict...")
571
+ progress(0.70, desc="Extracting fused stateDict")
572
  state = {k: v.detach().cpu().contiguous() for k, v in transformer.state_dict().items()}
573
+ print(f"[LoRA] State dict extracted: {len(state)} keys")
574
+
575
+ print(f"[LoRA] Saving to cache: {cache_path}")
576
  save_file(state, str(cache_path))
577
+ print(f"[LoRA] Cache saved, size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
578
 
579
+ print("[LoRA] Cleaning up temporary ledger...")
580
  del transformer, tmp_ledger
581
  gc.collect()
582
+ print(f"[LoRA] Cleanup complete in {time.time() - start_time:.1f}s total")
583
 
584
+ print("[LoRA] Applying state to pipeline._cached_state...")
585
  progress(0.90, desc="Applying LoRA state to pipeline...")
586
  pipeline.apply_cached_lora_state(state)
587
 
588
  progress(1.0, desc="Done!")
589
+ print("[LoRA] === LoRA Cache Preparation Complete ===")
590
+ return f"Built and cached LoRA state: {cache_path.name} ({len(state)} keys, {time.time() - start_time:.1f}s)"
591
 
592
  # =============================================================================
593
  # LoRA State Application (called BEFORE pipeline generation)
 
598
  Apply the prepared LoRA state from pipeline._cached_state to the preloaded
599
  transformer. This should be called BEFORE pipeline generation, not during.
600
  """
601
+ print("[LoRA] === Applying LoRA State to Transformer ===")
602
+
603
  if pipeline._cached_state is None:
604
  print("[LoRA] No prepared LoRA state available; skipping.")
605
+ print("[LoRA] === LoRA Application Complete (no state) ===")
606
  return False
607
 
608
  try:
609
  existing_transformer = _transformer # The preloaded transformer from globals
610
+ state = pipeline._cached_state
611
+ print(f"[LoRA] Applying state dict with {len(state)} keys...")
612
+ print(f"[LoRA] State dict size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
613
+
614
+ import time
615
+ start_time = time.time()
616
+
617
  with torch.no_grad():
618
+ missing, unexpected = existing_transformer.load_state_dict(state, strict=False)
619
+
620
+ print(f"[LoRA] load_state_dict completed in {time.time() - start_time:.1f}s")
621
+
622
+ if missing:
623
+ print(f"[LoRA] WARNING: {len(missing)} keys missing from state dict")
624
+ if unexpected:
625
+ print(f"[LoRA] WARNING: {len(unexpected)} unexpected keys in state dict")
626
+
627
+ if not missing and not unexpected:
628
+ print("[LoRA] State dict loaded successfully with no mismatches!")
629
 
630
+ print("[LoRA] === LoRA Application Complete (success) ===")
631
  return True
632
  except Exception as e:
633
+ print(f"[LoRA] FAILED to apply LoRA state: {type(e).__name__}: {e}")
634
+ print("[LoRA] === LoRA Application Complete (FAILED) ===")
635
  return False
636
 
637
  # =============================================================================
 
835
  width=width,
836
  num_frames=num_frames,
837
  frame_rate=DEFAULT_FRAME_RATE,
 
838
  video_guider_params=video_guider_params,
839
  audio_guider_params=audio_guider_params,
840
  images=images,