dagloop5 commited on
Commit
b7c6b7f
·
verified ·
1 Parent(s): ae1a1ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -54
app.py CHANGED
@@ -266,6 +266,11 @@ print("=" * 80)
266
  print("Downloading LTX-2.3 distilled model + Gemma...")
267
  print("=" * 80)
268
 
 
 
 
 
 
269
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
270
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
271
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
@@ -303,56 +308,92 @@ pipeline = LTX23DistilledA2VPipeline(
303
 
304
  def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
305
  """
306
- Apply LoRAs by:
307
- 1) creating a temporary ledger with requested LoRAs,
308
- 2) building the fused transformer on CPU only,
309
- 3) copying parameters & buffers in-place into the existing GPU transformer,
310
- 4) freeing CPU objects and clearing cache.
311
- This avoids having two full transformers on GPU simultaneously.
 
 
312
  """
313
  ledger = pipeline.model_ledger
 
 
 
 
 
 
314
 
315
- entries = [
316
- (pose_lora_path, float(pose_strength)),
317
- (general_lora_path, float(general_strength)),
318
- (motion_lora_path, float(motion_strength)),
319
- ]
320
-
321
- # Build LoraPathStrengthAndSDOps for non-zero strengths
322
- loras_for_builder = [
323
- LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
324
- for path, strength in entries
325
- if path is not None and float(strength) != 0.0
326
- ]
327
-
328
- if len(loras_for_builder) == 0:
329
- print("[LoRA] No nonzero LoRA strengths — skipping rebuild.")
330
  return
331
 
332
- try:
333
- # Create temporary ledger configured with LoRAs
334
- tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
335
- print(f"[LoRA] Built temporary ledger with {len(loras_for_builder)} LoRA(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- # Force the temporary ledger to build on CPU so the fused model is built on CPU.
338
- # Save original attributes to restore them later.
 
 
 
 
 
 
339
  orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
340
  orig_tmp_device = getattr(tmp_ledger, "device", None)
341
  try:
342
- # _target_device is expected to be callable by model_ledger.transformer()
343
- # set it to a callable that returns CPU so builder.build(device=...) works.
344
  tmp_ledger._target_device = (lambda: torch.device("cpu"))
345
- # ledger.device is used after build: set it to CPU so .to(self.device) keeps the model on CPU.
346
  tmp_ledger.device = torch.device("cpu")
347
  print("[LoRA] Building fused transformer on CPU (no GPU allocation)...")
348
- new_transformer_cpu = tmp_ledger.transformer() # should now return a CPU model
349
  print("[LoRA] Fused transformer built on CPU.")
 
 
 
 
 
 
 
 
 
 
 
350
  finally:
351
- # Restore attributes to their previous values (if there were any).
352
  if orig_tmp_target is not None:
353
  tmp_ledger._target_device = orig_tmp_target
354
  else:
355
- # remove attribute if ledger did not have it previously
356
  try:
357
  delattr(tmp_ledger, "_target_device")
358
  except Exception:
@@ -365,31 +406,54 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
365
  except Exception:
366
  pass
367
 
368
- # Get the existing transformer instance (the one currently used by the pipeline).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  global _transformer
370
  try:
371
  existing_transformer = _transformer
372
  except NameError:
373
- # If not cached, ask ledger for it (this will be the GPU-resident model already loaded).
374
  existing_transformer = ledger.transformer()
375
  _transformer = existing_transformer
376
 
377
- # Map existing parameters & buffers for quick lookup
378
  existing_params = {name: param for name, param in existing_transformer.named_parameters()}
379
  existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
380
 
381
- # State dict of CPU model (fused with LoRAs)
382
- new_state = new_transformer_cpu.state_dict()
383
  # diagnostics: how many keys will be copied
384
  total_keys = len(new_state)
385
  matched = sum(1 for k in new_state if k in existing_params or k in existing_buffers)
386
  print(f"[LoRA] Transformer state keys: total={total_keys} matched_for_copy={matched}")
387
  if matched == 0:
388
- # helpful hint if naming differs
389
  sample_keys = list(new_state.keys())[:10]
390
  print(f"[LoRA] Warning: 0 matching keys found. sample new_state keys: {sample_keys}")
391
 
392
- # Copy CPU tensors into the GPU-resident transformer's params/buffers in-place
393
  with torch.no_grad():
394
  for k, v in new_state.items():
395
  if k in existing_params:
@@ -405,30 +469,22 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
405
  except Exception as e:
406
  print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
407
  else:
408
- # Parameter name mismatch — skip
409
- # This can happen if LoRA changes expected keys; not fatal.
410
- # Print debug once for the first few unmatched keys.
411
  pass
412
 
413
- # Free CPU-built transformer and temporary ledger resources, then clear caches
414
- try:
415
- del new_transformer_cpu
416
- del tmp_ledger
417
- except Exception:
418
- pass
419
  gc.collect()
420
  torch.cuda.empty_cache()
421
-
422
  print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
423
  return
424
 
425
  except Exception as e:
426
  import traceback
427
- print(f"[LoRA] Error during in-place LoRA application: {type(e).__name__}: {e}")
428
  print(traceback.format_exc())
429
-
430
- # If something unexpectedly failed, bail out (no fallback).
431
- print("[LoRA] apply_loras_to_pipeline finished (LOADING FAILED — no changes applied).")
432
 
433
  # ---- REPLACE PRELOAD BLOCK START ----
434
  # Preload all models for ZeroGPU tensor packing.
 
266
  print("Downloading LTX-2.3 distilled model + Gemma...")
267
  print("=" * 80)
268
 
269
+ # LoRA cache directory and currently-applied key
270
+ LORA_CACHE_DIR = Path("lora_cache")
271
+ LORA_CACHE_DIR.mkdir(exist_ok=True)
272
+ current_lora_key: str | None = None
273
+
274
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
275
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
276
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
 
308
 
309
  def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
310
  """
311
+ Apply LoRAs using cached fused state_dicts when available, otherwise
312
+ build fused transformer on CPU, save its state_dict to cache, and then
313
+ copy parameters in-place into the GPU-resident transformer.
314
+
315
+ Caching key:
316
+ sha256( f"{pose_path}:{round(pose,2)}|{general_path}:{round(general,2)}|{motion_path}:{round(motion,2)}" )
317
+
318
+ Rounding to 2 decimals reduces unique keys and increases cache hits.
319
  """
320
  ledger = pipeline.model_ledger
321
+ global current_lora_key, LORA_CACHE_DIR
322
+
323
+ # Round strengths to 2 decimals for cache stability
324
+ rp = round(float(pose_strength), 2)
325
+ rg = round(float(general_strength), 2)
326
+ rm = round(float(motion_strength), 2)
327
 
328
+ key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}"
329
+ key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
330
+ cache_path = LORA_CACHE_DIR / f"{key}.pt"
331
+
332
+ # If same key already applied, skip
333
+ if current_lora_key == key:
334
+ print("[LoRA] Key unchanged; skipping rebuild/copy.")
 
 
 
 
 
 
 
 
335
  return
336
 
337
+ # If cache exists, load the fused state_dict directly
338
+ new_state = None
339
+ if cache_path.exists():
340
+ try:
341
+ print("[LoRA] Cache hit: loading fused state_dict from cache.")
342
+ new_state = torch.load(cache_path, map_location="cpu")
343
+ print("[LoRA] Loaded cached fused state_dict.")
344
+ except Exception as e:
345
+ print(f"[LoRA] Failed to load cache {cache_path}: {type(e).__name__}: {e}")
346
+ new_state = None
347
+
348
+ # If no cache, build tmp_ledger and create fused transformer on CPU, save state_dict
349
+ if new_state is None:
350
+ # Only build if there is at least one non-zero strength
351
+ entries = [
352
+ (pose_lora_path, rp),
353
+ (general_lora_path, rg),
354
+ (motion_lora_path, rm),
355
+ ]
356
+ loras_for_builder = [
357
+ LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
358
+ for path, strength in entries
359
+ if path is not None and float(strength) != 0.0
360
+ ]
361
+ if len(loras_for_builder) == 0:
362
+ print("[LoRA] No nonzero LoRA strengths — skipping rebuild.")
363
+ return
364
 
365
+ try:
366
+ tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
367
+ print(f"[LoRA] Built temporary ledger with {len(loras_for_builder)} LoRA(s).")
368
+ except Exception as e:
369
+ print(f"[LoRA] Failed to create temporary ledger: {type(e).__name__}: {e}")
370
+ return
371
+
372
+ # Build fused transformer on CPU only. Ensure tmp_ledger._target_device is callable.
373
  orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
374
  orig_tmp_device = getattr(tmp_ledger, "device", None)
375
  try:
 
 
376
  tmp_ledger._target_device = (lambda: torch.device("cpu"))
 
377
  tmp_ledger.device = torch.device("cpu")
378
  print("[LoRA] Building fused transformer on CPU (no GPU allocation)...")
379
+ new_transformer_cpu = tmp_ledger.transformer() # model on CPU
380
  print("[LoRA] Fused transformer built on CPU.")
381
+ except Exception as e:
382
+ import traceback
383
+ print(f"[LoRA] Error while building fused transformer on CPU: {type(e).__name__}: {e}")
384
+ print(traceback.format_exc())
385
+ # cleanup
386
+ try:
387
+ del tmp_ledger
388
+ except Exception:
389
+ pass
390
+ gc.collect()
391
+ return
392
  finally:
393
+ # restore attributes
394
  if orig_tmp_target is not None:
395
  tmp_ledger._target_device = orig_tmp_target
396
  else:
 
397
  try:
398
  delattr(tmp_ledger, "_target_device")
399
  except Exception:
 
406
  except Exception:
407
  pass
408
 
409
+ # Extract state_dict on CPU and save to cache
410
+ try:
411
+ new_state = new_transformer_cpu.state_dict()
412
+ # Save CPU state_dict for future reuse
413
+ try:
414
+ torch.save(new_state, cache_path)
415
+ print(f"[LoRA] Saved fused state_dict to cache: {cache_path}")
416
+ except Exception as e:
417
+ print(f"[LoRA] Warning: failed to save cache {cache_path}: {type(e).__name__}: {e}")
418
+ except Exception as e:
419
+ print(f"[LoRA] Failed to get state_dict from CPU model: {type(e).__name__}: {e}")
420
+ new_state = None
421
+
422
+ # Free CPU model and temporary ledger to release memory
423
+ try:
424
+ del new_transformer_cpu
425
+ del tmp_ledger
426
+ except Exception:
427
+ pass
428
+ gc.collect()
429
+ torch.cuda.empty_cache()
430
+
431
+ if new_state is None:
432
+ print("[LoRA] Building fused state failed; aborting.")
433
+ return
434
+
435
+ # At this point new_state is a CPU state_dict (either from cache or just-built)
436
+ try:
437
+ # Get existing GPU-resident transformer (cached reference _transformer)
438
  global _transformer
439
  try:
440
  existing_transformer = _transformer
441
  except NameError:
 
442
  existing_transformer = ledger.transformer()
443
  _transformer = existing_transformer
444
 
 
445
  existing_params = {name: param for name, param in existing_transformer.named_parameters()}
446
  existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
447
 
 
 
448
  # diagnostics: how many keys will be copied
449
  total_keys = len(new_state)
450
  matched = sum(1 for k in new_state if k in existing_params or k in existing_buffers)
451
  print(f"[LoRA] Transformer state keys: total={total_keys} matched_for_copy={matched}")
452
  if matched == 0:
 
453
  sample_keys = list(new_state.keys())[:10]
454
  print(f"[LoRA] Warning: 0 matching keys found. sample new_state keys: {sample_keys}")
455
 
456
+ # Copy CPU tensors into GPU-resident transformer's params/buffers in-place
457
  with torch.no_grad():
458
  for k, v in new_state.items():
459
  if k in existing_params:
 
469
  except Exception as e:
470
  print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
471
  else:
472
+ # name mismatch — skip
 
 
473
  pass
474
 
475
+ # mark this key as applied
476
+ current_lora_key = key
477
+ # optional small GC
 
 
 
478
  gc.collect()
479
  torch.cuda.empty_cache()
 
480
  print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
481
  return
482
 
483
  except Exception as e:
484
  import traceback
485
+ print(f"[LoRA] Error during in-place LoRA application (copy stage): {type(e).__name__}: {e}")
486
  print(traceback.format_exc())
487
+ return
 
 
488
 
489
  # ---- REPLACE PRELOAD BLOCK START ----
490
  # Preload all models for ZeroGPU tensor packing.