dagloop5 commited on
Commit
9921524
·
verified ·
1 Parent(s): 2ec52b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -42
app.py CHANGED
@@ -32,6 +32,7 @@ import logging
32
  import random
33
  import tempfile
34
  from pathlib import Path
 
35
 
36
  import torch
37
  torch._dynamo.config.suppress_errors = True
@@ -302,26 +303,22 @@ pipeline = LTX23DistilledA2VPipeline(
302
 
303
  def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
304
  """
305
- Build a temporary ModelLedger configured with the requested LoRAs, build the transformer,
306
- and hot-swap it into the existing ledger without recreating the pipeline object.
307
-
308
- Strategy:
309
- 1. Construct LoraPathStrengthAndSDOps entries for any non-zero strengths.
310
- 2. Use ledger.with_loras(...) to get a temporary ledger configured with those loras.
311
- 3. Optionally clear the existing cached transformer to reduce peak VRAM, then build the
312
- transformer from the temporary ledger and hot-swap it into the live ledger.
313
- 4. If anything fails, print diagnostics and leave the existing pipeline in place.
314
  """
315
  ledger = pipeline.model_ledger
316
 
317
- # Build convenience list and convert to the LTX primitive type (with sd_ops)
318
  entries = [
319
  (pose_lora_path, float(pose_strength)),
320
  (general_lora_path, float(general_strength)),
321
  (motion_lora_path, float(motion_strength)),
322
  ]
323
 
324
- # Keep only nonzero strengths and valid paths (zero == disabled)
325
  loras_for_builder = [
326
  LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
327
  for path, strength in entries
@@ -333,53 +330,83 @@ def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motio
333
  return
334
 
335
  try:
336
- # Create a temporary ledger configured with the extra LoRAs.
337
- # with_loras accepts an iterable of LoraPathStrengthAndSDOps.
338
  tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
339
  print(f"[LoRA] Built temporary ledger with {len(loras_for_builder)} LoRA(s).")
340
 
341
- # Attempt to free previously cached transformer instance to reduce peak VRAM.
342
- # (We cached instances earlier and replaced ledger.<component> with lambdas returning them.)
 
 
343
  try:
344
- # If ModelLedger implements clear_vram, call it to release cached GPU tensors.
345
- if hasattr(ledger, "clear_vram"):
346
- ledger.clear_vram()
347
- print("[LoRA] Cleared old ledger VRAM cache before building new transformer.")
348
- except Exception as e:
349
- print(f"[LoRA] Warning: ledger.clear_vram() failed: {type(e).__name__}: {e}")
350
-
351
- # Build the new transformer from the temporary ledger (this will load & fuse LoRAs).
352
- print("[LoRA] Building transformer from temporary ledger (this may take time / spike VRAM)...")
353
- new_transformer = tmp_ledger.transformer() # returns an X0Model moved to device
354
- print("[LoRA] New transformer built successfully.")
355
-
356
- # Replace cached transformer instance and hot-swap ledger.transformer to return the new one.
357
  global _transformer
358
  try:
359
- # Remove old Python ref if present to allow GC.
360
- del _transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  except Exception:
362
  pass
 
363
  torch.cuda.empty_cache()
364
- _transformer = new_transformer
365
- ledger.transformer = lambda: _transformer
366
- print("[LoRA] Hot-swapped new transformer into ledger successfully.")
367
 
368
- # Done
369
  return
370
 
371
  except Exception as e:
372
  import traceback
373
- print(f"[LoRA] Error during builder-based LoRA application: {type(e).__name__}: {e}")
374
  print(traceback.format_exc())
375
 
376
- # Final fallback (should rarely hit if above path works):
377
- try:
378
- print("[LoRA] Falling back to pipeline.loras attribute assignment (best-effort).")
379
- pipeline.loras = [(p, float(s)) for p, s in entries if p is not None]
380
- except Exception as e:
381
- print(f"[LoRA] Fallback pipeline.loras assignment failed: {type(e).__name__}: {e}")
382
- print("[LoRA] apply_loras_to_pipeline finished (some approaches may not have taken effect).")
383
 
384
  # ---- REPLACE PRELOAD BLOCK START ----
385
  # Preload all models for ZeroGPU tensor packing.
 
32
  import random
33
  import tempfile
34
  from pathlib import Path
35
+ import gc
36
 
37
  import torch
38
  torch._dynamo.config.suppress_errors = True
 
303
 
304
  def apply_loras_to_pipeline(pose_strength: float, general_strength: float, motion_strength: float):
305
  """
306
+ Apply LoRAs by:
307
+ 1) creating a temporary ledger with requested LoRAs,
308
+ 2) building the fused transformer on CPU only,
309
+ 3) copying parameters & buffers in-place into the existing GPU transformer,
310
+ 4) freeing CPU objects and clearing cache.
311
+ This avoids having two full transformers on GPU simultaneously.
 
 
 
312
  """
313
  ledger = pipeline.model_ledger
314
 
 
315
  entries = [
316
  (pose_lora_path, float(pose_strength)),
317
  (general_lora_path, float(general_strength)),
318
  (motion_lora_path, float(motion_strength)),
319
  ]
320
 
321
+ # Build LoraPathStrengthAndSDOps for non-zero strengths
322
  loras_for_builder = [
323
  LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
324
  for path, strength in entries
 
330
  return
331
 
332
  try:
333
+ # Create temporary ledger configured with LoRAs
 
334
  tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
335
  print(f"[LoRA] Built temporary ledger with {len(loras_for_builder)} LoRA(s).")
336
 
337
+ # Force the temporary ledger to build on CPU so the fused model is built on CPU.
338
+ # Save original attributes to restore them later.
339
+ orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
340
+ orig_tmp_device = getattr(tmp_ledger, "device", None)
341
  try:
342
+ tmp_ledger._target_device = torch.device("cpu")
343
+ tmp_ledger.device = torch.device("cpu")
344
+ print("[LoRA] Building fused transformer on CPU (no GPU allocation)...")
345
+ new_transformer_cpu = tmp_ledger.transformer() # returns model on CPU now
346
+ print("[LoRA] Fused transformer built on CPU.")
347
+ finally:
348
+ # Restore attributes (defensive)
349
+ if orig_tmp_target is not None:
350
+ tmp_ledger._target_device = orig_tmp_target
351
+ if orig_tmp_device is not None:
352
+ tmp_ledger.device = orig_tmp_device
353
+
354
+ # Get the existing transformer instance (the one currently used by the pipeline).
355
  global _transformer
356
  try:
357
+ existing_transformer = _transformer
358
+ except NameError:
359
+ # If not cached, ask ledger for it (this will be the GPU-resident model already loaded).
360
+ existing_transformer = ledger.transformer()
361
+ _transformer = existing_transformer
362
+
363
+ # Map existing parameters & buffers for quick lookup
364
+ existing_params = {name: param for name, param in existing_transformer.named_parameters()}
365
+ existing_buffers = {name: buf for name, buf in existing_transformer.named_buffers()}
366
+
367
+ # State dict of CPU model (fused with LoRAs)
368
+ new_state = new_transformer_cpu.state_dict()
369
+
370
+ # Copy CPU tensors into the GPU-resident transformer's params/buffers in-place
371
+ with torch.no_grad():
372
+ for k, v in new_state.items():
373
+ if k in existing_params:
374
+ tgt = existing_params[k].data
375
+ try:
376
+ tgt.copy_(v.to(tgt.device))
377
+ except Exception as e:
378
+ print(f"[LoRA] Failed to copy parameter {k}: {type(e).__name__}: {e}")
379
+ elif k in existing_buffers:
380
+ tgt = existing_buffers[k].data
381
+ try:
382
+ tgt.copy_(v.to(tgt.device))
383
+ except Exception as e:
384
+ print(f"[LoRA] Failed to copy buffer {k}: {type(e).__name__}: {e}")
385
+ else:
386
+ # Parameter name mismatch — skip
387
+ # This can happen if LoRA changes expected keys; not fatal.
388
+ # Print debug once for the first few unmatched keys.
389
+ pass
390
+
391
+ # Free CPU-built transformer and temporary ledger resources, then clear caches
392
+ try:
393
+ del new_transformer_cpu
394
+ del tmp_ledger
395
  except Exception:
396
  pass
397
+ gc.collect()
398
  torch.cuda.empty_cache()
 
 
 
399
 
400
+ print("[LoRA] In-place parameter copy complete. LoRAs applied to the existing transformer.")
401
  return
402
 
403
  except Exception as e:
404
  import traceback
405
+ print(f"[LoRA] Error during in-place LoRA application: {type(e).__name__}: {e}")
406
  print(traceback.format_exc())
407
 
408
+ # If something unexpectedly failed, bail out (no fallback).
409
+ print("[LoRA] apply_loras_to_pipeline finished (LOADING FAILED — no changes applied).")
 
 
 
 
 
410
 
411
  # ---- REPLACE PRELOAD BLOCK START ----
412
  # Preload all models for ZeroGPU tensor packing.