Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -255,8 +255,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 255 |
# Model repos
|
| 256 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 257 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 258 |
-
GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
|
| 259 |
-
GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
|
| 260 |
|
| 261 |
# Download model checkpoints
|
| 262 |
print("=" * 80)
|
|
@@ -272,81 +270,7 @@ PENDING_LORA_KEY: str | None = None
|
|
| 272 |
PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
|
| 273 |
PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
|
| 274 |
|
| 275 |
-
|
| 276 |
-
weights_dir.mkdir(exist_ok=True)
|
| 277 |
-
checkpoint_path = hf_hub_download(
|
| 278 |
-
repo_id="SulphurAI/Sulphur-2-base",
|
| 279 |
-
filename="sulphur_distil_bf16.safetensors",
|
| 280 |
-
local_dir=str(weights_dir),
|
| 281 |
-
local_dir_use_symlinks=False,
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
print("[Gemma] Setting up abliterated Gemma text encoder...")
|
| 285 |
-
MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
|
| 286 |
-
gemma_root = "/tmp/abliterated_gemma"
|
| 287 |
-
os.makedirs(gemma_root, exist_ok=True)
|
| 288 |
-
|
| 289 |
-
gemma_official_dir = snapshot_download(
|
| 290 |
-
repo_id=GEMMA_REPO,
|
| 291 |
-
ignore_patterns=["*.safetensors", "*.safetensors.index.json"],
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
for fname in os.listdir(gemma_official_dir):
|
| 295 |
-
src = os.path.join(gemma_official_dir, fname)
|
| 296 |
-
dst = os.path.join(gemma_root, fname)
|
| 297 |
-
if os.path.isfile(src) and not fname.endswith(".safetensors") and fname != "model.safetensors.index.json":
|
| 298 |
-
if not os.path.exists(dst):
|
| 299 |
-
os.symlink(src, dst)
|
| 300 |
-
|
| 301 |
-
if os.path.exists(MERGED_WEIGHTS):
|
| 302 |
-
print("[Gemma] Using cached merged weights")
|
| 303 |
-
else:
|
| 304 |
-
abliterated_weights_path = hf_hub_download(
|
| 305 |
-
repo_id=GEMMA_ABLITERATED_REPO,
|
| 306 |
-
filename=GEMMA_ABLITERATED_FILE,
|
| 307 |
-
)
|
| 308 |
-
index_path = hf_hub_download(
|
| 309 |
-
repo_id=GEMMA_REPO,
|
| 310 |
-
filename="model.safetensors.index.json"
|
| 311 |
-
)
|
| 312 |
-
with open(index_path) as f:
|
| 313 |
-
weight_index = json.load(f)
|
| 314 |
-
|
| 315 |
-
vision_keys = {}
|
| 316 |
-
for key, shard in weight_index["weight_map"].items():
|
| 317 |
-
if "vision_tower" in key or "multi_modal_projector" in key:
|
| 318 |
-
vision_keys[key] = shard
|
| 319 |
-
needed_shards = set(vision_keys.values())
|
| 320 |
-
|
| 321 |
-
shard_paths = {}
|
| 322 |
-
for shard_name in needed_shards:
|
| 323 |
-
shard_paths[shard_name] = hf_hub_download(
|
| 324 |
-
repo_id=GEMMA_REPO,
|
| 325 |
-
filename=shard_name
|
| 326 |
-
)
|
| 327 |
-
|
| 328 |
-
_fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
|
| 329 |
-
raw = load_file(abliterated_weights_path)
|
| 330 |
-
merged = {}
|
| 331 |
-
for key, tensor in raw.items():
|
| 332 |
-
t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
|
| 333 |
-
merged[f"language_model.{key}"] = t
|
| 334 |
-
del raw
|
| 335 |
-
|
| 336 |
-
for key, shard_name in vision_keys.items():
|
| 337 |
-
with safe_open(shard_paths[shard_name], framework="pt") as f:
|
| 338 |
-
merged[key] = f.get_tensor(key)
|
| 339 |
-
|
| 340 |
-
save_file(merged, MERGED_WEIGHTS)
|
| 341 |
-
del merged
|
| 342 |
-
gc.collect()
|
| 343 |
-
|
| 344 |
-
weight_link = os.path.join(gemma_root, "model.safetensors")
|
| 345 |
-
if os.path.exists(weight_link):
|
| 346 |
-
os.remove(weight_link)
|
| 347 |
-
os.symlink(MERGED_WEIGHTS, weight_link)
|
| 348 |
-
print(f"[Gemma] Root ready: {gemma_root}")
|
| 349 |
-
|
| 350 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
|
| 351 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 352 |
|
|
|
|
| 255 |
# Model repos
|
| 256 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 257 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
|
|
|
|
|
|
| 258 |
|
| 259 |
# Download model checkpoints
|
| 260 |
print("=" * 80)
|
|
|
|
| 270 |
PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
|
| 271 |
PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
|
| 272 |
|
| 273 |
+
checkpoint_path = hf_hub_download(repo_id="SulphurAI/Sulphur-2-base", filename="sulphur_distil_bf16.safetensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
|
| 275 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 276 |
|