dagloop5 commited on
Commit
e949b11
·
verified ·
1 Parent(s): 1b65c80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -267,8 +267,8 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
270
- GEMMA_ABLITERATED_REPO = "FusionCow/Gemma-3-12b-Abliterated-LTX2"
271
- GEMMA_ABLITERATED_FILE = "gemma_ablit_fixed_bf16.safetensors"
272
 
273
  # Download model checkpoints
274
  print("=" * 80)
@@ -338,10 +338,12 @@ else:
338
  filename=shard_name
339
  )
340
 
 
341
  raw = load_file(abliterated_weights_path)
342
  merged = {}
343
  for key, tensor in raw.items():
344
- merged[f"language_model.{key}"] = tensor
 
345
  del raw
346
 
347
  for key, shard_name in vision_keys.items():
 
267
  # Model repos
268
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
269
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
270
+ GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
271
+ GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
272
 
273
  # Download model checkpoints
274
  print("=" * 80)
 
338
  filename=shard_name
339
  )
340
 
341
+ _fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
342
  raw = load_file(abliterated_weights_path)
343
  merged = {}
344
  for key, tensor in raw.items():
345
+ t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
346
+ merged[f"language_model.{key}"] = t
347
  del raw
348
 
349
  for key, shard_name in vision_keys.items():