dagloop5 commited on
Commit
9392385
·
verified ·
1 Parent(s): d7cff3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -276,7 +276,14 @@ PENDING_LORA_KEY: str | None = None
276
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
277
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
278
 
279
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
 
 
 
 
 
 
 
280
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
281
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
282
 
@@ -377,35 +384,16 @@ def prepare_lora_cache(
377
  new_transformer_cpu = None
378
  try:
379
  progress(0.35, desc="Building fused CPU transformer")
380
- tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
381
- # --- FIX: force tmp_ledger to use already downloaded checkpoints ---
382
- tmp_ledger._distilled_checkpoint_path = checkpoint_path
383
- tmp_ledger._spatial_upsampler_path = spatial_upsampler_path
384
- tmp_ledger._gemma_root = gemma_root
385
- # ---------------------------------------------------------------
386
- assert os.path.exists(checkpoint_path), f"Checkpoint missing: {checkpoint_path}"
387
-
388
- orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
389
- orig_tmp_device = getattr(tmp_ledger, "device", None)
390
- try:
391
- tmp_ledger._target_device = lambda: torch.device("cpu")
392
- tmp_ledger.device = torch.device("cpu")
393
- new_transformer_cpu = tmp_ledger.transformer()
394
- finally:
395
- if orig_tmp_target is not None:
396
- tmp_ledger._target_device = orig_tmp_target
397
- else:
398
- try:
399
- delattr(tmp_ledger, "_target_device")
400
- except Exception:
401
- pass
402
- if orig_tmp_device is not None:
403
- tmp_ledger.device = orig_tmp_device
404
- else:
405
- try:
406
- delattr(tmp_ledger, "device")
407
- except Exception:
408
- pass
409
 
410
  progress(0.70, desc="Extracting fused state_dict")
411
  state = new_transformer_cpu.state_dict()
 
276
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
277
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
278
 
279
+ weights_dir = Path("weights")
280
+ weights_dir.mkdir(exist_ok=True)
281
+ checkpoint_path = hf_hub_download(
282
+ repo_id=LTX_MODEL_REPO,
283
+ filename="ltx-2.3-22b-distilled.safetensors",
284
+ local_dir=str(weights_dir),
285
+ local_dir_use_symlinks=False,
286
+ )
287
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
288
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
289
 
 
384
  new_transformer_cpu = None
385
  try:
386
  progress(0.35, desc="Building fused CPU transformer")
387
+ tmp_ledger = pipeline.model_ledger.__class__(
388
+ dtype=ledger.dtype,
389
+ device=torch.device("cpu"),
390
+ checkpoint_path=str(checkpoint_path),
391
+ spatial_upsampler_path=str(spatial_upsampler_path),
392
+ gemma_root_path=str(gemma_root),
393
+ loras=tuple(loras_for_builder),
394
+ quantization=getattr(ledger, "quantization", None),
395
+ )
396
+ new_transformer_cpu = tmp_ledger.transformer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  progress(0.70, desc="Extracting fused state_dict")
399
  state = new_transformer_cpu.state_dict()