Update app.py
Browse files
app.py
CHANGED
|
@@ -272,6 +272,10 @@ LORA_CACHE_DIR = Path("lora_cache")
|
|
| 272 |
LORA_CACHE_DIR.mkdir(exist_ok=True)
|
| 273 |
current_lora_key: str | None = None
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
|
| 276 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 277 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
|
@@ -307,91 +311,76 @@ pipeline = LTX23DistilledA2VPipeline(
|
|
| 307 |
)
|
| 308 |
# ----------------------------------------------------------------
|
| 309 |
|
| 310 |
-
def
|
| 311 |
-
"""
|
| 312 |
-
Apply LoRAs using cached fused state_dicts when available, otherwise
|
| 313 |
-
build fused transformer on CPU, save its state_dict to cache, and then
|
| 314 |
-
copy parameters in-place into the GPU-resident transformer.
|
| 315 |
-
|
| 316 |
-
Caching key:
|
| 317 |
-
sha256( f"{pose_path}:{round(pose,2)}|{general_path}:{round(general,2)}|{motion_path}:{round(motion,2)}" )
|
| 318 |
-
|
| 319 |
-
Rounding to 2 decimals reduces unique keys and increases cache hits.
|
| 320 |
-
"""
|
| 321 |
-
ledger = pipeline.model_ledger
|
| 322 |
-
global current_lora_key, LORA_CACHE_DIR
|
| 323 |
-
|
| 324 |
-
# Round strengths to 2 decimals for cache stability
|
| 325 |
rp = round(float(pose_strength), 2)
|
| 326 |
rg = round(float(general_strength), 2)
|
| 327 |
rm = round(float(motion_strength), 2)
|
| 328 |
-
|
| 329 |
key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}"
|
| 330 |
key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
|
| 331 |
-
|
| 332 |
|
| 333 |
-
# If same key already applied, skip
|
| 334 |
-
if current_lora_key == key:
|
| 335 |
-
print("[LoRA] Key unchanged; skipping rebuild/copy.")
|
| 336 |
-
return
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
entries = [
|
| 353 |
-
(pose_lora_path, rp),
|
| 354 |
-
(general_lora_path, rg),
|
| 355 |
-
(motion_lora_path, rm),
|
| 356 |
-
]
|
| 357 |
-
loras_for_builder = [
|
| 358 |
-
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 359 |
-
for path, strength in entries
|
| 360 |
-
if path is not None and float(strength) != 0.0
|
| 361 |
-
]
|
| 362 |
-
if len(loras_for_builder) == 0:
|
| 363 |
-
print("[LoRA] No nonzero LoRA strengths — skipping rebuild.")
|
| 364 |
-
return
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
try:
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
except Exception as e:
|
| 370 |
-
print(f"[LoRA]
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
# Build fused transformer on CPU only. Ensure tmp_ledger._target_device is callable.
|
| 374 |
orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
|
| 375 |
orig_tmp_device = getattr(tmp_ledger, "device", None)
|
| 376 |
try:
|
| 377 |
-
tmp_ledger._target_device =
|
| 378 |
tmp_ledger.device = torch.device("cpu")
|
| 379 |
-
|
| 380 |
-
new_transformer_cpu = tmp_ledger.transformer() # model on CPU
|
| 381 |
-
print("[LoRA] Fused transformer built on CPU.")
|
| 382 |
-
except Exception as e:
|
| 383 |
-
import traceback
|
| 384 |
-
print(f"[LoRA] Error while building fused transformer on CPU: {type(e).__name__}: {e}")
|
| 385 |
-
print(traceback.format_exc())
|
| 386 |
-
# cleanup
|
| 387 |
-
try:
|
| 388 |
-
del tmp_ledger
|
| 389 |
-
except Exception:
|
| 390 |
-
pass
|
| 391 |
-
gc.collect()
|
| 392 |
-
return
|
| 393 |
finally:
|
| 394 |
-
# restore attributes
|
| 395 |
if orig_tmp_target is not None:
|
| 396 |
tmp_ledger._target_device = orig_tmp_target
|
| 397 |
else:
|
|
@@ -407,85 +396,65 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
|
|
| 407 |
except Exception:
|
| 408 |
pass
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
-
|
| 424 |
try:
|
| 425 |
del new_transformer_cpu
|
|
|
|
|
|
|
|
|
|
| 426 |
del tmp_ledger
|
| 427 |
except Exception:
|
| 428 |
pass
|
| 429 |
gc.collect()
|
| 430 |
-
torch.cuda.empty_cache()
|
| 431 |
|
| 432 |
-
if new_state is None:
|
| 433 |
-
print("[LoRA] Building fused state failed; aborting.")
|
| 434 |
-
return
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
except NameError:
|
| 443 |
-
existing_transformer = ledger.transformer()
|
| 444 |
-
_transformer = existing_transformer
|
| 445 |
-
|
| 446 |
-
existing_params = {name: param for name, param in existing_transformer.named_parameters()}
|
| 447 |
-
existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
|
| 448 |
-
|
| 449 |
-
# diagnostics: how many keys will be copied
|
| 450 |
-
total_keys = len(new_state)
|
| 451 |
-
matched = sum(1 for k in new_state if k in existing_params or k in existing_buffers)
|
| 452 |
-
print(f"[LoRA] Transformer state keys: total={total_keys} matched_for_copy={matched}")
|
| 453 |
-
if matched == 0:
|
| 454 |
-
sample_keys = list(new_state.keys())[:10]
|
| 455 |
-
print(f"[LoRA] Warning: 0 matching keys found. sample new_state keys: {sample_keys}")
|
| 456 |
-
|
| 457 |
-
# Copy CPU tensors into GPU-resident transformer's params/buffers in-place
|
| 458 |
-
with torch.no_grad():
|
| 459 |
-
for k, v in new_state.items():
|
| 460 |
-
if k in existing_params:
|
| 461 |
-
tgt = existing_params[k].data
|
| 462 |
-
try:
|
| 463 |
-
tgt.copy_(v.to(tgt.device))
|
| 464 |
-
except Exception as e:
|
| 465 |
-
print(f"[LoRA] Failed to copy parameter {k}: {type(e).__name__}: {e}")
|
| 466 |
-
elif k in existing_buffers:
|
| 467 |
-
tgt = existing_buffers[k].data
|
| 468 |
-
try:
|
| 469 |
-
tgt.copy_(v.to(tgt.device))
|
| 470 |
-
except Exception as e:
|
| 471 |
-
print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
|
| 472 |
-
else:
|
| 473 |
-
# name mismatch — skip
|
| 474 |
-
pass
|
| 475 |
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
gc.collect()
|
| 480 |
-
torch.cuda.empty_cache()
|
| 481 |
-
print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
|
| 482 |
-
return
|
| 483 |
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
# ---- REPLACE PRELOAD BLOCK START ----
|
| 491 |
# Preload all models for ZeroGPU tensor packing.
|
|
@@ -648,7 +617,7 @@ def generate_video(
|
|
| 648 |
|
| 649 |
log_memory("before pipeline call")
|
| 650 |
|
| 651 |
-
|
| 652 |
|
| 653 |
video, audio = pipeline(
|
| 654 |
prompt=prompt,
|
|
@@ -729,6 +698,12 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
|
|
| 729 |
label="Motion Helper strength",
|
| 730 |
minimum=0.0, maximum=2.0, value=0.0, step=0.01
|
| 731 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
with gr.Column():
|
| 734 |
output_video = gr.Video(label="Generated Video", autoplay=False)
|
|
@@ -789,6 +764,12 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
|
|
| 789 |
outputs=[width, height],
|
| 790 |
)
|
| 791 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
generate_btn.click(
|
| 793 |
fn=generate_video,
|
| 794 |
inputs=[
|
|
|
|
| 272 |
LORA_CACHE_DIR.mkdir(exist_ok=True)
|
| 273 |
current_lora_key: str | None = None
|
| 274 |
|
| 275 |
+
PENDING_LORA_KEY: str | None = None
|
| 276 |
+
PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
|
| 277 |
+
PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
|
| 278 |
+
|
| 279 |
checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
|
| 280 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 281 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
|
|
|
| 311 |
)
|
| 312 |
# ----------------------------------------------------------------
|
| 313 |
|
| 314 |
+
def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float) -> tuple[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
rp = round(float(pose_strength), 2)
|
| 316 |
rg = round(float(general_strength), 2)
|
| 317 |
rm = round(float(motion_strength), 2)
|
|
|
|
| 318 |
key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}"
|
| 319 |
key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
|
| 320 |
+
return key, key_str
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
def prepare_lora_cache(
|
| 324 |
+
pose_strength: float,
|
| 325 |
+
general_strength: float,
|
| 326 |
+
motion_strength: float,
|
| 327 |
+
progress=gr.Progress(track_tqdm=True),
|
| 328 |
+
):
|
| 329 |
+
"""
|
| 330 |
+
CPU-only step:
|
| 331 |
+
- checks cache
|
| 332 |
+
- loads cached fused transformer state_dict, or
|
| 333 |
+
- builds fused transformer on CPU and saves it
|
| 334 |
+
The resulting state_dict is stored in memory and can be applied later.
|
| 335 |
+
"""
|
| 336 |
+
global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
+
ledger = pipeline.model_ledger
|
| 339 |
+
key, _ = _make_lora_key(pose_strength, general_strength, motion_strength)
|
| 340 |
+
cache_path = LORA_CACHE_DIR / f"{key}.pt"
|
| 341 |
+
|
| 342 |
+
progress(0.05, desc="Preparing LoRA state")
|
| 343 |
+
if cache_path.exists():
|
| 344 |
try:
|
| 345 |
+
progress(0.20, desc="Loading cached fused state")
|
| 346 |
+
state = torch.load(cache_path, map_location="cpu")
|
| 347 |
+
PENDING_LORA_KEY = key
|
| 348 |
+
PENDING_LORA_STATE = state
|
| 349 |
+
PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
|
| 350 |
+
return PENDING_LORA_STATUS
|
| 351 |
except Exception as e:
|
| 352 |
+
print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
|
| 353 |
+
|
| 354 |
+
entries = [
|
| 355 |
+
(pose_lora_path, round(float(pose_strength), 2)),
|
| 356 |
+
(general_lora_path, round(float(general_strength), 2)),
|
| 357 |
+
(motion_lora_path, round(float(motion_strength), 2)),
|
| 358 |
+
]
|
| 359 |
+
loras_for_builder = [
|
| 360 |
+
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 361 |
+
for path, strength in entries
|
| 362 |
+
if path is not None and float(strength) != 0.0
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
if not loras_for_builder:
|
| 366 |
+
PENDING_LORA_KEY = None
|
| 367 |
+
PENDING_LORA_STATE = None
|
| 368 |
+
PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
|
| 369 |
+
return PENDING_LORA_STATUS
|
| 370 |
+
|
| 371 |
+
tmp_ledger = None
|
| 372 |
+
new_transformer_cpu = None
|
| 373 |
+
try:
|
| 374 |
+
progress(0.35, desc="Building fused CPU transformer")
|
| 375 |
+
tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
|
| 376 |
|
|
|
|
| 377 |
orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
|
| 378 |
orig_tmp_device = getattr(tmp_ledger, "device", None)
|
| 379 |
try:
|
| 380 |
+
tmp_ledger._target_device = lambda: torch.device("cpu")
|
| 381 |
tmp_ledger.device = torch.device("cpu")
|
| 382 |
+
new_transformer_cpu = tmp_ledger.transformer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
finally:
|
|
|
|
| 384 |
if orig_tmp_target is not None:
|
| 385 |
tmp_ledger._target_device = orig_tmp_target
|
| 386 |
else:
|
|
|
|
| 396 |
except Exception:
|
| 397 |
pass
|
| 398 |
|
| 399 |
+
progress(0.70, desc="Extracting fused state_dict")
|
| 400 |
+
state = new_transformer_cpu.state_dict()
|
| 401 |
+
torch.save(state, cache_path)
|
| 402 |
+
|
| 403 |
+
PENDING_LORA_KEY = key
|
| 404 |
+
PENDING_LORA_STATE = state
|
| 405 |
+
PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
|
| 406 |
+
return PENDING_LORA_STATUS
|
| 407 |
+
|
| 408 |
+
except Exception as e:
|
| 409 |
+
import traceback
|
| 410 |
+
print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
|
| 411 |
+
print(traceback.format_exc())
|
| 412 |
+
PENDING_LORA_KEY = None
|
| 413 |
+
PENDING_LORA_STATE = None
|
| 414 |
+
PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
|
| 415 |
+
return PENDING_LORA_STATUS
|
| 416 |
|
| 417 |
+
finally:
|
| 418 |
try:
|
| 419 |
del new_transformer_cpu
|
| 420 |
+
except Exception:
|
| 421 |
+
pass
|
| 422 |
+
try:
|
| 423 |
del tmp_ledger
|
| 424 |
except Exception:
|
| 425 |
pass
|
| 426 |
gc.collect()
|
|
|
|
| 427 |
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
+
def apply_prepared_lora_state_to_pipeline():
|
| 430 |
+
"""
|
| 431 |
+
Fast step: copy the already prepared CPU state into the live transformer.
|
| 432 |
+
This is the only part that should remain near generation time.
|
| 433 |
+
"""
|
| 434 |
+
global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
+
if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
|
| 437 |
+
print("[LoRA] No prepared LoRA state available; skipping.")
|
| 438 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
+
if current_lora_key == PENDING_LORA_KEY:
|
| 441 |
+
print("[LoRA] Prepared LoRA state already active; skipping.")
|
| 442 |
+
return True
|
| 443 |
+
|
| 444 |
+
existing_transformer = _transformer
|
| 445 |
+
existing_params = {name: param for name, param in existing_transformer.named_parameters()}
|
| 446 |
+
existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
|
| 447 |
+
|
| 448 |
+
with torch.no_grad():
|
| 449 |
+
for k, v in PENDING_LORA_STATE.items():
|
| 450 |
+
if k in existing_params:
|
| 451 |
+
existing_params[k].data.copy_(v.to(existing_params[k].device))
|
| 452 |
+
elif k in existing_buffers:
|
| 453 |
+
existing_buffers[k].data.copy_(v.to(existing_buffers[k].device))
|
| 454 |
+
|
| 455 |
+
current_lora_key = PENDING_LORA_KEY
|
| 456 |
+
print("[LoRA] Prepared LoRA state applied to the pipeline.")
|
| 457 |
+
return True
|
| 458 |
|
| 459 |
# ---- REPLACE PRELOAD BLOCK START ----
|
| 460 |
# Preload all models for ZeroGPU tensor packing.
|
|
|
|
| 617 |
|
| 618 |
log_memory("before pipeline call")
|
| 619 |
|
| 620 |
+
apply_prepared_lora_state_to_pipeline()
|
| 621 |
|
| 622 |
video, audio = pipeline(
|
| 623 |
prompt=prompt,
|
|
|
|
| 698 |
label="Motion Helper strength",
|
| 699 |
minimum=0.0, maximum=2.0, value=0.0, step=0.01
|
| 700 |
)
|
| 701 |
+
prepare_lora_btn = gr.Button("Prepare / Load LoRA Cache", variant="secondary")
|
| 702 |
+
lora_status = gr.Textbox(
|
| 703 |
+
label="LoRA Cache Status",
|
| 704 |
+
value="No LoRA state prepared yet.",
|
| 705 |
+
interactive=False,
|
| 706 |
+
)
|
| 707 |
|
| 708 |
with gr.Column():
|
| 709 |
output_video = gr.Video(label="Generated Video", autoplay=False)
|
|
|
|
| 764 |
outputs=[width, height],
|
| 765 |
)
|
| 766 |
|
| 767 |
+
prepare_lora_btn.click(
|
| 768 |
+
fn=prepare_lora_cache,
|
| 769 |
+
inputs=[pose_strength, general_strength, motion_strength],
|
| 770 |
+
outputs=[lora_status],
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
generate_btn.click(
|
| 774 |
fn=generate_video,
|
| 775 |
inputs=[
|