dagloop5 commited on
Commit
428e251
·
verified ·
1 Parent(s): d27af4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -138
app.py CHANGED
@@ -272,6 +272,10 @@ LORA_CACHE_DIR = Path("lora_cache")
272
  LORA_CACHE_DIR.mkdir(exist_ok=True)
273
  current_lora_key: str | None = None
274
 
 
 
 
 
275
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
276
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
277
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
@@ -307,91 +311,76 @@ pipeline = LTX23DistilledA2VPipeline(
307
  )
308
  # ----------------------------------------------------------------
309
 
310
- def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
311
- """
312
- Apply LoRAs using cached fused state_dicts when available, otherwise
313
- build fused transformer on CPU, save its state_dict to cache, and then
314
- copy parameters in-place into the GPU-resident transformer.
315
-
316
- Caching key:
317
- sha256( f"{pose_path}:{round(pose,2)}|{general_path}:{round(general,2)}|{motion_path}:{round(motion,2)}" )
318
-
319
- Rounding to 2 decimals reduces unique keys and increases cache hits.
320
- """
321
- ledger = pipeline.model_ledger
322
- global current_lora_key, LORA_CACHE_DIR
323
-
324
- # Round strengths to 2 decimals for cache stability
325
  rp = round(float(pose_strength), 2)
326
  rg = round(float(general_strength), 2)
327
  rm = round(float(motion_strength), 2)
328
-
329
  key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}"
330
  key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
331
- cache_path = LORA_CACHE_DIR / f"{key}.pt"
332
 
333
- # If same key already applied, skip
334
- if current_lora_key == key:
335
- print("[LoRA] Key unchanged; skipping rebuild/copy.")
336
- return
337
 
338
- # If cache exists, load the fused state_dict directly
339
- new_state = None
340
- if cache_path.exists():
341
- try:
342
- print("[LoRA] Cache hit: loading fused state_dict from cache.")
343
- new_state = torch.load(cache_path, map_location="cpu")
344
- print("[LoRA] Loaded cached fused state_dict.")
345
- except Exception as e:
346
- print(f"[LoRA] Failed to load cache {cache_path}: {type(e).__name__}: {e}")
347
- new_state = None
348
-
349
- # If no cache, build tmp_ledger and create fused transformer on CPU, save state_dict
350
- if new_state is None:
351
- # Only build if there is at least one non-zero strength
352
- entries = [
353
- (pose_lora_path, rp),
354
- (general_lora_path, rg),
355
- (motion_lora_path, rm),
356
- ]
357
- loras_for_builder = [
358
- LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
359
- for path, strength in entries
360
- if path is not None and float(strength) != 0.0
361
- ]
362
- if len(loras_for_builder) == 0:
363
- print("[LoRA] No nonzero LoRA strengths — skipping rebuild.")
364
- return
365
 
 
 
 
 
 
 
366
  try:
367
- tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
368
- print(f"[LoRA] Built temporary ledger with {len(loras_for_builder)} LoRA(s).")
 
 
 
 
369
  except Exception as e:
370
- print(f"[LoRA] Failed to create temporary ledger: {type(e).__name__}: {e}")
371
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
- # Build fused transformer on CPU only. Ensure tmp_ledger._target_device is callable.
374
  orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
375
  orig_tmp_device = getattr(tmp_ledger, "device", None)
376
  try:
377
- tmp_ledger._target_device = (lambda: torch.device("cpu"))
378
  tmp_ledger.device = torch.device("cpu")
379
- print("[LoRA] Building fused transformer on CPU (no GPU allocation)...")
380
- new_transformer_cpu = tmp_ledger.transformer() # model on CPU
381
- print("[LoRA] Fused transformer built on CPU.")
382
- except Exception as e:
383
- import traceback
384
- print(f"[LoRA] Error while building fused transformer on CPU: {type(e).__name__}: {e}")
385
- print(traceback.format_exc())
386
- # cleanup
387
- try:
388
- del tmp_ledger
389
- except Exception:
390
- pass
391
- gc.collect()
392
- return
393
  finally:
394
- # restore attributes
395
  if orig_tmp_target is not None:
396
  tmp_ledger._target_device = orig_tmp_target
397
  else:
@@ -407,85 +396,65 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
407
  except Exception:
408
  pass
409
 
410
- # Extract state_dict on CPU and save to cache
411
- try:
412
- new_state = new_transformer_cpu.state_dict()
413
- # Save CPU state_dict for future reuse
414
- try:
415
- torch.save(new_state, cache_path)
416
- print(f"[LoRA] Saved fused state_dict to cache: {cache_path}")
417
- except Exception as e:
418
- print(f"[LoRA] Warning: failed to save cache {cache_path}: {type(e).__name__}: {e}")
419
- except Exception as e:
420
- print(f"[LoRA] Failed to get state_dict from CPU model: {type(e).__name__}: {e}")
421
- new_state = None
 
 
 
 
 
