dagloop5 commited on
Commit
cf88763
·
verified ·
1 Parent(s): ca86fe3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -93
app.py CHANGED
@@ -291,6 +291,8 @@ pipeline = LTX23DistilledA2VPipeline(
291
  quantization=QuantizationPolicy.fp8_cast(),
292
  )
293
 
 
 
294
  # Preload all models for ZeroGPU tensor packing.
295
  # >>> REPLACE the "Preload all models" block with this one:
296
  print("Preloading models (pinning decoders/encoders but leaving transformer dynamic)...")
@@ -417,100 +419,57 @@ def generate_video(
417
  tiling_config = TilingConfig.default()
418
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
419
 
420
- # >>> RUNTIME LoRA application (robust, multi-fallback)
421
- # We cannot rely on mutating the original descriptor (some implementations are immutable),
422
- # so create a fresh runtime descriptor and try multiple ways to install it.
423
  runtime_strength = float(lora_strength)
424
- replaced = False
425
-
426
- # 1) Try simple approach: build a new LoraPathStrengthAndSDOps
427
- runtime_lora = LoraPathStrengthAndSDOps(lora_path, runtime_strength, LTXV_LORA_COMFY_RENAMING_MAP)
428
- print(f"[LoRA] attempting to apply runtime LoRA (strength={runtime_strength})")
429
-
430
- # Try a few likely places to replace the descriptor used by the pipeline/ledger.
431
- try:
432
- # common attribute on pipeline
433
- if hasattr(pipeline, "loras"):
434
- try:
435
- pipeline.loras = [runtime_lora]
436
- replaced = True
437
- print("[LoRA] replaced pipeline.loras")
438
- except Exception as e:
439
- print(f"[LoRA] pipeline.loras assignment failed: {e}")
440
- except Exception:
441
- pass
442
-
443
- try:
444
- # common attribute on the model ledger
445
- if hasattr(pipeline, "model_ledger") and hasattr(pipeline.model_ledger, "loras"):
446
- try:
447
- pipeline.model_ledger.loras = [runtime_lora]
448
- replaced = True
449
- print("[LoRA] replaced pipeline.model_ledger.loras")
450
- except Exception as e:
451
- print(f"[LoRA] pipeline.model_ledger.loras assignment failed: {e}")
452
- except Exception:
453
- pass
454
-
455
- try:
456
- # some internals use a private _loras list
457
- if hasattr(pipeline, "model_ledger") and hasattr(pipeline.model_ledger, "_loras"):
458
- try:
459
- pipeline.model_ledger._loras = [runtime_lora]
460
- replaced = True
461
- print("[LoRA] replaced pipeline.model_ledger._loras")
462
- except Exception as e:
463
- print(f"[LoRA] pipeline.model_ledger._loras assignment failed: {e}")
464
- except Exception:
465
- pass
466
-
467
- # 2) If we succeeded replacing the descriptor in-place, clear transformer cache so it will rebuild
468
- if replaced:
469
- try:
470
- if hasattr(pipeline.model_ledger, "_transformer"):
471
- pipeline.model_ledger._transformer = None
472
- # also clear potential caches named similar to 'transformer_cache' if present
473
- if hasattr(pipeline.model_ledger, "transformer_cache"):
474
- try:
475
- pipeline.model_ledger.transformer_cache = {}
476
- except Exception:
477
- pass
478
- print("[LoRA] in-place descriptor replacement done; transformer cache cleared")
479
- except Exception as e:
480
- print(f"[LoRA] replacement succeeded but cache clearing failed: {e}")
481
-
482
- # 3) FINAL FALLBACK - if none of the in-place replacements worked, rebuild the pipeline
483
- if not replaced:
484
- print("[LoRA] in-place replacement FAILED; rebuilding pipeline with runtime LoRA (this is slow)")
485
- try:
486
- # Rebuild pipeline object with the new LoRA descriptor
487
- # NOTE: this replaces the global `pipeline`. We must declare global to reassign it.
488
- pipeline = LTX23DistilledA2VPipeline(
489
- distilled_checkpoint_path=checkpoint_path,
490
- spatial_upsampler_path=spatial_upsampler_path,
491
- gemma_root=gemma_root,
492
- loras=[runtime_lora],
493
- quantization=QuantizationPolicy.fp8_cast(),
494
- )
495
-
496
- # After rebuilding, we *do not* re-run the original module-level preloads here,
497
- # because re-pinning may be complex; the rebuilt pipeline will construct its
498
- # own ledger as part of the first call. This is slower but reliable.
499
- # Clear any transformer caches if they exist on the new ledger as well.
500
- try:
501
- if hasattr(pipeline.model_ledger, "_transformer"):
502
- pipeline.model_ledger._transformer = None
503
- except Exception:
504
- pass
505
-
506
- print("[LoRA] pipeline rebuilt with runtime LoRA")
507
- except Exception as e:
508
- print(f"[LoRA] pipeline rebuild FAILED: {e}")
509
-
510
- # Finally, log memory then proceed
511
- log_memory("before pipeline call")
512
-
513
- video, audio = pipeline(
514
  prompt=prompt,
515
  seed=current_seed,
516
  height=int(height),
 
291
  quantization=QuantizationPolicy.fp8_cast(),
292
  )
293
 
294
+ pipeline_cache = {}
295
+
296
  # Preload all models for ZeroGPU tensor packing.
297
  # >>> REPLACE the "Preload all models" block with this one:
298
  print("Preloading models (pinning decoders/encoders but leaving transformer dynamic)...")
 
419
  tiling_config = TilingConfig.default()
420
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
421
 
422
+ # >>> NEW: deterministic pipeline-per-strength approach
423
+
 
424
  runtime_strength = float(lora_strength)
425
+
426
+ # round strength to avoid cache explosion (important)
427
+ cache_key = round(runtime_strength, 2)
428
+
429
+ if cache_key not in pipeline_cache:
430
+ print(f"[LoRA] building new pipeline for strength={cache_key}")
431
+
432
+ runtime_lora = LoraPathStrengthAndSDOps(
433
+ lora_path,
434
+ cache_key,
435
+ LTXV_LORA_COMFY_RENAMING_MAP
436
+ )
437
+
438
+ new_pipeline = LTX23DistilledA2VPipeline(
439
+ distilled_checkpoint_path=checkpoint_path,
440
+ spatial_upsampler_path=spatial_upsampler_path,
441
+ gemma_root=gemma_root,
442
+ loras=[runtime_lora],
443
+ quantization=QuantizationPolicy.fp8_cast(),
444
+ )
445
+
446
+ # OPTIONAL: minimal preload (safe components only)
447
+ ledger = new_pipeline.model_ledger
448
+ _video_encoder = ledger.video_encoder()
449
+ _video_decoder = ledger.video_decoder()
450
+ _audio_encoder = ledger.audio_encoder()
451
+ _audio_decoder = ledger.audio_decoder()
452
+ _vocoder = ledger.vocoder()
453
+ _spatial_upsampler = ledger.spatial_upsampler()
454
+ _text_encoder = ledger.text_encoder()
455
+ _embeddings_processor = ledger.gemma_embeddings_processor()
456
+
457
+ ledger.video_encoder = lambda: _video_encoder
458
+ ledger.video_decoder = lambda: _video_decoder
459
+ ledger.audio_encoder = lambda: _audio_encoder
460
+ ledger.audio_decoder = lambda: _audio_decoder
461
+ ledger.vocoder = lambda: _vocoder
462
+ ledger.spatial_upsampler = lambda: _spatial_upsampler
463
+ ledger.text_encoder = lambda: _text_encoder
464
+ ledger.gemma_embeddings_processor = lambda: _embeddings_processor
465
+
466
+ pipeline_cache[cache_key] = new_pipeline
467
+ else:
468
+ print(f"[LoRA] reusing cached pipeline (strength={cache_key})")
469
+
470
+ active_pipeline = pipeline_cache[cache_key]
471
+
472
+ video, audio = active_pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  prompt=prompt,
474
  seed=current_seed,
475
  height=int(height),