Update app.py
Browse files
app.py
CHANGED
|
@@ -266,6 +266,11 @@ print("=" * 80)
|
|
| 266 |
print("Downloading LTX-2.3 distilled model + Gemma...")
|
| 267 |
print("=" * 80)
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
|
| 270 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 271 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
|
@@ -303,56 +308,92 @@ pipeline = LTX23DistilledA2VPipeline(
|
|
| 303 |
|
| 304 |
def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
|
| 305 |
"""
|
| 306 |
-
Apply LoRAs
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
"""
|
| 313 |
ledger = pipeline.model_ledger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
loras_for_builder = [
|
| 323 |
-
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 324 |
-
for path, strength in entries
|
| 325 |
-
if path is not None and float(strength) != 0.0
|
| 326 |
-
]
|
| 327 |
-
|
| 328 |
-
if len(loras_for_builder) == 0:
|
| 329 |
-
print("[LoRA] No nonzero LoRA strengths — skipping rebuild.")
|
| 330 |
return
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
|
| 340 |
orig_tmp_device = getattr(tmp_ledger, "device", None)
|
| 341 |
try:
|
| 342 |
-
# _target_device is expected to be callable by model_ledger.transformer()
|
| 343 |
-
# set it to a callable that returns CPU so builder.build(device=...) works.
|
| 344 |
tmp_ledger._target_device = (lambda: torch.device("cpu"))
|
| 345 |
-
# ledger.device is used after build: set it to CPU so .to(self.device) keeps the model on CPU.
|
| 346 |
tmp_ledger.device = torch.device("cpu")
|
| 347 |
print("[LoRA] Building fused transformer on CPU (no GPU allocation)...")
|
| 348 |
-
new_transformer_cpu = tmp_ledger.transformer() #
|
| 349 |
print("[LoRA] Fused transformer built on CPU.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
finally:
|
| 351 |
-
#
|
| 352 |
if orig_tmp_target is not None:
|
| 353 |
tmp_ledger._target_device = orig_tmp_target
|
| 354 |
else:
|
| 355 |
-
# remove attribute if ledger did not have it previously
|
| 356 |
try:
|
| 357 |
delattr(tmp_ledger, "_target_device")
|
| 358 |
except Exception:
|
|
@@ -365,31 +406,54 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
|
|
| 365 |
except Exception:
|
| 366 |
pass
|
| 367 |
|
| 368 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
global _transformer
|
| 370 |
try:
|
| 371 |
existing_transformer = _transformer
|
| 372 |
except NameError:
|
| 373 |
-
# If not cached, ask ledger for it (this will be the GPU-resident model already loaded).
|
| 374 |
existing_transformer = ledger.transformer()
|
| 375 |
_transformer = existing_transformer
|
| 376 |
|
| 377 |
-
# Map existing parameters & buffers for quick lookup
|
| 378 |
existing_params = {name: param for name, param in existing_transformer.named_parameters()}
|
| 379 |
existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
|
| 380 |
|
| 381 |
-
# State dict of CPU model (fused with LoRAs)
|
| 382 |
-
new_state = new_transformer_cpu.state_dict()
|
| 383 |
# diagnostics: how many keys will be copied
|
| 384 |
total_keys = len(new_state)
|
| 385 |
matched = sum(1 for k in new_state if k in existing_params or k in existing_buffers)
|
| 386 |
print(f"[LoRA] Transformer state keys: total={total_keys} matched_for_copy={matched}")
|
| 387 |
if matched == 0:
|
| 388 |
-
# helpful hint if naming differs
|
| 389 |
sample_keys = list(new_state.keys())[:10]
|
| 390 |
print(f"[LoRA] Warning: 0 matching keys found. sample new_state keys: {sample_keys}")
|
| 391 |
|
| 392 |
-
# Copy CPU tensors into
|
| 393 |
with torch.no_grad():
|
| 394 |
for k, v in new_state.items():
|
| 395 |
if k in existing_params:
|
|
@@ -405,30 +469,22 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
|
|
| 405 |
except Exception as e:
|
| 406 |
print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
|
| 407 |
else:
|
| 408 |
-
#
|
| 409 |
-
# This can happen if LoRA changes expected keys; not fatal.
|
| 410 |
-
# Print debug once for the first few unmatched keys.
|
| 411 |
pass
|
| 412 |
|
| 413 |
-
#
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
del tmp_ledger
|
| 417 |
-
except Exception:
|
| 418 |
-
pass
|
| 419 |
gc.collect()
|
| 420 |
torch.cuda.empty_cache()
|
| 421 |
-
|
| 422 |
print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
|
| 423 |
return
|
| 424 |
|
| 425 |
except Exception as e:
|
| 426 |
import traceback
|
| 427 |
-
print(f"[LoRA] Error during in-place LoRA application: {type(e).__name__}: {e}")
|
| 428 |
print(traceback.format_exc())
|
| 429 |
-
|
| 430 |
-
# If something unexpectedly failed, bail out (no fallback).
|
| 431 |
-
print("[LoRA] apply_loras_to_pipeline finished (LOADING FAILED — no changes applied).")
|
| 432 |
|
| 433 |
# ---- REPLACE PRELOAD BLOCK START ----
|
| 434 |
# Preload all models for ZeroGPU tensor packing.
|
|
|
|
| 266 |
print("Downloading LTX-2.3 distilled model + Gemma...")
|
| 267 |
print("=" * 80)
|
| 268 |
|
| 269 |
+
# LoRA cache directory and currently-applied key
|
| 270 |
+
LORA_CACHE_DIR = Path("lora_cache")
|
| 271 |
+
LORA_CACHE_DIR.mkdir(exist_ok=True)
|
| 272 |
+
current_lora_key: str | None = None
|
| 273 |
+
|
| 274 |
checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
|
| 275 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 276 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
|
|
|
| 308 |
|
| 309 |
def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
|
| 310 |
"""
|
| 311 |
+
Apply LoRAs using cached fused state_dicts when available, otherwise
|
| 312 |
+
build fused transformer on CPU, save its state_dict to cache, and then
|
| 313 |
+
copy parameters in-place into the GPU-resident transformer.
|
| 314 |
+
|
| 315 |
+
Caching key:
|
| 316 |
+
sha256( f"{pose_path}:{round(pose,2)}|{general_path}:{round(general,2)}|{motion_path}:{round(motion,2)}" )
|
| 317 |
+
|
| 318 |
+
Rounding to 2 decimals reduces unique keys and increases cache hits.
|
| 319 |
"""
|
| 320 |
ledger = pipeline.model_ledger
|
| 321 |
+
global current_lora_key, LORA_CACHE_DIR
|
| 322 |
+
|
| 323 |
+
# Round strengths to 2 decimals for cache stability
|
| 324 |
+
rp = round(float(pose_strength), 2)
|
| 325 |
+
rg = round(float(general_strength), 2)
|
| 326 |
+
rm = round(float(motion_strength), 2)
|
| 327 |
|
| 328 |
+
key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}"
|
| 329 |
+
key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
|
| 330 |
+
cache_path = LORA_CACHE_DIR / f"{key}.pt"
|
| 331 |
+
|
| 332 |
+
# If same key already applied, skip
|
| 333 |
+
if current_lora_key == key:
|
| 334 |
+
print("[LoRA] Key unchanged; skipping rebuild/copy.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
return
|
| 336 |
|
| 337 |
+
# If cache exists, load the fused state_dict directly
|
| 338 |
+
new_state = None
|
| 339 |
+
if cache_path.exists():
|
| 340 |
+
try:
|
| 341 |
+
print("[LoRA] Cache hit: loading fused state_dict from cache.")
|
| 342 |
+
new_state = torch.load(cache_path, map_location="cpu")
|
| 343 |
+
print("[LoRA] Loaded cached fused state_dict.")
|
| 344 |
+
except Exception as e:
|
| 345 |
+
print(f"[LoRA] Failed to load cache {cache_path}: {type(e).__name__}: {e}")
|
| 346 |
+
new_state = None
|
| 347 |
+
|
| 348 |
+
# If no cache, build tmp_ledger and create fused transformer on CPU, save state_dict
|
| 349 |
+
if new_state is None:
|
| 350 |
+
# Only build if there is at least one non-zero strength
|
| 351 |
+
entries = [
|
| 352 |
+
(pose_lora_path, rp),
|
| 353 |
+
(general_lora_path, rg),
|
| 354 |
+
(motion_lora_path, rm),
|
| 355 |
+
]
|
| 356 |
+
loras_for_builder = [
|
| 357 |
+
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 358 |
+
for path, strength in entries
|
| 359 |
+
if path is not None and float(strength) != 0.0
|
| 360 |
+
]
|
| 361 |
+
if len(loras_for_builder) == 0:
|
| 362 |
+
print("[LoRA] No nonzero LoRA strengths — skipping rebuild.")
|
| 363 |
+
return
|
| 364 |
|
| 365 |
+
try:
|
| 366 |
+
tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
|
| 367 |
+
print(f"[LoRA] Built temporary ledger with {len(loras_for_builder)} LoRA(s).")
|
| 368 |
+
except Exception as e:
|
| 369 |
+
print(f"[LoRA] Failed to create temporary ledger: {type(e).__name__}: {e}")
|
| 370 |
+
return
|
| 371 |
+
|
| 372 |
+
# Build fused transformer on CPU only. Ensure tmp_ledger._target_device is callable.
|
| 373 |
orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
|
| 374 |
orig_tmp_device = getattr(tmp_ledger, "device", None)
|
| 375 |
try:
|
|
|
|
|
|
|
| 376 |
tmp_ledger._target_device = (lambda: torch.device("cpu"))
|
|
|
|
| 377 |
tmp_ledger.device = torch.device("cpu")
|
| 378 |
print("[LoRA] Building fused transformer on CPU (no GPU allocation)...")
|
| 379 |
+
new_transformer_cpu = tmp_ledger.transformer() # model on CPU
|
| 380 |
print("[LoRA] Fused transformer built on CPU.")
|
| 381 |
+
except Exception as e:
|
| 382 |
+
import traceback
|
| 383 |
+
print(f"[LoRA] Error while building fused transformer on CPU: {type(e).__name__}: {e}")
|
| 384 |
+
print(traceback.format_exc())
|
| 385 |
+
# cleanup
|
| 386 |
+
try:
|
| 387 |
+
del tmp_ledger
|
| 388 |
+
except Exception:
|
| 389 |
+
pass
|
| 390 |
+
gc.collect()
|
| 391 |
+
return
|
| 392 |
finally:
|
| 393 |
+
# restore attributes
|
| 394 |
if orig_tmp_target is not None:
|
| 395 |
tmp_ledger._target_device = orig_tmp_target
|
| 396 |
else:
|
|
|
|
| 397 |
try:
|
| 398 |
delattr(tmp_ledger, "_target_device")
|
| 399 |
except Exception:
|
|
|
|
| 406 |
except Exception:
|
| 407 |
pass
|
| 408 |
|
| 409 |
+
# Extract state_dict on CPU and save to cache
|
| 410 |
+
try:
|
| 411 |
+
new_state = new_transformer_cpu.state_dict()
|
| 412 |
+
# Save CPU state_dict for future reuse
|
| 413 |
+
try:
|
| 414 |
+
torch.save(new_state, cache_path)
|
| 415 |
+
print(f"[LoRA] Saved fused state_dict to cache: {cache_path}")
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f"[LoRA] Warning: failed to save cache {cache_path}: {type(e).__name__}: {e}")
|
| 418 |
+
except Exception as e:
|
| 419 |
+
print(f"[LoRA] Failed to get state_dict from CPU model: {type(e).__name__}: {e}")
|
| 420 |
+
new_state = None
|
| 421 |
+
|
| 422 |
+
# Free CPU model and temporary ledger to release memory
|
| 423 |
+
try:
|
| 424 |
+
del new_transformer_cpu
|
| 425 |
+
del tmp_ledger
|
| 426 |
+
except Exception:
|
| 427 |
+
pass
|
| 428 |
+
gc.collect()
|
| 429 |
+
torch.cuda.empty_cache()
|
| 430 |
+
|
| 431 |
+
if new_state is None:
|
| 432 |
+
print("[LoRA] Building fused state failed; aborting.")
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
# At this point new_state is a CPU state_dict (either from cache or just-built)
|
| 436 |
+
try:
|
| 437 |
+
# Get existing GPU-resident transformer (cached reference _transformer)
|
| 438 |
global _transformer
|
| 439 |
try:
|
| 440 |
existing_transformer = _transformer
|
| 441 |
except NameError:
|
|
|
|
| 442 |
existing_transformer = ledger.transformer()
|
| 443 |
_transformer = existing_transformer
|
| 444 |
|
|
|
|
| 445 |
existing_params = {name: param for name, param in existing_transformer.named_parameters()}
|
| 446 |
existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
|
| 447 |
|
|
|
|
|
|
|
| 448 |
# diagnostics: how many keys will be copied
|
| 449 |
total_keys = len(new_state)
|
| 450 |
matched = sum(1 for k in new_state if k in existing_params or k in existing_buffers)
|
| 451 |
print(f"[LoRA] Transformer state keys: total={total_keys} matched_for_copy={matched}")
|
| 452 |
if matched == 0:
|
|
|
|
| 453 |
sample_keys = list(new_state.keys())[:10]
|
| 454 |
print(f"[LoRA] Warning: 0 matching keys found. sample new_state keys: {sample_keys}")
|
| 455 |
|
| 456 |
+
# Copy CPU tensors into GPU-resident transformer's params/buffers in-place
|
| 457 |
with torch.no_grad():
|
| 458 |
for k, v in new_state.items():
|
| 459 |
if k in existing_params:
|
|
|
|
| 469 |
except Exception as e:
|
| 470 |
print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
|
| 471 |
else:
|
| 472 |
+
# name mismatch — skip
|
|
|
|
|
|
|
| 473 |
pass
|
| 474 |
|
| 475 |
+
# mark this key as applied
|
| 476 |
+
current_lora_key = key
|
| 477 |
+
# optional small GC
|
|
|
|
|
|
|
|
|
|
| 478 |
gc.collect()
|
| 479 |
torch.cuda.empty_cache()
|
|
|
|
| 480 |
print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
|
| 481 |
return
|
| 482 |
|
| 483 |
except Exception as e:
|
| 484 |
import traceback
|
| 485 |
+
print(f"[LoRA] Error during in-place LoRA application (copy stage): {type(e).__name__}: {e}")
|
| 486 |
print(traceback.format_exc())
|
| 487 |
+
return
|
|
|
|
|
|
|
| 488 |
|
| 489 |
# ---- REPLACE PRELOAD BLOCK START ----
|
| 490 |
# Preload all models for ZeroGPU tensor packing.
|