dagloop5 commited on
Commit
acd6807
·
verified ·
1 Parent(s): 45e9a0f

Update app(workingwithsafetensorscache).py

Browse files
Files changed (1) hide show
  1. app(workingwithsafetensorscache).py +55 -8
app(workingwithsafetensorscache).py CHANGED
@@ -36,9 +36,12 @@ import random
36
  import tempfile
37
  from pathlib import Path
38
  import gc
 
39
  import hashlib
40
 
 
41
  import torch
 
42
  torch._dynamo.config.suppress_errors = True
43
  torch._dynamo.config.disable = True
44
 
@@ -264,6 +267,9 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
264
  # Model repos
265
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
266
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
 
 
 
267
 
268
 
269
  # Download model checkpoints
@@ -280,16 +286,50 @@ PENDING_LORA_KEY: str | None = None
280
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
281
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
282
 
283
- weights_dir = Path("weights")
284
- weights_dir.mkdir(exist_ok=True)
285
- checkpoint_path = hf_hub_download(
286
- repo_id=LTX_MODEL_REPO,
287
- filename="ltx-2.3-22b-distilled.safetensors",
288
- local_dir=str(weights_dir),
289
- local_dir_use_symlinks=False,
290
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
292
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
 
293
 
294
  # ---- Insert block (LoRA downloads) between lines 268 and 269 ----
295
  # LoRA repo + download the requested LoRA adapters
@@ -322,9 +362,16 @@ print(f"Demopose LoRA: {demopose_lora_path}")
322
  print(f"Checkpoint: {checkpoint_path}")
323
  print(f"Spatial upsampler: {spatial_upsampler_path}")
324
  print(f"Gemma root: {gemma_root}")
 
325
 
326
  # Initialize pipeline WITH text encoder and optional audio support
327
  # ---- Replace block (pipeline init) lines 275-281 ----
 
 
 
 
 
 
328
  pipeline = LTX23DistilledA2VPipeline(
329
  distilled_checkpoint_path=checkpoint_path,
330
  spatial_upsampler_path=spatial_upsampler_path,
 
36
  import tempfile
37
  from pathlib import Path
38
  import gc
39
+ import json
40
  import hashlib
41
 
42
+ import requests
43
  import torch
44
+ from safetensors import safe_open
45
  torch._dynamo.config.suppress_errors = True
46
  torch._dynamo.config.disable = True
47
 
 
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
270
+ DISTILLED_LORA_FILE = "ltx-2.3-22b-distilled-lora-384.safetensors"
271
+ EROS_REPO = "dagloop5/LoRA"
272
+ EROS_FILE = "ltx2310eros_beta.safetensors"
273
 
274
 
275
  # Download model checkpoints
 
286
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
287
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
288
 
289
+ # --- 1. 10Eros checkpoint: FP8→BF16 conversion + DEV metadata injection ---
290
+ # Official DEV checkpoint is BF16. fp8_cast() expects BF16 input.
291
+ # 10Eros is FP8 (CivitAI format) → convert to BF16 to match official dtype distribution.
292
+ print("[1/4] Preparing 10Eros checkpoint...")
293
+ eros_fp8_path = hf_hub_download(repo_id=EROS_REPO, filename=EROS_FILE)
294
+ print(f" Downloaded: {eros_fp8_path}")
295
+
296
+ EROS_FIXED = "/tmp/eros_bf16_with_meta.safetensors"
297
+ if os.path.exists(EROS_FIXED):
298
+ print(" Using cached BF16 checkpoint")
299
+ else:
300
+ # Fetch DEV checkpoint metadata from header only (first 2MB, not full 46GB)
301
+ print(" Fetching DEV checkpoint metadata (header only)...")
302
+ dev_url = f"https://huggingface.co/{LTX_MODEL_REPO}/resolve/main/ltx-2.3-22b-dev.safetensors"
303
+ hdr_resp = requests.get(dev_url, headers={"Range": "bytes=0-2000000"}, timeout=30)
304
+ hdr_resp.raise_for_status()
305
+ hdr_size = int.from_bytes(hdr_resp.content[:8], "little")
306
+ hdr_json = json.loads(hdr_resp.content[8:8 + min(hdr_size, len(hdr_resp.content) - 8)])
307
+ dev_metadata = hdr_json.get("__metadata__", {})
308
+ print(f" DEV metadata keys: {list(dev_metadata.keys())}")
309
+
310
+ # Convert FP8→BF16 + inject metadata
311
+ print(" Converting FP8→BF16 (lossless upcast)...")
312
+ _fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
313
+ tensors = {}
314
+ _converted = 0
315
+ with safe_open(eros_fp8_path, framework="pt") as f:
316
+ for key in f.keys():
317
+ tensor = f.get_tensor(key)
318
+ if tensor.dtype in _fp8_types:
319
+ tensors[key] = tensor.to(torch.bfloat16)
320
+ _converted += 1
321
+ else:
322
+ tensors[key] = tensor
323
+ print(f" Converted {_converted} FP8→BF16, kept {len(tensors)-_converted} as-is")
324
+ save_file(tensors, EROS_FIXED, metadata=dev_metadata)
325
+ del tensors
326
+ gc.collect()
327
+ print(" Saved with DEV metadata")
328
+
329
+ checkpoint_path = EROS_FIXED
330
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
331
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
332
+ distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=DISTILLED_LORA_FILE)
333
 
334
  # ---- Insert block (LoRA downloads) between lines 268 and 269 ----
335
  # LoRA repo + download the requested LoRA adapters
 
362
  print(f"Checkpoint: {checkpoint_path}")
363
  print(f"Spatial upsampler: {spatial_upsampler_path}")
364
  print(f"Gemma root: {gemma_root}")
365
+ print(f"Distilled LoRA: {distilled_lora_path}")
366
 
367
  # Initialize pipeline WITH text encoder and optional audio support
368
  # ---- Replace block (pipeline init) lines 275-281 ----
369
+ print("Creating TI2VidTwoStagesPipeline...")
370
+ distilled_lora = [LoraPathStrengthAndSDOps(
371
+ path=distilled_lora_path,
372
+ strength=0.9,
373
+ sd_ops=SDOps(name="distilled_lora", mapping=()),
374
+ )]
375
  pipeline = LTX23DistilledA2VPipeline(
376
  distilled_checkpoint_path=checkpoint_path,
377
  spatial_upsampler_path=spatial_upsampler_path,