422
 
423
- # Free CPU model and temporary ledger to release memory
424
  try:
425
  del new_transformer_cpu
 
 
 
426
  del tmp_ledger
427
  except Exception:
428
  pass
429
  gc.collect()
430
- torch.cuda.empty_cache()
431
 
432
- if new_state is None:
433
- print("[LoRA] Building fused state failed; aborting.")
434
- return
435
 
436
- # At this point new_state is a CPU state_dict (either from cache or just-built)
437
- try:
438
- # Get existing GPU-resident transformer (cached reference _transformer)
439
- global _transformer
440
- try:
441
- existing_transformer = _transformer
442
- except NameError:
443
- existing_transformer = ledger.transformer()
444
- _transformer = existing_transformer
445
-
446
- existing_params = {name: param for name, param in existing_transformer.named_parameters()}
447
- existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
448
-
449
- # diagnostics: how many keys will be copied
450
- total_keys = len(new_state)
451
- matched = sum(1 for k in new_state if k in existing_params or k in existing_buffers)
452
- print(f"[LoRA] Transformer state keys: total={total_keys} matched_for_copy={matched}")
453
- if matched == 0:
454
- sample_keys = list(new_state.keys())[:10]
455
- print(f"[LoRA] Warning: 0 matching keys found. sample new_state keys: {sample_keys}")
456
-
457
- # Copy CPU tensors into GPU-resident transformer's params/buffers in-place
458
- with torch.no_grad():
459
- for k, v in new_state.items():
460
- if k in existing_params:
461
- tgt = existing_params[k].data
462
- try:
463
- tgt.copy_(v.to(tgt.device))
464
- except Exception as e:
465
- print(f"[LoRA] Failed to copy parameter {k}: {type(e).__name__}: {e}")
466
- elif k in existing_buffers:
467
- tgt = existing_buffers[k].data
468
- try:
469
- tgt.copy_(v.to(tgt.device))
470
- except Exception as e:
471
- print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
472
- else:
473
- # name mismatch — skip
474
- pass
475
 
476
- # mark this key as applied
477
- current_lora_key = key
478
- # optional small GC
479
- gc.collect()
480
- torch.cuda.empty_cache()
481
- print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
482
- return
483
 
484
- except Exception as e:
485
- import traceback
486
- print(f"[LoRA] Error during in-place LoRA application (copy stage): {type(e).__name__}: {e}")
487
- print(traceback.format_exc())
488
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  # ---- REPLACE PRELOAD BLOCK START ----
491
  # Preload all models for ZeroGPU tensor packing.
@@ -648,7 +617,7 @@ def generate_video(
648
 
649
  log_memory("before pipeline call")
650
 
651
- apply_loras_to_pipeline(pose_strength, general_strength, motion_strength)
652
 
653
  video, audio = pipeline(
654
  prompt=prompt,
@@ -729,6 +698,12 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
729
  label="Motion Helper strength",
730
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
731
  )
 
 
 
 
 
 
732
 
733
  with gr.Column():
734
  output_video = gr.Video(label="Generated Video", autoplay=False)
@@ -789,6 +764,12 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
789
  outputs=[width, height],
790
  )
791
 
 
 
 
 
 
 
792
  generate_btn.click(
793
  fn=generate_video,
794
  inputs=[
 
272
  LORA_CACHE_DIR.mkdir(exist_ok=True)
273
  current_lora_key: str | None = None
274
 
275
+ PENDING_LORA_KEY: str | None = None
276
+ PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
277
+ PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
278
+
279
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
280
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
281
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
 
311
  )
312
  # ----------------------------------------------------------------
313
 
314
+ def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float) -> tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  rp = round(float(pose_strength), 2)
316
  rg = round(float(general_strength), 2)
317
  rm = round(float(motion_strength), 2)
 
318
  key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}"
319
  key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
320
+ return key, key_str
321
 
 
 
 
 
322
 
323
+ def prepare_lora_cache(
324
+ pose_strength: float,
325
+ general_strength: float,
326
+ motion_strength: float,
327
+ progress=gr.Progress(track_tqdm=True),
328
+ ):
329
+ """
330
+ CPU-only step:
331
+ - checks cache
332
+ - loads cached fused transformer state_dict, or
333
+ - builds fused transformer on CPU and saves it
334
+ The resulting state_dict is stored in memory and can be applied later.
335
+ """
336
+ global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ ledger = pipeline.model_ledger
339
+ key, _ = _make_lora_key(pose_strength, general_strength, motion_strength)
340
+ cache_path = LORA_CACHE_DIR / f"{key}.pt"
341
+
342
+ progress(0.05, desc="Preparing LoRA state")
343
+ if cache_path.exists():
344
  try:
