dagloop5 commited on
Commit
b024f1b
·
verified ·
1 Parent(s): d7e2706

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -77
app.py CHANGED
@@ -255,8 +255,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
255
  # Model repos
256
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
257
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
258
- GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
259
- GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
260
 
261
  # Download model checkpoints
262
  print("=" * 80)
@@ -272,81 +270,7 @@ PENDING_LORA_KEY: str | None = None
272
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
273
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
274
 
275
- weights_dir = Path("weights")
276
- weights_dir.mkdir(exist_ok=True)
277
- checkpoint_path = hf_hub_download(
278
- repo_id="SulphurAI/Sulphur-2-base",
279
- filename="sulphur_distil_bf16.safetensors",
280
- local_dir=str(weights_dir),
281
- local_dir_use_symlinks=False,
282
- )
283
-
284
- print("[Gemma] Setting up abliterated Gemma text encoder...")
285
- MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
286
- gemma_root = "/tmp/abliterated_gemma"
287
- os.makedirs(gemma_root, exist_ok=True)
288
-
289
- gemma_official_dir = snapshot_download(
290
- repo_id=GEMMA_REPO,
291
- ignore_patterns=["*.safetensors", "*.safetensors.index.json"],
292
- )
293
-
294
- for fname in os.listdir(gemma_official_dir):
295
- src = os.path.join(gemma_official_dir, fname)
296
- dst = os.path.join(gemma_root, fname)
297
- if os.path.isfile(src) and not fname.endswith(".safetensors") and fname != "model.safetensors.index.json":
298
- if not os.path.exists(dst):
299
- os.symlink(src, dst)
300
-
301
- if os.path.exists(MERGED_WEIGHTS):
302
- print("[Gemma] Using cached merged weights")
303
- else:
304
- abliterated_weights_path = hf_hub_download(
305
- repo_id=GEMMA_ABLITERATED_REPO,
306
- filename=GEMMA_ABLITERATED_FILE,
307
- )
308
- index_path = hf_hub_download(
309
- repo_id=GEMMA_REPO,
310
- filename="model.safetensors.index.json"
311
- )
312
- with open(index_path) as f:
313
- weight_index = json.load(f)
314
-
315
- vision_keys = {}
316
- for key, shard in weight_index["weight_map"].items():
317
- if "vision_tower" in key or "multi_modal_projector" in key:
318
- vision_keys[key] = shard
319
- needed_shards = set(vision_keys.values())
320
-
321
- shard_paths = {}
322
- for shard_name in needed_shards:
323
- shard_paths[shard_name] = hf_hub_download(
324
- repo_id=GEMMA_REPO,
325
- filename=shard_name
326
- )
327
-
328
- _fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
329
- raw = load_file(abliterated_weights_path)
330
- merged = {}
331
- for key, tensor in raw.items():
332
- t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
333
- merged[f"language_model.{key}"] = t
334
- del raw
335
-
336
- for key, shard_name in vision_keys.items():
337
- with safe_open(shard_paths[shard_name], framework="pt") as f:
338
- merged[key] = f.get_tensor(key)
339
-
340
- save_file(merged, MERGED_WEIGHTS)
341
- del merged
342
- gc.collect()
343
-
344
- weight_link = os.path.join(gemma_root, "model.safetensors")
345
- if os.path.exists(weight_link):
346
- os.remove(weight_link)
347
- os.symlink(MERGED_WEIGHTS, weight_link)
348
- print(f"[Gemma] Root ready: {gemma_root}")
349
-
350
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
351
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
352
 
 
255
  # Model repos
256
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
257
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
 
 
258
 
259
  # Download model checkpoints
260
  print("=" * 80)
 
270
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
271
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
272
 
273
+ checkpoint_path = hf_hub_download(repo_id="SulphurAI/Sulphur-2-base", filename="sulphur_distil_bf16.safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
275
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
276