dagloop5 commited on
Commit
cafc416
·
verified ·
1 Parent(s): 75ef1d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -25
app.py CHANGED
@@ -35,6 +35,7 @@ import tempfile
35
  from pathlib import Path
36
  import gc
37
  import hashlib
 
38
 
39
  import torch
40
  torch._dynamo.config.suppress_errors = True
@@ -44,7 +45,6 @@ import spaces
44
  import gradio as gr
45
  import numpy as np
46
  from huggingface_hub import hf_hub_download, snapshot_download
47
- from safetensors.torch import load_file, save_file
48
  from safetensors import safe_open
49
  import json
50
  import requests
@@ -263,9 +263,14 @@ print("=" * 80)
263
  print("Downloading LTX-2.3 distilled model + Gemma...")
264
  print("=" * 80)
265
 
266
- # LoRA cache directory and currently-applied key
267
- LORA_CACHE_DIR = Path("lora_cache")
268
- LORA_CACHE_DIR.mkdir(exist_ok=True)
 
 
 
 
 
269
  current_lora_key: str | None = None
270
 
271
  PENDING_LORA_KEY: str | None = None
@@ -425,21 +430,16 @@ def prepare_lora_cache(
425
  global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
426
 
427
  ledger = pipeline.model_ledger
428
- key, _ = _make_lora_key(singularity_strength, teneros_strength, sulphur_strength, pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength, mcfurry_strength, dm_strength, praxis_strength, threed_strength, concept_strength, bulge_strength)
429
- cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
 
 
 
 
 
 
430
 
431
  progress(0.05, desc="Preparing LoRA state")
432
- if cache_path.exists():
433
- try:
434
- progress(0.20, desc="Loading cached fused state")
435
- state = load_file(str(cache_path))
436
- PENDING_LORA_KEY = key
437
- PENDING_LORA_STATE = state
438
- PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
439
- return PENDING_LORA_STATUS
440
- except Exception as e:
441
- print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
442
-
443
  entries = [
444
  (singularity_lora_path, round(float(singularity_strength), 2)),
445
  (teneros_lora_path, round(float(teneros_strength), 2)),
@@ -498,11 +498,9 @@ def prepare_lora_cache(
498
  k: v.detach().cpu().contiguous()
499
  for k, v in new_transformer_cpu.state_dict().items()
500
  }
501
- save_file(state, str(cache_path))
502
-
503
  PENDING_LORA_KEY = key
504
  PENDING_LORA_STATE = state
505
- PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
506
  return PENDING_LORA_STATUS
507
 
508
  except Exception as e:
@@ -528,19 +526,28 @@ def prepare_lora_cache(
528
 
529
  def apply_prepared_lora_state_to_pipeline():
530
  """
531
- Fast step: copy the already prepared CPU state into the live transformer.
532
- This is the only part that should remain near generation time.
533
  """
534
- global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
535
 
536
- if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
537
- print("[LoRA] No prepared LoRA state available; skipping.")
538
  return False
539
 
 
 
540
  if current_lora_key == PENDING_LORA_KEY:
 
 
 
541
  print("[LoRA] Prepared LoRA state already active; skipping.")
542
  return True
543
 
 
 
 
 
544
  existing_transformer = _transformer
545
  with torch.no_grad():
546
  missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
@@ -548,6 +555,11 @@ def apply_prepared_lora_state_to_pipeline():
548
  print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
549
 
550
  current_lora_key = PENDING_LORA_KEY
 
 
 
 
 
551
  print("[LoRA] Prepared LoRA state applied to the pipeline.")
552
  return True
553
 
 
35
  from pathlib import Path
36
  import gc
37
  import hashlib
38
+ import shutil
39
 
40
  import torch
41
  torch._dynamo.config.suppress_errors = True
 
45
  import gradio as gr
46
  import numpy as np
47
  from huggingface_hub import hf_hub_download, snapshot_download
 
48
  from safetensors import safe_open
49
  import json
50
  import requests
 
263
  print("Downloading LTX-2.3 distilled model + Gemma...")
264
  print("=" * 80)
265
 
266
+ # Legacy on-disk LoRA cache is no longer used; delete any old copies so the
267
+ # Space does not run out of persistent storage and reset.
268
+ _legacy_lora_cache_dir = Path("lora_cache")
269
+ if _legacy_lora_cache_dir.exists():
270
+ shutil.rmtree(_legacy_lora_cache_dir, ignore_errors=True)
271
+
272
+ # Only the currently-loaded key and a single transient CPU-side pending state
273
+ # are kept. The pending state is deleted as soon as it is copied to the GPU.
274
  current_lora_key: str | None = None
275
 
276
  PENDING_LORA_KEY: str | None = None
 
430
  global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
431
 
432
  ledger = pipeline.model_ledger
433
+ key, _ = _make_lora_key(
434
+ singularity_strength, teneros_strength, sulphur_strength, pose_strength,
435
+ general_strength, motion_strength, dreamlay_strength, mself_strength,
436
+ dramatic_strength, fluid_strength, liquid_strength, demopose_strength,
437
+ voice_strength, realism_strength, transition_strength, physics_strength,
438
+ reasoning_strength, twostep_strength, mcfurry_strength, dm_strength,
439
+ praxis_strength, threed_strength, concept_strength, bulge_strength,
440
+ )
441
 
442
  progress(0.05, desc="Preparing LoRA state")
 
 
 
 
 
 
 
 
 
 
 
443
  entries = [
444
  (singularity_lora_path, round(float(singularity_strength), 2)),
445
  (teneros_lora_path, round(float(teneros_strength), 2)),
 
498
  k: v.detach().cpu().contiguous()
499
  for k, v in new_transformer_cpu.state_dict().items()
500
  }
 
 
501
  PENDING_LORA_KEY = key
502
  PENDING_LORA_STATE = state
503
+ PENDING_LORA_STATUS = "Built LoRA state (ready to apply)."
504
  return PENDING_LORA_STATUS
505
 
506
  except Exception as e:
 
526
 
527
  def apply_prepared_lora_state_to_pipeline():
528
  """
529
+ Fast ZeroGPU step: copy the already prepared CPU state into the live
530
+ transformer, then immediately free the CPU-side copy.
531
  """
532
+ global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
533
 
534
+ if PENDING_LORA_KEY is None:
535
+ print("[LoRA] No prepared LoRA key available; skipping.")
536
  return False
537
 
538
+ # The same LoRA weights are already loaded into _transformer; just make
539
+ # sure the CPU copy is not wasting RAM.
540
  if current_lora_key == PENDING_LORA_KEY:
541
+ if PENDING_LORA_STATE is not None:
542
+ PENDING_LORA_STATE = None
543
+ PENDING_LORA_STATUS = "LoRA state already active; CPU copy freed."
544
  print("[LoRA] Prepared LoRA state already active; skipping.")
545
  return True
546
 
547
+ if PENDING_LORA_STATE is None:
548
+ print("[LoRA] LoRA key changed but no CPU state is available; skipping.")
549
+ return False
550
+
551
  existing_transformer = _transformer
552
  with torch.no_grad():
553
  missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
 
555
  print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
556
 
557
  current_lora_key = PENDING_LORA_KEY
558
+
559
+ # The weights now live in the live transformer, so drop the CPU copy to
560
+ # avoid holding an extra full transformer in RAM.
561
+ PENDING_LORA_STATE = None
562
+ PENDING_LORA_STATUS = "LoRA state applied to pipeline."
563
  print("[LoRA] Prepared LoRA state applied to the pipeline.")
564
  return True
565