Update app(workingwithsafetensorscache).py
Browse files
app(workingwithsafetensorscache).py
CHANGED
|
@@ -36,9 +36,12 @@ import random
|
|
| 36 |
import tempfile
|
| 37 |
from pathlib import Path
|
| 38 |
import gc
|
|
|
|
| 39 |
import hashlib
|
| 40 |
|
|
|
|
| 41 |
import torch
|
|
|
|
| 42 |
torch._dynamo.config.suppress_errors = True
|
| 43 |
torch._dynamo.config.disable = True
|
| 44 |
|
|
@@ -264,6 +267,9 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 264 |
# Model repos
|
| 265 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 266 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
# Download model checkpoints
|
|
@@ -280,16 +286,50 @@ PENDING_LORA_KEY: str | None = None
|
|
| 280 |
PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
|
| 281 |
PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 292 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
|
|
|
| 293 |
|
| 294 |
# ---- Insert block (LoRA downloads) between lines 268 and 269 ----
|
| 295 |
# LoRA repo + download the requested LoRA adapters
|
|
@@ -322,9 +362,16 @@ print(f"Demopose LoRA: {demopose_lora_path}")
|
|
| 322 |
print(f"Checkpoint: {checkpoint_path}")
|
| 323 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 324 |
print(f"Gemma root: {gemma_root}")
|
|
|
|
| 325 |
|
| 326 |
# Initialize pipeline WITH text encoder and optional audio support
|
| 327 |
# ---- Replace block (pipeline init) lines 275-281 ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
pipeline = LTX23DistilledA2VPipeline(
|
| 329 |
distilled_checkpoint_path=checkpoint_path,
|
| 330 |
spatial_upsampler_path=spatial_upsampler_path,
|
|
|
|
| 36 |
import tempfile
|
| 37 |
from pathlib import Path
|
| 38 |
import gc
|
| 39 |
+
import json
|
| 40 |
import hashlib
|
| 41 |
|
| 42 |
+
import requests
|
| 43 |
import torch
|
| 44 |
+
from safetensors import safe_open
|
| 45 |
torch._dynamo.config.suppress_errors = True
|
| 46 |
torch._dynamo.config.disable = True
|
| 47 |
|
|
|
|
| 267 |
# Model repos
|
| 268 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 269 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 270 |
+
DISTILLED_LORA_FILE = "ltx-2.3-22b-distilled-lora-384.safetensors"
|
| 271 |
+
EROS_REPO = "dagloop5/LoRA"
|
| 272 |
+
EROS_FILE = "ltx2310eros_beta.safetensors"
|
| 273 |
|
| 274 |
|
| 275 |
# Download model checkpoints
|
|
|
|
| 286 |
PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
|
| 287 |
PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
|
| 288 |
|
| 289 |
+
# --- 1. 10Eros checkpoint: FP8→BF16 conversion + DEV metadata injection ---
|
| 290 |
+
# Official DEV checkpoint is BF16. fp8_cast() expects BF16 input.
|
| 291 |
+
# 10Eros is FP8 (CivitAI format) → convert to BF16 to match official dtype distribution.
|
| 292 |
+
print("[1/4] Preparing 10Eros checkpoint...")
|
| 293 |
+
eros_fp8_path = hf_hub_download(repo_id=EROS_REPO, filename=EROS_FILE)
|
| 294 |
+
print(f" Downloaded: {eros_fp8_path}")
|
| 295 |
+
|
| 296 |
+
EROS_FIXED = "/tmp/eros_bf16_with_meta.safetensors"
|
| 297 |
+
if os.path.exists(EROS_FIXED):
|
| 298 |
+
print(" Using cached BF16 checkpoint")
|
| 299 |
+
else:
|
| 300 |
+
# Fetch DEV checkpoint metadata from header only (first 2MB, not full 46GB)
|
| 301 |
+
print(" Fetching DEV checkpoint metadata (header only)...")
|
| 302 |
+
dev_url = f"https://huggingface.co/{LTX_MODEL_REPO}/resolve/main/ltx-2.3-22b-dev.safetensors"
|
| 303 |
+
hdr_resp = requests.get(dev_url, headers={"Range": "bytes=0-2000000"}, timeout=30)
|
| 304 |
+
hdr_resp.raise_for_status()
|
| 305 |
+
hdr_size = int.from_bytes(hdr_resp.content[:8], "little")
|
| 306 |
+
hdr_json = json.loads(hdr_resp.content[8:8 + min(hdr_size, len(hdr_resp.content) - 8)])
|
| 307 |
+
dev_metadata = hdr_json.get("__metadata__", {})
|
| 308 |
+
print(f" DEV metadata keys: {list(dev_metadata.keys())}")
|
| 309 |
+
|
| 310 |
+
# Convert FP8→BF16 + inject metadata
|
| 311 |
+
print(" Converting FP8→BF16 (lossless upcast)...")
|
| 312 |
+
_fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
|
| 313 |
+
tensors = {}
|
| 314 |
+
_converted = 0
|
| 315 |
+
with safe_open(eros_fp8_path, framework="pt") as f:
|
| 316 |
+
for key in f.keys():
|
| 317 |
+
tensor = f.get_tensor(key)
|
| 318 |
+
if tensor.dtype in _fp8_types:
|
| 319 |
+
tensors[key] = tensor.to(torch.bfloat16)
|
| 320 |
+
_converted += 1
|
| 321 |
+
else:
|
| 322 |
+
tensors[key] = tensor
|
| 323 |
+
print(f" Converted {_converted} FP8→BF16, kept {len(tensors)-_converted} as-is")
|
| 324 |
+
save_file(tensors, EROS_FIXED, metadata=dev_metadata)
|
| 325 |
+
del tensors
|
| 326 |
+
gc.collect()
|
| 327 |
+
print(" Saved with DEV metadata")
|
| 328 |
+
|
| 329 |
+
checkpoint_path = EROS_FIXED
|
| 330 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 331 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 332 |
+
distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=DISTILLED_LORA_FILE)
|
| 333 |
|
| 334 |
# ---- Insert block (LoRA downloads) between lines 268 and 269 ----
|
| 335 |
# LoRA repo + download the requested LoRA adapters
|
|
|
|
| 362 |
print(f"Checkpoint: {checkpoint_path}")
|
| 363 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 364 |
print(f"Gemma root: {gemma_root}")
|
| 365 |
+
print(f"Distilled LoRA: {distilled_lora_path}")
|
| 366 |
|
| 367 |
# Initialize pipeline WITH text encoder and optional audio support
|
| 368 |
# ---- Replace block (pipeline init) lines 275-281 ----
|
| 369 |
+
print("Creating TI2VidTwoStagesPipeline...")
|
| 370 |
+
distilled_lora = [LoraPathStrengthAndSDOps(
|
| 371 |
+
path=distilled_lora_path,
|
| 372 |
+
strength=0.9,
|
| 373 |
+
sd_ops=SDOps(name="distilled_lora", mapping=()),
|
| 374 |
+
)]
|
| 375 |
pipeline = LTX23DistilledA2VPipeline(
|
| 376 |
distilled_checkpoint_path=checkpoint_path,
|
| 377 |
spatial_upsampler_path=spatial_upsampler_path,
|