dagloop5 commited on
Commit
7b5d90c
·
verified ·
1 Parent(s): 0374898

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -120
app.py CHANGED
@@ -75,6 +75,8 @@ from ltx_pipelines.utils.helpers import (
75
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
76
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
77
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
 
 
78
 
79
  # Force-patch xformers attention into the LTX attention module.
80
  from ltx_core.model.transformer import attention as _attn_mod
@@ -339,10 +341,6 @@ LORA_CACHE_DIR = Path("lora_cache")
339
  LORA_CACHE_DIR.mkdir(exist_ok=True)
340
  current_lora_key: str | None = None
341
 
342
- PENDING_LORA_KEY: str | None = None
343
- PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
344
- PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
345
-
346
  weights_dir = Path("weights")
347
  weights_dir.mkdir(exist_ok=True)
348
  checkpoint_path = hf_hub_download(
@@ -419,7 +417,110 @@ def _make_lora_key(pose_strength: float, general_strength: float, motion_strengt
419
  return key, key_str
420
 
421
 
422
- def prepare_lora_cache(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  pose_strength: float,
424
  general_strength: float,
425
  motion_strength: float,
@@ -435,128 +536,130 @@ def prepare_lora_cache(
435
  progress=gr.Progress(track_tqdm=True),
436
  ):
437
  """
438
- CPU-only step:
439
- - checks cache
440
- - loads cached fused transformer state_dict, or
441
- - builds fused transformer on CPU and saves it
442
- The resulting state_dict is stored in memory and can be applied later.
 
 
 
 
 
443
  """
444
- global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
445
-
446
- ledger = pipeline.model_ledger
447
- key, _ = _make_lora_key(pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength)
448
- cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
449
-
450
- progress(0.05, desc="Preparing LoRA state")
451
- if cache_path.exists():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  try:
453
- progress(0.20, desc="Loading cached fused state")
454
- state = load_file(str(cache_path))
455
- PENDING_LORA_KEY = key
456
- PENDING_LORA_STATE = state
457
- PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
458
- return PENDING_LORA_STATUS
 
 
459
  except Exception as e:
460
- print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
461
-
462
- entries = [
463
- (pose_lora_path, round(float(pose_strength), 2)),
464
- (general_lora_path, round(float(general_strength), 2)),
465
- (motion_lora_path, round(float(motion_strength), 2)),
466
- (dreamlay_lora_path, round(float(dreamlay_strength), 2)),
467
- (mself_lora_path, round(float(mself_strength), 2)),
468
- (dramatic_lora_path, round(float(dramatic_strength), 2)),
469
- (fluid_lora_path, round(float(fluid_strength), 2)),
470
- (liquid_lora_path, round(float(liquid_strength), 2)),
471
- (demopose_lora_path, round(float(demopose_strength), 2)),
472
- (voice_lora_path, round(float(voice_strength), 2)),
473
- (realism_lora_path, round(float(realism_strength), 2)),
474
- (transition_lora_path, round(float(transition_strength), 2)),
475
- ]
476
- loras_for_builder = [
477
- LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
478
- for path, strength in entries
479
- if path is not None and float(strength) != 0.0
480
- ]
481
-
482
- if not loras_for_builder:
483
- PENDING_LORA_KEY = None
484
- PENDING_LORA_STATE = None
485
- PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
486
- return PENDING_LORA_STATUS
487
-
488
- tmp_ledger = None
489
- new_transformer_cpu = None
490
  try:
491
- progress(0.35, desc="Building fused CPU transformer")
492
- tmp_ledger = pipeline.model_ledger.__class__(
493
- dtype=ledger.dtype,
494
- device=torch.device("cpu"),
495
- checkpoint_path=str(checkpoint_path),
496
- spatial_upsampler_path=str(spatial_upsampler_path),
497
- gemma_root_path=str(gemma_root),
498
- loras=tuple(loras_for_builder),
499
- quantization=getattr(ledger, "quantization", None),
500
- )
501
- new_transformer_cpu = tmp_ledger.transformer()
502
-
503
- progress(0.70, desc="Extracting fused state_dict")
504
- state = {
505
- k: v.detach().cpu().contiguous()
506
- for k, v in new_transformer_cpu.state_dict().items()
507
- }
508
- save_file(state, str(cache_path))
509
-
510
- PENDING_LORA_KEY = key
511
- PENDING_LORA_STATE = state
512
- PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
513
- return PENDING_LORA_STATUS
514
-
515
  except Exception as e:
516
  import traceback
517
- print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
518
  print(traceback.format_exc())
519
- PENDING_LORA_KEY = None
520
- PENDING_LORA_STATE = None
521
- PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
522
- return PENDING_LORA_STATUS
523
-
524
- finally:
525
- try:
526
- del new_transformer_cpu
527
- except Exception:
528
- pass
529
  try:
530
- del tmp_ledger
 
 
 
531
  except Exception:
532
  pass
533
- gc.collect()
534
-
535
-
536
- def apply_prepared_lora_state_to_pipeline():
537
- """
538
- Fast step: copy the already prepared CPU state into the live transformer.
539
- This is the only part that should remain near generation time.
540
- """
541
- global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
542
-
543
- if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
544
- print("[LoRA] No prepared LoRA state available; skipping.")
545
- return False
546
-
547
- if current_lora_key == PENDING_LORA_KEY:
548
- print("[LoRA] Prepared LoRA state already active; skipping.")
549
- return True
550
-
551
- existing_transformer = _transformer
552
- with torch.no_grad():
553
- missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
554
- if missing or unexpected:
555
- print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
556
-
557
- current_lora_key = PENDING_LORA_KEY
558
- print("[LoRA] Prepared LoRA state applied to the pipeline.")
559
- return True
560
 
561
  # Preload all models for ZeroGPU tensor packing.
562
  print("Preloading all models (including Gemma and audio components)...")
@@ -770,8 +873,6 @@ def generate_video(
770
 
771
  log_memory("before pipeline call")
772
 
773
- apply_prepared_lora_state_to_pipeline()
774
-
775
  video, audio = pipeline(
776
  prompt=prompt,
777
  negative_prompt=negative_prompt,
@@ -956,7 +1057,7 @@ with gr.Blocks(title="LTX-2.3 Distilled with LoRAs, Negative Prompting, and Adva
956
  high_res.change(fn=on_highres_toggle, inputs=[first_image, last_image, high_res], outputs=[width, height])
957
 
958
  prepare_lora_btn.click(
959
- fn=prepare_lora_cache,
960
  inputs=[pose_strength, general_strength, motion_strength, dreamlay_strength,
961
  mself_strength, dramatic_strength, fluid_strength, liquid_strength,
962
  demopose_strength, voice_strength, realism_strength, transition_strength],
 
75
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
76
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
77
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
78
+ from ltx_core.loader.module_ops.apply_loras import apply_loras
79
+ from safetensors import safe_open
80
 
81
  # Force-patch xformers attention into the LTX attention module.
82
  from ltx_core.model.transformer import attention as _attn_mod
 
341
  LORA_CACHE_DIR.mkdir(exist_ok=True)
342
  current_lora_key: str | None = None
343
 
 
 
 
 
344
  weights_dir = Path("weights")
345
  weights_dir.mkdir(exist_ok=True)
346
  checkpoint_path = hf_hub_download(
 
417
  return key, key_str
418
 
419
 
420
+ # =============================================================================
421
+ # LoRA Cache (In-Memory) - Ultra-Fast In-Place Application
422
+ # =============================================================================
423
+
424
+ # In-memory caches to avoid redundant disk I/O
425
+ LORA_SD_CACHE: dict[str, StateDict] = {} # lora_path -> loaded StateDict
426
+ FUSED_CACHE: dict[str, dict] = {} # cache key -> fused state dict (CPU)
427
+ current_lora_key: str | None = None
428
+
429
+
430
+ def load_lora_into_cache(lora_path: str) -> StateDict:
431
+ """
432
+ Load a LoRA safetensor file into a cached StateDict.
433
+ Subsequent calls return the cached version instantly.
434
+
435
+ This replaces repeated disk reads with a one-time load + memory cache.
436
+ """
437
+ if lora_path in LORA_SD_CACHE:
438
+ return LORA_SD_CACHE[lora_path]
439
+
440
+ print(f"[LoRA] Loading {os.path.basename(lora_path)} into memory cache...")
441
+
442
+ # Use safe_open for memory-efficient streaming reads of large files
443
+ tensors = {}
444
+ with safe_open(lora_path, framework="safetensors") as f:
445
+ for key in f.keys():
446
+ tensors[key] = f.get_tensor(key)
447
+
448
+ state_dict = StateDict(
449
+ sd=tensors,
450
+ device=torch.device("cpu"),
451
+ size=sum(t.nbytes for t in tensors.values()),
452
+ dtype=set(t.dtype for t in tensors.values())
453
+ )
454
+
455
+ LORA_SD_CACHE[lora_path] = state_dict
456
+ print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
457
+ return state_dict
458
+
459
+
460
+ def build_fused_state_dict(
461
+ base_transformer,
462
+ lora_configs: list[tuple[str, float]],
463
+ progress_callback=None
464
+ ) -> dict[str, torch.Tensor]:
465
+ """
466
+ Fuse multiple LoRAs into a single state dict ready for load_state_dict().
467
+ Uses LTX's apply_loras function which handles FP8 quantization correctly.
468
+
469
+ Args:
470
+ base_transformer: The preloaded transformer model
471
+ lora_configs: List of (lora_path, strength) tuples for non-zero LoRAs
472
+ progress_callback: Optional callback(step, desc) for progress updates
473
+
474
+ Returns:
475
+ Dictionary of fused weights ready for load_state_dict()
476
+ """
477
+ if not lora_configs:
478
+ # No LoRAs - return base transformer state dict
479
+ return {k: v.clone() for k, v in base_transformer.state_dict().items()}
480
+
481
+ if progress_callback:
482
+ progress_callback(0.1, "Loading LoRA state dicts into memory")
483
+
484
+ # Step 1: Load all LoRA state dicts (uses cache after first load)
485
+ lora_sd_with_strengths = []
486
+ for lora_path, strength in lora_configs:
487
+ sd = load_lora_into_cache(lora_path)
488
+ lora_sd_with_strengths.append(LoraStateDictWithStrength(sd, float(strength)))
489
+
490
+ if progress_callback:
491
+ progress_callback(0.3, "Extracting base transformer state dict")
492
+
493
+ # Step 2: Get base transformer state dict (already in memory from preloading!)
494
+ base_dict = base_transformer.state_dict()
495
+ base_sd = StateDict(
496
+ sd={k: v.detach().cpu().contiguous() for k, v in base_dict.items()},
497
+ device=torch.device("cpu"),
498
+ size=sum(v.nbytes for v in base_dict.values()),
499
+ dtype=set(v.dtype for v in base_dict.values())
500
+ )
501
+
502
+ if progress_callback:
503
+ progress_callback(0.5, "Fusing LoRAs with base weights (CPU)")
504
+
505
+ # Step 3: Fuse using LTX's apply_loras function
506
+ # This function handles:
507
+ # - FP8 quantized weights (_fuse_delta_with_scaled_fp8 / _fuse_delta_with_cast_fp8)
508
+ # - BFloat16 weights (_fuse_delta_with_bfloat16)
509
+ # - Proper delta accumulation for multiple LoRAs
510
+ fused_sd = apply_loras(
511
+ model_sd=base_sd,
512
+ lora_sd_and_strengths=lora_sd_with_strengths,
513
+ dtype=torch.bfloat16
514
+ )
515
+
516
+ if progress_callback:
517
+ progress_callback(0.9, "Extracting fused state dict")
518
+
519
+ # Step 4: Return the fused state dict as a plain dict (for load_state_dict)
520
+ return fused_sd.sd
521
+
522
+
523
+ def on_prepare_loras_click(
524
  pose_strength: float,
525
  general_strength: float,
526
  motion_strength: float,
 
536
  progress=gr.Progress(track_tqdm=True),
537
  ):
538
  """
539
+ Called when user clicks the 'Prepare LoRA Cache' button.
540
+
541
+ This function:
542
+ 1. Checks if LoRA combination is already applied (skip if so)
543
+ 2. Checks in-memory FUSED_CACHE (skip building if cached)
544
+ 3. Loads LoRA files into cache (reuses LORA_SD_CACHE on subsequent calls)
545
+ 4. Builds fused state dict if needed (only new combinations)
546
+ 5. Applies to the preloaded transformer
547
+
548
+ Only runs on button click, NOT on slider change.
549
  """
550
+ global current_lora_key, FUSED_CACHE
551
+
552
+ # Compute the cache key for this combination of strengths
553
+ key, _ = _make_lora_key(
554
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
555
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
556
+ demopose_strength, voice_strength, realism_strength, transition_strength
557
+ )
558
+
559
+ # Already applied with these exact strengths? Nothing to do.
560
+ if current_lora_key == key:
561
+ return f"✓ LoRAs already applied with current strengths"
562
+
563
+ progress(0.0, desc="Starting LoRA preparation")
564
+
565
+ # Build the list of active (non-zero) LoRAs
566
+ active_loras = []
567
+ lora_entries = [
568
+ (pose_lora_path, pose_strength, "Anthro Enhancer"),
569
+ (general_lora_path, general_strength, "Reasoning Enhancer"),
570
+ (motion_lora_path, motion_strength, "Anthro Posing"),
571
+ (dreamlay_lora_path, dreamlay_strength, "Dreamlay"),
572
+ (mself_lora_path, mself_strength, "Mself"),
573
+ (dramatic_lora_path, dramatic_strength, "Dramatic"),
574
+ (fluid_lora_path, fluid_strength, "Fluid Helper"),
575
+ (liquid_lora_path, liquid_strength, "Liquid Helper"),
576
+ (demopose_lora_path, demopose_strength, "Audio Helper"),
577
+ (voice_lora_path, voice_strength, "Voice Helper"),
578
+ (realism_lora_path, realism_strength, "Anthro Realism"),
579
+ (transition_lora_path, transition_strength, "POV"),
580
+ ]
581
+
582
+ for path, strength, name in lora_entries:
583
+ if float(strength) != 0.0:
584
+ active_loras.append((path, float(strength)))
585
+ print(f"[LoRA] Active: {name} = {strength}")
586
+
587
+ if not active_loras:
588
+ # No LoRAs selected - apply base model weights (reset from any previous LoRAs)
589
+ print("[LoRA] No LoRAs selected, resetting to base model weights")
590
  try:
591
+ transformer = ledger.transformer()
592
+ base_weights = {k: v.cpu() for k, v in transformer.state_dict().items()}
593
+ transformer.load_state_dict(base_weights, strict=False)
594
+ if torch.cuda.is_available():
595
+ transformer = transformer.to("cuda")
596
+ current_lora_key = key
597
+ progress(1.0, desc="Done")
598
+ return "✓ Reset to base model (no LoRAs active)"
599
  except Exception as e:
600
+ return f" Reset failed: {e}"
601
+
602
+ # Check in-memory cache for this strength combination
603
+ if key in FUSED_CACHE:
604
+ print(f"[LoRA] Using cached fused state for: {key[:16]}...")
605
+ fused_state = FUSED_CACHE[key]
606
+ progress(0.85, desc="Using cached fused state")
607
+ else:
608
+ # Need to build the fused state dict (the expensive part)
609
+ print(f"[LoRA] Building new fused state dict for {len(active_loras)} LoRA(s)...")
610
+
611
+ # Progress callback that maps to Gradio's progress tracker
612
+ def progress_cb(step, desc):
613
+ progress(0.1 + step * 0.8, desc=desc)
614
+
615
+ transformer = ledger.transformer()
616
+ fused_state = build_fused_state_dict(transformer, active_loras, progress_cb)
617
+
618
+ # Cache the fused state for future reuse (keyed by strength combination)
619
+ FUSED_CACHE[key] = fused_state
620
+ print(f"[LoRA] Cached fused state for: {key[:16]}...")
621
+
622
+ # Apply fused state to transformer
623
+ progress(0.92, desc="Applying fused weights to transformer")
624
+
 
 
 
 
 
625
  try:
626
+ transformer = ledger.transformer()
627
+ target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
628
+
629
+ # Move transformer to CPU for loading (avoids device mismatch)
630
+ transformer = transformer.to("cpu")
631
+ torch.cuda.empty_cache() # Free VRAM from the CPU copy
632
+
633
+ # Load the fused state dict
634
+ missing, unexpected = transformer.load_state_dict(fused_state, strict=False)
635
+ if missing:
636
+ print(f"[LoRA] Warning: {len(missing)} keys not found in fused state")
637
+ if unexpected:
638
+ print(f"[LoRA] Warning: {len(unexpected)} unexpected keys in fused state")
639
+
640
+ # Move transformer to target device (GPU for generation)
641
+ if target_device.type != "cpu":
642
+ transformer = transformer.to(target_device)
643
+
644
+ current_lora_key = key
645
+ progress(1.0, desc="Done")
646
+ return f"✓ Applied {len(active_loras)} LoRA(s) successfully"
647
+
 
 
648
  except Exception as e:
649
  import traceback
650
+ print(f"[LoRA] Apply failed: {e}")
651
  print(traceback.format_exc())
652
+
653
+ # Try to restore transformer to GPU on error
 
 
 
 
 
 
 
 
654
  try:
655
+ transformer = ledger.transformer()
656
+ if next(transformer.parameters()).device.type == "cpu":
657
+ if torch.cuda.is_available():
658
+ transformer = transformer.to("cuda")
659
  except Exception:
660
  pass
661
+
662
+ return f"✗ LoRA application failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
  # Preload all models for ZeroGPU tensor packing.
665
  print("Preloading all models (including Gemma and audio components)...")
 
873
 
874
  log_memory("before pipeline call")
875
 
 
 
876
  video, audio = pipeline(
877
  prompt=prompt,
878
  negative_prompt=negative_prompt,
 
1057
  high_res.change(fn=on_highres_toggle, inputs=[first_image, last_image, high_res], outputs=[width, height])
1058
 
1059
  prepare_lora_btn.click(
1060
+ fn=on_prepare_loras_click,
1061
  inputs=[pose_strength, general_strength, motion_strength, dreamlay_strength,
1062
  mself_strength, dramatic_strength, fluid_strength, liquid_strength,
1063
  demopose_strength, voice_strength, realism_strength, transition_strength],