Update app.py
Browse files
app.py
CHANGED
|
@@ -267,8 +267,8 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
|
|
| 267 |
# Model repos
|
| 268 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 269 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 270 |
-
GEMMA_ABLITERATED_REPO = "
|
| 271 |
-
GEMMA_ABLITERATED_FILE = "
|
| 272 |
|
| 273 |
# Download model checkpoints
|
| 274 |
print("=" * 80)
|
|
@@ -338,10 +338,12 @@ else:
|
|
| 338 |
filename=shard_name
|
| 339 |
)
|
| 340 |
|
|
|
|
| 341 |
raw = load_file(abliterated_weights_path)
|
| 342 |
merged = {}
|
| 343 |
for key, tensor in raw.items():
|
| 344 |
-
|
|
|
|
| 345 |
del raw
|
| 346 |
|
| 347 |
for key, shard_name in vision_keys.items():
|
|
|
|
| 267 |
# Model repos
|
| 268 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 269 |
GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 270 |
+
GEMMA_ABLITERATED_REPO = "Sikaworld1990/gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition-Ltx-2"
|
| 271 |
+
GEMMA_ABLITERATED_FILE = "gemma-3-12b-it-abliterated-sikaworld-high-fidelity-edition.safetensors"
|
| 272 |
|
| 273 |
# Download model checkpoints
|
| 274 |
print("=" * 80)
|
|
|
|
| 338 |
filename=shard_name
|
| 339 |
)
|
| 340 |
|
| 341 |
+
_fp8_types = {torch.float8_e4m3fn, torch.float8_e5m2}
|
| 342 |
raw = load_file(abliterated_weights_path)
|
| 343 |
merged = {}
|
| 344 |
for key, tensor in raw.items():
|
| 345 |
+
t = tensor.to(torch.bfloat16) if tensor.dtype in _fp8_types else tensor
|
| 346 |
+
merged[f"language_model.{key}"] = t
|
| 347 |
del raw
|
| 348 |
|
| 349 |
for key, shard_name in vision_keys.items():
|