Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -472,18 +472,56 @@ def prepare_lora_cache(
|
|
| 472 |
"""Build cached LoRA state for single transformer."""
|
| 473 |
global pipeline
|
| 474 |
|
|
|
|
| 475 |
progress(0.05, desc="Preparing LoRA cache...")
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
key_str = f"{checkpoint_path}:{distilled_strength}:{pose_strength}:{general_strength}:{motion_strength}:{dreamlay_strength}:{mself_strength}:{dramatic_strength}:{fluid_strength}:{liquid_strength}:{demopose_strength}:{voice_strength}:{realism_strength}:{transition_strength}"
|
| 478 |
key = hashlib.sha256(key_str.encode()).hexdigest()
|
| 479 |
|
| 480 |
cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
|
|
|
|
|
|
|
| 481 |
|
| 482 |
if cache_path.exists():
|
|
|
|
| 483 |
progress(0.20, desc="Loading cached LoRA state...")
|
| 484 |
state = load_file(str(cache_path))
|
|
|
|
| 485 |
pipeline.apply_cached_lora_state(state)
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
entries = [
|
| 489 |
(distilled_lora_path, distilled_strength),
|
|
@@ -501,34 +539,55 @@ def prepare_lora_cache(
|
|
| 501 |
(transition_lora_path, transition_strength),
|
| 502 |
]
|
| 503 |
|
| 504 |
-
|
| 505 |
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 506 |
for path, strength in entries
|
| 507 |
if path is not None and float(strength) != 0.0
|
| 508 |
]
|
| 509 |
|
|
|
|
|
|
|
| 510 |
progress(0.35, desc="Building fused state (CPU)...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
tmp_ledger = pipeline.model_ledger.__class__(
|
| 512 |
dtype=torch.bfloat16,
|
| 513 |
device=torch.device("cpu"),
|
| 514 |
checkpoint_path=str(checkpoint_path),
|
| 515 |
spatial_upsampler_path=str(spatial_upsampler_path),
|
| 516 |
gemma_root_path=str(gemma_root),
|
| 517 |
-
loras=tuple(
|
| 518 |
quantization=None,
|
| 519 |
)
|
|
|
|
|
|
|
|
|
|
| 520 |
transformer = tmp_ledger.transformer()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
state = {k: v.detach().cpu().contiguous() for k, v in transformer.state_dict().items()}
|
|
|
|
|
|
|
|
|
|
| 522 |
save_file(state, str(cache_path))
|
|
|
|
| 523 |
|
|
|
|
| 524 |
del transformer, tmp_ledger
|
| 525 |
gc.collect()
|
|
|
|
| 526 |
|
|
|
|
| 527 |
progress(0.90, desc="Applying LoRA state to pipeline...")
|
| 528 |
pipeline.apply_cached_lora_state(state)
|
| 529 |
|
| 530 |
progress(1.0, desc="Done!")
|
| 531 |
-
|
|
|
|
| 532 |
|
| 533 |
# =============================================================================
|
| 534 |
# LoRA State Application (called BEFORE pipeline generation)
|
|
@@ -539,23 +598,40 @@ def apply_prepared_lora_state_to_pipeline():
|
|
| 539 |
Apply the prepared LoRA state from pipeline._cached_state to the preloaded
|
| 540 |
transformer. This should be called BEFORE pipeline generation, not during.
|
| 541 |
"""
|
|
|
|
|
|
|
| 542 |
if pipeline._cached_state is None:
|
| 543 |
print("[LoRA] No prepared LoRA state available; skipping.")
|
|
|
|
| 544 |
return False
|
| 545 |
|
| 546 |
try:
|
| 547 |
existing_transformer = _transformer # The preloaded transformer from globals
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
with torch.no_grad():
|
| 549 |
-
missing, unexpected = existing_transformer.load_state_dict(
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
-
print("[LoRA]
|
| 556 |
return True
|
| 557 |
except Exception as e:
|
| 558 |
-
print(f"[LoRA]
|
|
|
|
| 559 |
return False
|
| 560 |
|
| 561 |
# =============================================================================
|
|
@@ -759,7 +835,6 @@ def generate_video(
|
|
| 759 |
width=width,
|
| 760 |
num_frames=num_frames,
|
| 761 |
frame_rate=DEFAULT_FRAME_RATE,
|
| 762 |
-
num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps,
|
| 763 |
video_guider_params=video_guider_params,
|
| 764 |
audio_guider_params=audio_guider_params,
|
| 765 |
images=images,
|
|
|
|
| 472 |
"""Build cached LoRA state for single transformer."""
|
| 473 |
global pipeline
|
| 474 |
|
| 475 |
+
print("[LoRA] === Starting LoRA Cache Preparation ===")
|
| 476 |
progress(0.05, desc="Preparing LoRA cache...")
|
| 477 |
|
| 478 |
+
# Validate all LoRA files exist
|
| 479 |
+
print("[LoRA] Validating LoRA file paths...")
|
| 480 |
+
lora_files = [
|
| 481 |
+
("Distilled", distilled_lora_path, distilled_strength),
|
| 482 |
+
("Pose", pose_lora_path, pose_strength),
|
| 483 |
+
("General", general_lora_path, general_strength),
|
| 484 |
+
("Motion", motion_lora_path, motion_strength),
|
| 485 |
+
("Dreamlay", dreamlay_lora_path, dreamlay_strength),
|
| 486 |
+
("Mself", mself_lora_path, mself_strength),
|
| 487 |
+
("Dramatic", dramatic_lora_path, dramatic_strength),
|
| 488 |
+
("Fluid", fluid_lora_path, fluid_strength),
|
| 489 |
+
("Liquid", liquid_lora_path, liquid_strength),
|
| 490 |
+
("Demopose", demopose_lora_path, demopose_strength),
|
| 491 |
+
("Voice", voice_lora_path, voice_strength),
|
| 492 |
+
("Realism", realism_lora_path, realism_strength),
|
| 493 |
+
("Transition", transition_lora_path, transition_strength),
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
active_loras = []
|
| 497 |
+
for name, path, strength in lora_files:
|
| 498 |
+
if path is not None and float(strength) != 0.0:
|
| 499 |
+
active_loras.append((name, path, strength))
|
| 500 |
+
print(f"[LoRA] - {name}: strength={strength}")
|
| 501 |
+
|
| 502 |
+
print(f"[LoRA] Active LoRAs: {len(active_loras)}")
|
| 503 |
+
|
| 504 |
key_str = f"{checkpoint_path}:{distilled_strength}:{pose_strength}:{general_strength}:{motion_strength}:{dreamlay_strength}:{mself_strength}:{dramatic_strength}:{fluid_strength}:{liquid_strength}:{demopose_strength}:{voice_strength}:{realism_strength}:{transition_strength}"
|
| 505 |
key = hashlib.sha256(key_str.encode()).hexdigest()
|
| 506 |
|
| 507 |
cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
|
| 508 |
+
print(f"[LoRA] Cache key: {key[:16]}...")
|
| 509 |
+
print(f"[LoRA] Cache path: {cache_path}")
|
| 510 |
|
| 511 |
if cache_path.exists():
|
| 512 |
+
print("[LoRA] Loading from existing cache...")
|
| 513 |
progress(0.20, desc="Loading cached LoRA state...")
|
| 514 |
state = load_file(str(cache_path))
|
| 515 |
+
print(f"[LoRA] Loaded state dict with {len(state)} keys, size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
|
| 516 |
pipeline.apply_cached_lora_state(state)
|
| 517 |
+
print("[LoRA] State applied to pipeline._cached_state")
|
| 518 |
+
print("[LoRA] === LoRA Cache Preparation Complete ===")
|
| 519 |
+
return f"Loaded cached LoRA state: {cache_path.name} ({len(state)} keys)"
|
| 520 |
+
|
| 521 |
+
if not active_loras:
|
| 522 |
+
print("[LoRA] No non-zero LoRA strengths selected; nothing to prepare.")
|
| 523 |
+
print("[LoRA] === LoRA Cache Preparation Complete (no LoRAs) ===")
|
| 524 |
+
return "No non-zero LoRA strengths selected; nothing to prepare."
|
| 525 |
|
| 526 |
entries = [
|
| 527 |
(distilled_lora_path, distilled_strength),
|
|
|
|
| 539 |
(transition_lora_path, transition_strength),
|
| 540 |
]
|
| 541 |
|
| 542 |
+
loras_for_builder = [
|
| 543 |
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 544 |
for path, strength in entries
|
| 545 |
if path is not None and float(strength) != 0.0
|
| 546 |
]
|
| 547 |
|
| 548 |
+
print(f"[LoRA] Building fused state on CPU with {len(loras_for_builder)} LoRAs...")
|
| 549 |
+
print("[LoRA] This may take several minutes (do not close the Space)...")
|
| 550 |
progress(0.35, desc="Building fused state (CPU)...")
|
| 551 |
+
|
| 552 |
+
import time
|
| 553 |
+
start_time = time.time()
|
| 554 |
+
|
| 555 |
tmp_ledger = pipeline.model_ledger.__class__(
|
| 556 |
dtype=torch.bfloat16,
|
| 557 |
device=torch.device("cpu"),
|
| 558 |
checkpoint_path=str(checkpoint_path),
|
| 559 |
spatial_upsampler_path=str(spatial_upsampler_path),
|
| 560 |
gemma_root_path=str(gemma_root),
|
| 561 |
+
loras=tuple(loras_for_builder),
|
| 562 |
quantization=None,
|
| 563 |
)
|
| 564 |
+
print(f"[LoRA] Temporary ledger created in {time.time() - start_time:.1f}s")
|
| 565 |
+
|
| 566 |
+
print("[LoRA] Loading transformer with LoRAs applied...")
|
| 567 |
transformer = tmp_ledger.transformer()
|
| 568 |
+
print(f"[LoRA] Transformer loaded in {time.time() - start_time:.1f}s")
|
| 569 |
+
|
| 570 |
+
print("[LoRA] Extracting state dict...")
|
| 571 |
+
progress(0.70, desc="Extracting fused stateDict")
|
| 572 |
state = {k: v.detach().cpu().contiguous() for k, v in transformer.state_dict().items()}
|
| 573 |
+
print(f"[LoRA] State dict extracted: {len(state)} keys")
|
| 574 |
+
|
| 575 |
+
print(f"[LoRA] Saving to cache: {cache_path}")
|
| 576 |
save_file(state, str(cache_path))
|
| 577 |
+
print(f"[LoRA] Cache saved, size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
|
| 578 |
|
| 579 |
+
print("[LoRA] Cleaning up temporary ledger...")
|
| 580 |
del transformer, tmp_ledger
|
| 581 |
gc.collect()
|
| 582 |
+
print(f"[LoRA] Cleanup complete in {time.time() - start_time:.1f}s total")
|
| 583 |
|
| 584 |
+
print("[LoRA] Applying state to pipeline._cached_state...")
|
| 585 |
progress(0.90, desc="Applying LoRA state to pipeline...")
|
| 586 |
pipeline.apply_cached_lora_state(state)
|
| 587 |
|
| 588 |
progress(1.0, desc="Done!")
|
| 589 |
+
print("[LoRA] === LoRA Cache Preparation Complete ===")
|
| 590 |
+
return f"Built and cached LoRA state: {cache_path.name} ({len(state)} keys, {time.time() - start_time:.1f}s)"
|
| 591 |
|
| 592 |
# =============================================================================
|
| 593 |
# LoRA State Application (called BEFORE pipeline generation)
|
|
|
|
| 598 |
Apply the prepared LoRA state from pipeline._cached_state to the preloaded
|
| 599 |
transformer. This should be called BEFORE pipeline generation, not during.
|
| 600 |
"""
|
| 601 |
+
print("[LoRA] === Applying LoRA State to Transformer ===")
|
| 602 |
+
|
| 603 |
if pipeline._cached_state is None:
|
| 604 |
print("[LoRA] No prepared LoRA state available; skipping.")
|
| 605 |
+
print("[LoRA] === LoRA Application Complete (no state) ===")
|
| 606 |
return False
|
| 607 |
|
| 608 |
try:
|
| 609 |
existing_transformer = _transformer # The preloaded transformer from globals
|
| 610 |
+
state = pipeline._cached_state
|
| 611 |
+
print(f"[LoRA] Applying state dict with {len(state)} keys...")
|
| 612 |
+
print(f"[LoRA] State dict size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
|
| 613 |
+
|
| 614 |
+
import time
|
| 615 |
+
start_time = time.time()
|
| 616 |
+
|
| 617 |
with torch.no_grad():
|
| 618 |
+
missing, unexpected = existing_transformer.load_state_dict(state, strict=False)
|
| 619 |
+
|
| 620 |
+
print(f"[LoRA] load_state_dict completed in {time.time() - start_time:.1f}s")
|
| 621 |
+
|
| 622 |
+
if missing:
|
| 623 |
+
print(f"[LoRA] WARNING: {len(missing)} keys missing from state dict")
|
| 624 |
+
if unexpected:
|
| 625 |
+
print(f"[LoRA] WARNING: {len(unexpected)} unexpected keys in state dict")
|
| 626 |
+
|
| 627 |
+
if not missing and not unexpected:
|
| 628 |
+
print("[LoRA] State dict loaded successfully with no mismatches!")
|
| 629 |
|
| 630 |
+
print("[LoRA] === LoRA Application Complete (success) ===")
|
| 631 |
return True
|
| 632 |
except Exception as e:
|
| 633 |
+
print(f"[LoRA] FAILED to apply LoRA state: {type(e).__name__}: {e}")
|
| 634 |
+
print("[LoRA] === LoRA Application Complete (FAILED) ===")
|
| 635 |
return False
|
| 636 |
|
| 637 |
# =============================================================================
|
|
|
|
| 835 |
width=width,
|
| 836 |
num_frames=num_frames,
|
| 837 |
frame_rate=DEFAULT_FRAME_RATE,
|
|
|
|
| 838 |
video_guider_params=video_guider_params,
|
| 839 |
audio_guider_params=audio_guider_params,
|
| 840 |
images=images,
|