345
+ progress(0.20, desc="Loading cached fused state")
346
+ state = torch.load(cache_path, map_location="cpu")
347
+ PENDING_LORA_KEY = key
348
+ PENDING_LORA_STATE = state
349
+ PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
350
+ return PENDING_LORA_STATUS
351
  except Exception as e:
352
+ print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
353
+
354
+ entries = [
355
+ (pose_lora_path, round(float(pose_strength), 2)),
356
+ (general_lora_path, round(float(general_strength), 2)),
357
+ (motion_lora_path, round(float(motion_strength), 2)),
358
+ ]
359
+ loras_for_builder = [
360
+ LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
361
+ for path, strength in entries
362
+ if path is not None and float(strength) != 0.0
363
+ ]
364
+
365
+ if not loras_for_builder:
366
+ PENDING_LORA_KEY = None
367
+ PENDING_LORA_STATE = None
368
+ PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
369
+ return PENDING_LORA_STATUS
370
+
371
+ tmp_ledger = None
372
+ new_transformer_cpu = None
373
+ try:
374
+ progress(0.35, desc="Building fused CPU transformer")
375
+ tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
376
 
 
377
  orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
378
  orig_tmp_device = getattr(tmp_ledger, "device", None)
379
  try:
380
+ tmp_ledger._target_device = lambda: torch.device("cpu")
381
  tmp_ledger.device = torch.device("cpu")
382
+ new_transformer_cpu = tmp_ledger.transformer()
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  finally:
 
384
  if orig_tmp_target is not None:
385
  tmp_ledger._target_device = orig_tmp_target
386
  else:
 
396
  except Exception:
397
  pass
398
 
399
+ progress(0.70, desc="Extracting fused state_dict")
400
+ state = new_transformer_cpu.state_dict()
401
+ torch.save(state, cache_path)
402
+
403
+ PENDING_LORA_KEY = key
404
+ PENDING_LORA_STATE = state
405
+ PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
406
+ return PENDING_LORA_STATUS
407
+
408
+ except Exception as e:
409
+ import traceback
410
+ print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
411
+ print(traceback.format_exc())
412
+ PENDING_LORA_KEY = None
413
+ PENDING_LORA_STATE = None
414
+ PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
415
+ return PENDING_LORA_STATUS
416
 
417
+ finally:
418
  try:
419
  del new_transformer_cpu
420
+ except Exception:
421
+ pass
422
+ try:
423
  del tmp_ledger
424
  except Exception:
425
  pass
426
  gc.collect()
 
427
 
 
 
 
428
 
429
+ def apply_prepared_lora_state_to_pipeline():
430
+ """
431
+ Fast step: copy the already prepared CPU state into the live transformer.
432
+ This is the only part that should remain near generation time.
433
+ """
434
+ global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
437
+ print("[LoRA] No prepared LoRA state available; skipping.")
438
+ return False
 
 
 
 
439
 
440
+ if current_lora_key == PENDING_LORA_KEY:
441
+ print("[LoRA] Prepared LoRA state already active; skipping.")
442
+ return True
443
+
444
+ existing_transformer = _transformer
445
+ existing_params = {name: param for name, param in existing_transformer.named_parameters()}
446
+ existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
447
+
448
+ with torch.no_grad():
449
+ for k, v in PENDING_LORA_STATE.items():
450
+ if k in existing_params:
451
+ existing_params[k].data.copy_(v.to(existing_params[k].device))
452
+ elif k in existing_buffers:
453
+ existing_buffers[k].data.copy_(v.to(existing_buffers[k].device))
454
+
455
+ current_lora_key = PENDING_LORA_KEY
456
+ print("[LoRA] Prepared LoRA state applied to the pipeline.")
457
+ return True
458
 
459
  # ---- REPLACE PRELOAD BLOCK START ----
460
  # Preload all models for ZeroGPU tensor packing.
 
617
 
618
  log_memory("before pipeline call")
619
 
620
+ apply_prepared_lora_state_to_pipeline()
621
 
622
  video, audio = pipeline(
623
  prompt=prompt,
 
698
  label="Motion Helper strength",
699
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
700
  )
701
+ prepare_lora_btn = gr.Button("Prepare / Load LoRA Cache", variant="secondary")
702
+ lora_status = gr.Textbox(
703
+ label="LoRA Cache Status",
704
+ value="No LoRA state prepared yet.",
705
+ interactive=False,
706
+ )
707
 
708
  with gr.Column():
709
  output_video = gr.Video(label="Generated Video", autoplay=False)
 
764
  outputs=[width, height],
765
  )
766
 
767
+ prepare_lora_btn.click(
768
+ fn=prepare_lora_cache,
769
+ inputs=[pose_strength, general_strength, motion_strength],
770
+ outputs=[lora_status],
771
+ )
772
+
773
  generate_btn.click(
774
  fn=generate_video,
775
  inputs=[