dagloop5 commited on
Commit
007ffb7
·
verified ·
1 Parent(s): ca0ae99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +695 -161
app.py CHANGED
@@ -6,22 +6,17 @@ import sys
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
9
- # Install xformers for memory-efficient attention
10
- subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
  # Clone LTX-2 repo and install packages
13
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
- LTX_COMMIT_SHA = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
 
16
 
17
  if not os.path.exists(LTX_REPO_DIR):
18
  print(f"Cloning {LTX_REPO_URL}...")
19
- os.makedirs(LTX_REPO_DIR)
20
- subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
21
- subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
22
- subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
- subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
24
-
25
 
26
  print("Installing ltx-core and ltx-pipelines from cloned repo...")
27
  subprocess.run(
@@ -38,6 +33,8 @@ import logging
38
  import random
39
  import tempfile
40
  from pathlib import Path
 
 
41
 
42
  import torch
43
  torch._dynamo.config.suppress_errors = True
@@ -47,98 +44,32 @@ import spaces
47
  import gradio as gr
48
  import numpy as np
49
  from huggingface_hub import hf_hub_download, snapshot_download
 
 
 
 
50
 
51
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
 
 
 
 
52
  from ltx_core.quantization import QuantizationPolicy
 
53
  from ltx_pipelines.distilled import DistilledPipeline
 
54
  from ltx_pipelines.utils.args import ImageConditioningInput
55
- from ltx_pipelines.utils.media_io import encode_video
56
-
57
- # Force-patch xformers attention into the LTX attention module.
58
- from ltx_core.model.transformer import attention as _attn_mod
59
- print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
60
- try:
61
- from xformers.ops import memory_efficient_attention as _mea
62
- _attn_mod.memory_efficient_attention = _mea
63
- print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
64
- except Exception as e:
65
- print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
66
-
67
- # Disable xformers FA3 dispatch: FA3 kernels are Hopper-only (sm_90a), but
68
- # xformers' dispatcher gates them on `device_capability >= (9, 0)`, which also
69
- # matches Blackwell (RTX PRO 6000, the ZeroGPU fleet hardware since 2026-05-12)
70
- # and crashes at kernel launch with "invalid argument".
71
- try:
72
- from xformers.ops.fmha import _set_use_fa3
73
- _set_use_fa3(False)
74
- print("[ATTN] xformers FA3 dispatch disabled (Blackwell-incompatible)")
75
- except Exception as e:
76
- print(f"[ATTN] FA3 disable FAILED: {type(e).__name__}: {e}")
77
-
78
- # FUSE/mmap workaround: SafetensorsStateDictLoader.load uses safetensors.safe_open
79
- # under the hood, which mmap's the file. On bucket FUSE mounts that triggers a
80
- # page-fault storm and deadlocks loading. Bypass mmap by parsing the safetensors
81
- # header ourselves and reading each tensor's bytes directly.
82
- import json
83
- import struct
84
-
85
- from ltx_core.loader.primitives import StateDict
86
- from ltx_core.loader.sft_loader import SafetensorsStateDictLoader
87
-
88
- _SAFETENSORS_DTYPE_MAP = {
89
- "F64": torch.float64,
90
- "F32": torch.float32,
91
- "F16": torch.float16,
92
- "BF16": torch.bfloat16,
93
- "F8_E5M2": torch.float8_e5m2,
94
- "F8_E4M3": torch.float8_e4m3fn,
95
- "I64": torch.int64,
96
- "I32": torch.int32,
97
- "I16": torch.int16,
98
- "I8": torch.int8,
99
- "U8": torch.uint8,
100
- "BOOL": torch.bool,
101
- }
102
-
103
-
104
- def _patched_load(self, path, sd_ops, device=None):
105
- sd = {}
106
- size = 0
107
- dtype = set()
108
- device = device or torch.device("cpu")
109
- model_paths = path if isinstance(path, list) else [path]
110
- for shard_path in model_paths:
111
- with open(shard_path, "rb") as f:
112
- header_len = struct.unpack("<Q", f.read(8))[0]
113
- header = json.loads(f.read(header_len).decode("utf-8"))
114
- data_base = 8 + header_len
115
- for name, meta in header.items():
116
- if name == "__metadata__":
117
- continue
118
- expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
119
- if expected_name is None:
120
- continue
121
- start, end = meta["data_offsets"]
122
- f.seek(data_base + start)
123
- buf = f.read(end - start)
124
- t = torch.frombuffer(
125
- bytearray(buf), dtype=_SAFETENSORS_DTYPE_MAP[meta["dtype"]]
126
- ).reshape(meta["shape"])
127
- t = t.to(device=device, non_blocking=True, copy=False)
128
- kvs = (
129
- ((expected_name, t),)
130
- if sd_ops is None
131
- else sd_ops.apply_to_key_value(expected_name, t)
132
- )
133
- for key, v in kvs:
134
- size += v.nbytes
135
- dtype.add(v.dtype)
136
- sd[key] = v
137
- return StateDict(sd=sd, device=device, size=size, dtype=dtype)
138
-
139
-
140
- SafetensorsStateDictLoader.load = _patched_load
141
- print("[FUSE-PATCH] SafetensorsStateDictLoader.load replaced (chunked-read)")
142
 
143
  logging.getLogger().setLevel(logging.INFO)
144
 
@@ -157,46 +88,464 @@ RESOLUTIONS = {
157
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
158
  }
159
 
160
- LTX_MOUNT = "/models/ltx"
161
- GEMMA_MOUNT = "/models/gemma"
162
 
163
- DISTILLED_FILENAME = "ltx-2.3-22b-distilled-1.1.safetensors"
164
- UPSCALER_FILENAME = "ltx-2.3-spatial-upscaler-x2-1.1.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- distilled_checkpoint_path = os.path.join(LTX_MOUNT, DISTILLED_FILENAME)
167
- spatial_upsampler_path = os.path.join(LTX_MOUNT, UPSCALER_FILENAME)
168
- gemma_root = GEMMA_MOUNT
169
 
170
- # Initialize pipeline WITH text encoder
171
- pipeline = DistilledPipeline(
172
- distilled_checkpoint_path=distilled_checkpoint_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  spatial_upsampler_path=spatial_upsampler_path,
174
  gemma_root=gemma_root,
175
  loras=[],
176
- quantization=QuantizationPolicy.fp8_cast(),
177
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Preload all models for ZeroGPU tensor packing.
180
- print("Preloading all models (including Gemma)...")
181
  ledger = pipeline.model_ledger
182
- _transformer = ledger.transformer()
183
- _video_encoder = ledger.video_encoder()
184
- _video_decoder = ledger.video_decoder()
185
- _audio_decoder = ledger.audio_decoder()
186
- _vocoder = ledger.vocoder()
187
- _spatial_upsampler = ledger.spatial_upsampler()
188
- _text_encoder = ledger.text_encoder()
189
- _embeddings_processor = ledger.gemma_embeddings_processor()
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  ledger.transformer = lambda: _transformer
192
  ledger.video_encoder = lambda: _video_encoder
193
  ledger.video_decoder = lambda: _video_decoder
 
194
  ledger.audio_decoder = lambda: _audio_decoder
195
  ledger.vocoder = lambda: _vocoder
196
  ledger.spatial_upsampler = lambda: _spatial_upsampler
197
  ledger.text_encoder = lambda: _text_encoder
198
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
199
- print("All models preloaded (including Gemma text encoder)!")
 
 
200
 
201
  print("=" * 80)
202
  print("Pipeline ready!")
@@ -212,7 +561,6 @@ def log_memory(tag: str):
212
 
213
 
214
  def detect_aspect_ratio(image) -> str:
215
- """Detect the closest aspect ratio (16:9, 9:16, or 1:1) from an image."""
216
  if image is None:
217
  return "16:9"
218
  if hasattr(image, "size"):
@@ -226,33 +574,82 @@ def detect_aspect_ratio(image) -> str:
226
  return min(candidates, key=lambda k: abs(ratio - candidates[k]))
227
 
228
 
229
- def on_image_upload(image, high_res):
230
- """Auto-set resolution when image is uploaded."""
231
- aspect = detect_aspect_ratio(image)
232
  tier = "high" if high_res else "low"
233
  w, h = RESOLUTIONS[tier][aspect]
234
  return gr.update(value=w), gr.update(value=h)
235
 
236
 
237
- def on_highres_toggle(image, high_res):
238
- """Update resolution when high-res toggle changes."""
239
- aspect = detect_aspect_ratio(image)
240
  tier = "high" if high_res else "low"
241
  w, h = RESOLUTIONS[tier][aspect]
242
  return gr.update(value=w), gr.update(value=h)
243
 
244
 
245
- @spaces.GPU(duration=75)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  @torch.inference_mode()
247
  def generate_video(
248
- input_image,
 
 
249
  prompt: str,
250
- duration: float,
 
251
  enhance_prompt: bool = True,
252
  seed: int = 42,
253
  randomize_seed: bool = True,
254
- height: int = 1024,
255
- width: int = 1536,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  progress=gr.Progress(track_tqdm=True),
257
  ):
258
  try:
@@ -268,21 +665,32 @@ def generate_video(
268
  print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")
269
 
270
  images = []
271
- if input_image is not None:
272
- output_dir = Path("outputs")
273
- output_dir.mkdir(exist_ok=True)
274
- temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
275
- if hasattr(input_image, "save"):
276
- input_image.save(temp_image_path)
 
277
  else:
278
- temp_image_path = Path(input_image)
279
- images = [ImageConditioningInput(path=str(temp_image_path), frame_idx=0, strength=1.0)]
 
 
 
 
 
 
 
 
280
 
281
  tiling_config = TilingConfig.default()
282
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
283
 
284
  log_memory("before pipeline call")
285
 
 
 
286
  video, audio = pipeline(
287
  prompt=prompt,
288
  seed=current_seed,
@@ -291,6 +699,7 @@ def generate_video(
291
  num_frames=num_frames,
292
  frame_rate=frame_rate,
293
  images=images,
 
294
  tiling_config=tiling_config,
295
  enhance_prompt=enhance_prompt,
296
  )
@@ -317,16 +726,15 @@ def generate_video(
317
 
318
 
319
  with gr.Blocks(title="LTX-2.3 Distilled") as demo:
320
- gr.Markdown("# LTX-2.3 Distilled (22B): Fast Audio-Video Generation")
321
- gr.Markdown(
322
- "Fast and high quality video + audio generation "
323
- "[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
324
- "[[code]](https://github.com/Lightricks/LTX-2)"
325
- )
326
 
327
  with gr.Row():
328
  with gr.Column():
329
- input_image = gr.Image(label="Input Image (Optional)", type="pil")
 
 
 
330
  prompt = gr.Textbox(
331
  label="Prompt",
332
  info="for best results - make it as elaborate as possible",
@@ -334,12 +742,8 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
334
  lines=3,
335
  placeholder="Describe the motion and animation you want...",
336
  )
337
-
338
- with gr.Row():
339
- duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
340
- with gr.Column():
341
- enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
342
- high_res = gr.Checkbox(label="High Resolution", value=True)
343
 
344
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
345
 
@@ -349,29 +753,161 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
349
  with gr.Row():
350
  width = gr.Number(label="Width", value=1536, precision=0)
351
  height = gr.Number(label="Height", value=1024, precision=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  with gr.Column():
354
- output_video = gr.Video(label="Generated Video", autoplay=True)
 
 
 
 
 
 
 
355
 
356
- # Auto-detect aspect ratio from uploaded image and set resolution
357
- input_image.change(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  fn=on_image_upload,
359
- inputs=[input_image, high_res],
360
  outputs=[width, height],
361
  )
362
 
363
- # Update resolution when high-res toggle changes
364
  high_res.change(
365
  fn=on_highres_toggle,
366
- inputs=[input_image, high_res],
367
  outputs=[width, height],
368
  )
369
 
 
 
 
 
 
 
370
  generate_btn.click(
371
  fn=generate_video,
372
  inputs=[
373
- input_image, prompt, duration, enhance_prompt,
374
  seed, randomize_seed, height, width,
 
375
  ],
376
  outputs=[output_video, seed],
377
  )
@@ -379,9 +915,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
379
 
380
  css = """
381
  .fillable{max-width: 1200px !important}
382
- .progress-text {color: white}
383
  """
384
 
385
  if __name__ == "__main__":
386
- demo.launch(theme=gr.themes.Citrus(), css=css)
387
-
 
6
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
  os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
 
 
 
9
 
10
  # Clone LTX-2 repo and install packages
11
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
12
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
13
+
14
+ LTX_COMMIT = "ae855f8538843825f9015a419cf4ba5edaf5eec2" # known working commit with decode_video
15
 
16
  if not os.path.exists(LTX_REPO_DIR):
17
  print(f"Cloning {LTX_REPO_URL}...")
18
+ subprocess.run(["git", "clone", LTX_REPO_URL, LTX_REPO_DIR], check=True)
19
+ subprocess.run(["git", "checkout", LTX_COMMIT], cwd=LTX_REPO_DIR, check=True)
 
 
 
 
20
 
21
  print("Installing ltx-core and ltx-pipelines from cloned repo...")
22
  subprocess.run(
 
33
  import random
34
  import tempfile
35
  from pathlib import Path
36
+ import gc
37
+ import hashlib
38
 
39
  import torch
40
  torch._dynamo.config.suppress_errors = True
 
44
  import gradio as gr
45
  import numpy as np
46
  from huggingface_hub import hf_hub_download, snapshot_download
47
+ from safetensors.torch import load_file, save_file
48
+ from safetensors import safe_open
49
+ import json
50
+ import requests
51
 
52
+ from ltx_core.components.diffusion_steps import EulerDiffusionStep
53
+ from ltx_core.components.noisers import GaussianNoiser
54
+ from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
55
+ from ltx_core.model.upsampler import upsample_video
56
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
57
  from ltx_core.quantization import QuantizationPolicy
58
+ from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
59
  from ltx_pipelines.distilled import DistilledPipeline
60
+ from ltx_pipelines.utils import euler_denoising_loop
61
  from ltx_pipelines.utils.args import ImageConditioningInput
62
+ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
63
+ from ltx_pipelines.utils.helpers import (
64
+ cleanup_memory,
65
+ combined_image_conditionings,
66
+ denoise_video_only,
67
+ encode_prompts,
68
+ simple_denoising_func,
69
+ )
70
+ from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
71
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
72
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  logging.getLogger().setLevel(logging.INFO)
75
 
 
88
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
89
  }
90
 
 
 
91
 
92
+ class LTX23DistilledA2VPipeline(DistilledPipeline):
93
+ """DistilledPipeline with optional audio conditioning."""
94
+
95
+ def __call__(
96
+ self,
97
+ prompt: str,
98
+ seed: int,
99
+ height: int,
100
+ width: int,
101
+ num_frames: int,
102
+ frame_rate: float,
103
+ images: list[ImageConditioningInput],
104
+ audio_path: str | None = None,
105
+ tiling_config: TilingConfig | None = None,
106
+ enhance_prompt: bool = False,
107
+ ):
108
+ # Standard path when no audio input is provided.
109
+ print(prompt)
110
+ if audio_path is None:
111
+ return super().__call__(
112
+ prompt=prompt,
113
+ seed=seed,
114
+ height=height,
115
+ width=width,
116
+ num_frames=num_frames,
117
+ frame_rate=frame_rate,
118
+ images=images,
119
+ tiling_config=tiling_config,
120
+ enhance_prompt=enhance_prompt,
121
+ )
122
+
123
+ generator = torch.Generator(device=self.device).manual_seed(seed)
124
+ noiser = GaussianNoiser(generator=generator)
125
+ stepper = EulerDiffusionStep()
126
+ dtype = torch.bfloat16
127
+
128
+ (ctx_p,) = encode_prompts(
129
+ [prompt],
130
+ self.model_ledger,
131
+ enhance_first_prompt=enhance_prompt,
132
+ enhance_prompt_image=images[0].path if len(images) > 0 else None,
133
+ )
134
+ video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
135
+
136
+ video_duration = num_frames / frame_rate
137
+ decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
138
+ if decoded_audio is None:
139
+ raise ValueError(f"Could not extract audio stream from {audio_path}")
140
+
141
+ encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
142
+ audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
143
+ expected_frames = audio_shape.frames
144
+ actual_frames = encoded_audio_latent.shape[2]
145
+
146
+ if actual_frames > expected_frames:
147
+ encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
148
+ elif actual_frames < expected_frames:
149
+ pad = torch.zeros(
150
+ encoded_audio_latent.shape[0],
151
+ encoded_audio_latent.shape[1],
152
+ expected_frames - actual_frames,
153
+ encoded_audio_latent.shape[3],
154
+ device=encoded_audio_latent.device,
155
+ dtype=encoded_audio_latent.dtype,
156
+ )
157
+ encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
158
+
159
+ video_encoder = self.model_ledger.video_encoder()
160
+ transformer = self.model_ledger.transformer()
161
+ stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
162
+
163
+ def denoising_loop(sigmas, video_state, audio_state, stepper):
164
+ return euler_denoising_loop(
165
+ sigmas=sigmas,
166
+ video_state=video_state,
167
+ audio_state=audio_state,
168
+ stepper=stepper,
169
+ denoise_fn=simple_denoising_func(
170
+ video_context=video_context,
171
+ audio_context=audio_context,
172
+ transformer=transformer,
173
+ ),
174
+ )
175
+
176
+ stage_1_output_shape = VideoPixelShape(
177
+ batch=1,
178
+ frames=num_frames,
179
+ width=width // 2,
180
+ height=height // 2,
181
+ fps=frame_rate,
182
+ )
183
+ stage_1_conditionings = combined_image_conditionings(
184
+ images=images,
185
+ height=stage_1_output_shape.height,
186
+ width=stage_1_output_shape.width,
187
+ video_encoder=video_encoder,
188
+ dtype=dtype,
189
+ device=self.device,
190
+ )
191
+ video_state = denoise_video_only(
192
+ output_shape=stage_1_output_shape,
193
+ conditionings=stage_1_conditionings,
194
+ noiser=noiser,
195
+ sigmas=stage_1_sigmas,
196
+ stepper=stepper,
197
+ denoising_loop_fn=denoising_loop,
198
+ components=self.pipeline_components,
199
+ dtype=dtype,
200
+ device=self.device,
201
+ initial_audio_latent=encoded_audio_latent,
202
+ )
203
+
204
+ torch.cuda.synchronize()
205
+ cleanup_memory()
206
+
207
+ upscaled_video_latent = upsample_video(
208
+ latent=video_state.latent[:1],
209
+ video_encoder=video_encoder,
210
+ upsampler=self.model_ledger.spatial_upsampler(),
211
+ )
212
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
213
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
214
+ stage_2_conditionings = combined_image_conditionings(
215
+ images=images,
216
+ height=stage_2_output_shape.height,
217
+ width=stage_2_output_shape.width,
218
+ video_encoder=video_encoder,
219
+ dtype=dtype,
220
+ device=self.device,
221
+ )
222
+ video_state = denoise_video_only(
223
+ output_shape=stage_2_output_shape,
224
+ conditionings=stage_2_conditionings,
225
+ noiser=noiser,
226
+ sigmas=stage_2_sigmas,
227
+ stepper=stepper,
228
+ denoising_loop_fn=denoising_loop,
229
+ components=self.pipeline_components,
230
+ dtype=dtype,
231
+ device=self.device,
232
+ noise_scale=stage_2_sigmas[0],
233
+ initial_video_latent=upscaled_video_latent,
234
+ initial_audio_latent=encoded_audio_latent,
235
+ )
236
+
237
+ torch.cuda.synchronize()
238
+ del transformer
239
+ del video_encoder
240
+ cleanup_memory()
241
+
242
+ decoded_video = vae_decode_video(
243
+ video_state.latent,
244
+ self.model_ledger.video_decoder(),
245
+ tiling_config,
246
+ generator,
247
+ )
248
+ original_audio = Audio(
249
+ waveform=decoded_audio.waveform.squeeze(0),
250
+ sampling_rate=decoded_audio.sampling_rate,
251
+ )
252
+ return decoded_video, original_audio
253
+
254
+
255
+ # Model repos
256
+ LTX_MODEL_REPO = "Lightricks/LTX-2.3"
257
+ GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
258
+
259
+ # Download model checkpoints
260
+ print("=" * 80)
261
+ print("Downloading LTX-2.3 distilled model + Gemma...")
262
+ print("=" * 80)
263
+
264
+ # LoRA cache directory and currently-applied key
265
+ LORA_CACHE_DIR = Path("lora_cache")
266
+ LORA_CACHE_DIR.mkdir(exist_ok=True)
267
+ current_lora_key: str | None = None
268
+
269
+ PENDING_LORA_KEY: str | None = None
270
+ PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
271
+ PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
272
+
273
+ weights_dir = Path("weights")
274
+ weights_dir.mkdir(exist_ok=True)
275
+ checkpoint_path = hf_hub_download(
276
+ repo_id="SulphurAI/Sulphur-2-base",
277
+ filename="sulphur_distil_bf16.safetensors",
278
+ local_dir=str(weights_dir),
279
+ local_dir_use_symlinks=False,
280
+ )
281
+
282
+ spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
283
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
284
+
285
 
286
+ # ---- Insert block (LoRA downloads) between lines 268 and 269 ----
287
+ # LoRA repo + download the requested LoRA adapters
288
+ LORA_REPO = "dagloop5/LoRA"
289
 
290
+ print("=" * 80)
291
+ print("Downloading LoRA adapters from dagloop5/LoRA...")
292
+ print("=" * 80)
293
+ pose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2_3_NSFW_furry_concat_v2.safetensors")
294
+ general_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_reasoning_I2V_V3.safetensors")
295
+ motion_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="Sulphur_LTX 2.3_better _NSFW_motion.safetensors")
296
+ dreamlay_lora_path = hf_hub_download(repo_id="lynaNSFW/DR34ML4Y_AIO_NSFW_LTX23", filename="DR34ML4Y_LTXXX_V2.safetensors") # m15510n4ry, bl0wj0b, d0ubl3_bj, d0gg1e, c0wg1rl
297
+ mself_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="Furry Hyper Masturbation - LTX-2 I2V v1.safetensors") # Hyperfap
298
+ dramatic_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2.3 - Orgasm.safetensors") # buddr
299
+ fluid_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_CREAMPIE_ANIMATION-V0.1.safetensors") # cum
300
+ liquid_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="liquid_wet_dr1pp_ltx2_v1.0_scaled.safetensors") # wet dr1pp
301
+ demopose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="clapping-cheeks-audio-v001-alpha.safetensors")
302
+ voice_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="hentai_voice_ltx23.safetensors")
303
+ realism_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="FurryenhancerLTX2.3V4.094fused.safetensors")
304
+ transition_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2_takerpov_lora_v1.2.safetensors") # takerpov1, taker pov
305
+ physics_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_Physics_V2_000002000.safetensors")
306
+ reasoning_lora_path = hf_hub_download(repo_id="LiconStudio/Ltx2.3-VBVR-lora-I2V", filename="Ltx2.3-Licon-VBVR-I2V-390K-R32.safetensors")
307
+ twostep_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_Multi_step_video_reasoning_V0.1.safetensors")
308
+
309
+ print(f"Pose LoRA: {pose_lora_path}")
310
+ print(f"General LoRA: {general_lora_path}")
311
+ print(f"Motion LoRA: {motion_lora_path}")
312
+ print(f"Dreamlay LoRA: {dreamlay_lora_path}")
313
+ print(f"Mself LoRA: {mself_lora_path}")
314
+ print(f"Dramatic LoRA: {dramatic_lora_path}")
315
+ print(f"Fluid LoRA: {fluid_lora_path}")
316
+ print(f"Liquid LoRA: {liquid_lora_path}")
317
+ print(f"Demopose LoRA: {demopose_lora_path}")
318
+ print(f"Voice LoRA: {voice_lora_path}")
319
+ print(f"Realism LoRA: {realism_lora_path}")
320
+ print(f"Transition LoRA: {transition_lora_path}")
321
+ print(f"Physics LoRA: {physics_lora_path}")
322
+ print(f"Reasoning LoRA: {reasoning_lora_path}")
323
+ print(f"Twostep LoRA: {twostep_lora_path}")
324
+ # ----------------------------------------------------------------
325
+
326
+ print(f"Checkpoint: {checkpoint_path}")
327
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
328
+ print(f"[Gemma] Root ready: {gemma_root}")
329
+
330
+ # Initialize pipeline WITH text encoder and optional audio support
331
+ # ---- Replace block (pipeline init) lines 275-281 ----
332
+ pipeline = LTX23DistilledA2VPipeline(
333
+ distilled_checkpoint_path=checkpoint_path,
334
  spatial_upsampler_path=spatial_upsampler_path,
335
  gemma_root=gemma_root,
336
  loras=[],
337
+ quantization=QuantizationPolicy.fp8_cast(), # keep FP8 quantization unchanged
338
  )
339
+ # ----------------------------------------------------------------
340
+
341
+ def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
342
+ rp = round(float(pose_strength), 2)
343
+ rg = round(float(general_strength), 2)
344
+ rm = round(float(motion_strength), 2)
345
+ rd = round(float(dreamlay_strength), 2)
346
+ rs = round(float(mself_strength), 2)
347
+ rr = round(float(dramatic_strength), 2)
348
+ rf = round(float(fluid_strength), 2)
349
+ rl = round(float(liquid_strength), 2)
350
+ ro = round(float(demopose_strength), 2)
351
+ rv = round(float(voice_strength), 2)
352
+ re = round(float(realism_strength), 2)
353
+ rt = round(float(transition_strength), 2)
354
+ ry = round(float(physics_strength), 2)
355
+ ri = round(float(reasoning_strength), 2)
356
+ rw = round(float(twostep_strength), 2)
357
+ key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}|{dreamlay_lora_path}:{rd}|{mself_lora_path}:{rs}|{dramatic_lora_path}:{rr}|{fluid_lora_path}:{rf}|{liquid_lora_path}:{rl}|{demopose_lora_path}:{ro}|{voice_lora_path}:{rv}|{realism_lora_path}:{re}|{transition_lora_path}:{rt}|{physics_lora_path}:{ry}|{reasoning_lora_path}:{ri}|{twostep_lora_path}:{rw}"
358
+ key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
359
+ return key, key_str
360
+
361
+
362
+ def prepare_lora_cache(
363
+ pose_strength: float,
364
+ general_strength: float,
365
+ motion_strength: float,
366
+ dreamlay_strength: float,
367
+ mself_strength: float,
368
+ dramatic_strength: float,
369
+ fluid_strength: float,
370
+ liquid_strength: float,
371
+ demopose_strength: float,
372
+ voice_strength: float,
373
+ realism_strength: float,
374
+ transition_strength: float,
375
+ physics_strength: float,
376
+ reasoning_strength: float,
377
+ twostep_strength: float,
378
+ progress=gr.Progress(track_tqdm=True),
379
+ ):
380
+ """
381
+ CPU-only step:
382
+ - checks cache
383
+ - loads cached fused transformer state_dict, or
384
+ - builds fused transformer on CPU and saves it
385
+ The resulting state_dict is stored in memory and can be applied later.
386
+ """
387
+ global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
388
+
389
+ ledger = pipeline.model_ledger
390
+ key, _ = _make_lora_key(pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength)
391
+ cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
392
+
393
+ progress(0.05, desc="Preparing LoRA state")
394
+ if cache_path.exists():
395
+ try:
396
+ progress(0.20, desc="Loading cached fused state")
397
+ state = load_file(str(cache_path))
398
+ PENDING_LORA_KEY = key
399
+ PENDING_LORA_STATE = state
400
+ PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
401
+ return PENDING_LORA_STATUS
402
+ except Exception as e:
403
+ print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
404
+
405
+ entries = [
406
+ (pose_lora_path, round(float(pose_strength), 2)),
407
+ (general_lora_path, round(float(general_strength), 2)),
408
+ (motion_lora_path, round(float(motion_strength), 2)),
409
+ (dreamlay_lora_path, round(float(dreamlay_strength), 2)),
410
+ (mself_lora_path, round(float(mself_strength), 2)),
411
+ (dramatic_lora_path, round(float(dramatic_strength), 2)),
412
+ (fluid_lora_path, round(float(fluid_strength), 2)),
413
+ (liquid_lora_path, round(float(liquid_strength), 2)),
414
+ (demopose_lora_path, round(float(demopose_strength), 2)),
415
+ (voice_lora_path, round(float(voice_strength), 2)),
416
+ (realism_lora_path, round(float(realism_strength), 2)),
417
+ (transition_lora_path, round(float(transition_strength), 2)),
418
+ (physics_lora_path, round(float(physics_strength), 2)),
419
+ (reasoning_lora_path, round(float(reasoning_strength), 2)),
420
+ (twostep_lora_path, round(float(twostep_strength), 2)),
421
+ ]
422
+ loras_for_builder = [
423
+ LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
424
+ for path, strength in entries
425
+ if path is not None and float(strength) != 0.0
426
+ ]
427
+
428
+ if not loras_for_builder:
429
+ PENDING_LORA_KEY = None
430
+ PENDING_LORA_STATE = None
431
+ PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
432
+ return PENDING_LORA_STATUS
433
+
434
+ tmp_ledger = None
435
+ new_transformer_cpu = None
436
+ try:
437
+ progress(0.35, desc="Building fused CPU transformer")
438
+ tmp_ledger = pipeline.model_ledger.__class__(
439
+ dtype=ledger.dtype,
440
+ device=torch.device("cpu"),
441
+ checkpoint_path=str(checkpoint_path),
442
+ spatial_upsampler_path=str(spatial_upsampler_path),
443
+ gemma_root_path=str(gemma_root),
444
+ loras=tuple(loras_for_builder),
445
+ quantization=getattr(ledger, "quantization", None),
446
+ )
447
+ new_transformer_cpu = tmp_ledger.transformer()
448
 
449
+ progress(0.70, desc="Extracting fused state_dict")
450
+ state = {
451
+ k: v.detach().cpu().contiguous()
452
+ for k, v in new_transformer_cpu.state_dict().items()
453
+ }
454
+ save_file(state, str(cache_path))
455
+
456
+ PENDING_LORA_KEY = key
457
+ PENDING_LORA_STATE = state
458
+ PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
459
+ return PENDING_LORA_STATUS
460
+
461
+ except Exception as e:
462
+ import traceback
463
+ print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
464
+ print(traceback.format_exc())
465
+ PENDING_LORA_KEY = None
466
+ PENDING_LORA_STATE = None
467
+ PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
468
+ return PENDING_LORA_STATUS
469
+
470
+ finally:
471
+ try:
472
+ del new_transformer_cpu
473
+ except Exception:
474
+ pass
475
+ try:
476
+ del tmp_ledger
477
+ except Exception:
478
+ pass
479
+ gc.collect()
480
+
481
+
482
+ def apply_prepared_lora_state_to_pipeline():
483
+ """
484
+ Fast step: copy the already prepared CPU state into the live transformer.
485
+ This is the only part that should remain near generation time.
486
+ """
487
+ global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
488
+
489
+ if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
490
+ print("[LoRA] No prepared LoRA state available; skipping.")
491
+ return False
492
+
493
+ if current_lora_key == PENDING_LORA_KEY:
494
+ print("[LoRA] Prepared LoRA state already active; skipping.")
495
+ return True
496
+
497
+ existing_transformer = _transformer
498
+ with torch.no_grad():
499
+ missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
500
+ if missing or unexpected:
501
+ print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
502
+
503
+ current_lora_key = PENDING_LORA_KEY
504
+ print("[LoRA] Prepared LoRA state applied to the pipeline.")
505
+ return True
506
+
507
+ # ---- REPLACE PRELOAD BLOCK START ----
508
  # Preload all models for ZeroGPU tensor packing.
509
+ print("Preloading all models (including Gemma and audio components)...")
510
  ledger = pipeline.model_ledger
 
 
 
 
 
 
 
 
511
 
512
+ # Save the original factory methods so we can rebuild individual components later.
513
+ # These are bound callables on ledger that will call the builder when invoked.
514
+ _orig_transformer_factory = ledger.transformer
515
+ _orig_video_encoder_factory = ledger.video_encoder
516
+ _orig_video_decoder_factory = ledger.video_decoder
517
+ _orig_audio_encoder_factory = ledger.audio_encoder
518
+ _orig_audio_decoder_factory = ledger.audio_decoder
519
+ _orig_vocoder_factory = ledger.vocoder
520
+ _orig_spatial_upsampler_factory = ledger.spatial_upsampler
521
+ _orig_text_encoder_factory = ledger.text_encoder
522
+ _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
523
+
524
+ # Call the original factories once to create the cached instances we will serve by default.
525
+ _transformer = _orig_transformer_factory()
526
+ _video_encoder = _orig_video_encoder_factory()
527
+ _video_decoder = _orig_video_decoder_factory()
528
+ _audio_encoder = _orig_audio_encoder_factory()
529
+ _audio_decoder = _orig_audio_decoder_factory()
530
+ _vocoder = _orig_vocoder_factory()
531
+ _spatial_upsampler = _orig_spatial_upsampler_factory()
532
+ _text_encoder = _orig_text_encoder_factory()
533
+ _embeddings_processor = _orig_gemma_embeddings_factory()
534
+
535
+ # Replace ledger methods with lightweight lambdas that return the cached instances.
536
+ # We keep the original factories above so we can call them later to rebuild components.
537
  ledger.transformer = lambda: _transformer
538
  ledger.video_encoder = lambda: _video_encoder
539
  ledger.video_decoder = lambda: _video_decoder
540
+ ledger.audio_encoder = lambda: _audio_encoder
541
  ledger.audio_decoder = lambda: _audio_decoder
542
  ledger.vocoder = lambda: _vocoder
543
  ledger.spatial_upsampler = lambda: _spatial_upsampler
544
  ledger.text_encoder = lambda: _text_encoder
545
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
546
+
547
+ print("All models preloaded (including Gemma text encoder and audio encoder)!")
548
+ # ---- REPLACE PRELOAD BLOCK END ----
549
 
550
  print("=" * 80)
551
  print("Pipeline ready!")
 
561
 
562
 
563
  def detect_aspect_ratio(image) -> str:
 
564
  if image is None:
565
  return "16:9"
566
  if hasattr(image, "size"):
 
574
  return min(candidates, key=lambda k: abs(ratio - candidates[k]))
575
 
576
 
577
+ def on_image_upload(first_image, last_image, high_res):
578
+ ref_image = first_image if first_image is not None else last_image
579
+ aspect = detect_aspect_ratio(ref_image)
580
  tier = "high" if high_res else "low"
581
  w, h = RESOLUTIONS[tier][aspect]
582
  return gr.update(value=w), gr.update(value=h)
583
 
584
 
585
+ def on_highres_toggle(first_image, last_image, high_res):
586
+ ref_image = first_image if first_image is not None else last_image
587
+ aspect = detect_aspect_ratio(ref_image)
588
  tier = "high" if high_res else "low"
589
  w, h = RESOLUTIONS[tier][aspect]
590
  return gr.update(value=w), gr.update(value=h)
591
 
592
 
593
+ def get_gpu_duration(
594
+ first_image,
595
+ last_image,
596
+ input_audio,
597
+ prompt: str,
598
+ duration: float = 0.0,
599
+ gpu_duration: float = 0.0,
600
+ enhance_prompt: bool = False,
601
+ seed: int = 42,
602
+ randomize_seed: bool = False,
603
+ height: int = 0.0,
604
+ width: int = 0.0,
605
+ pose_strength: float = 0.0,
606
+ general_strength: float = 0.0,
607
+ motion_strength: float = 0.0,
608
+ dreamlay_strength: float = 0.0,
609
+ mself_strength: float = 0.0,
610
+ dramatic_strength: float = 0.0,
611
+ fluid_strength: float = 0.0,
612
+ liquid_strength: float = 0.0,
613
+ demopose_strength: float = 0.0,
614
+ voice_strength: float = 0.0,
615
+ realism_strength: float = 0.0,
616
+ transition_strength: float = 0.0,
617
+ physics_strength: float = 0.0,
618
+ reasoning_strength: float = 0.0,
619
+ twostep_strength: float = 0.0,
620
+ progress=None,
621
+ ):
622
+ return int(gpu_duration)
623
+
624
+ @spaces.GPU(duration=get_gpu_duration)
625
  @torch.inference_mode()
626
  def generate_video(
627
+ first_image,
628
+ last_image,
629
+ input_audio,
630
  prompt: str,
631
+ duration: float = 0.0,
632
+ gpu_duration: float = 0.0,
633
  enhance_prompt: bool = True,
634
  seed: int = 42,
635
  randomize_seed: bool = True,
636
+ height: int = 0.0,
637
+ width: int = 0.0,
638
+ pose_strength: float = 0.0,
639
+ general_strength: float = 0.0,
640
+ motion_strength: float = 0.0,
641
+ dreamlay_strength: float = 0.0,
642
+ mself_strength: float = 0.0,
643
+ dramatic_strength: float = 0.0,
644
+ fluid_strength: float = 0.0,
645
+ liquid_strength: float = 0.0,
646
+ demopose_strength: float = 0.0,
647
+ voice_strength: float = 0.0,
648
+ realism_strength: float = 0.0,
649
+ transition_strength: float = 0.0,
650
+ physics_strength: float = 0.0,
651
+ reasoning_strength: float = 0.0,
652
+ twostep_strength: float = 0.0,
653
  progress=gr.Progress(track_tqdm=True),
654
  ):
655
  try:
 
665
  print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")
666
 
667
  images = []
668
+ output_dir = Path("outputs")
669
+ output_dir.mkdir(exist_ok=True)
670
+
671
+ if first_image is not None:
672
+ temp_first_path = output_dir / f"temp_first_{current_seed}.jpg"
673
+ if hasattr(first_image, "save"):
674
+ first_image.save(temp_first_path)
675
  else:
676
+ temp_first_path = Path(first_image)
677
+ images.append(ImageConditioningInput(path=str(temp_first_path), frame_idx=0, strength=1.0))
678
+
679
+ if last_image is not None:
680
+ temp_last_path = output_dir / f"temp_last_{current_seed}.jpg"
681
+ if hasattr(last_image, "save"):
682
+ last_image.save(temp_last_path)
683
+ else:
684
+ temp_last_path = Path(last_image)
685
+ images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
686
 
687
  tiling_config = TilingConfig.default()
688
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
689
 
690
  log_memory("before pipeline call")
691
 
692
+ apply_prepared_lora_state_to_pipeline()
693
+
694
  video, audio = pipeline(
695
  prompt=prompt,
696
  seed=current_seed,
 
699
  num_frames=num_frames,
700
  frame_rate=frame_rate,
701
  images=images,
702
+ audio_path=input_audio,
703
  tiling_config=tiling_config,
704
  enhance_prompt=enhance_prompt,
705
  )
 
726
 
727
 
728
  with gr.Blocks(title="LTX-2.3 Distilled") as demo:
729
+ gr.Markdown("# LTX-2.3 F2LF with Fast Audio-Video Generation with Frame Conditioning")
730
+
 
 
 
 
731
 
732
  with gr.Row():
733
  with gr.Column():
734
+ with gr.Row():
735
+ first_image = gr.Image(label="First Frame (Optional)", type="pil")
736
+ last_image = gr.Image(label="Last Frame (Optional)", type="pil")
737
+ input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
738
  prompt = gr.Textbox(
739
  label="Prompt",
740
  info="for best results - make it as elaborate as possible",
 
742
  lines=3,
743
  placeholder="Describe the motion and animation you want...",
744
  )
745
+ duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=30.0, value=10.0, step=0.1)
746
+
 
 
 
 
747
 
748
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
749
 
 
753
  with gr.Row():
754
  width = gr.Number(label="Width", value=1536, precision=0)
755
  height = gr.Number(label="Height", value=1024, precision=0)
756
+ with gr.Row():
757
+ enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
758
+ high_res = gr.Checkbox(label="High Resolution", value=True)
759
+ with gr.Column():
760
+ gr.Markdown("### LoRA adapter strengths (set to 0 to disable; slow and WIP)")
761
+ pose_strength = gr.Slider(
762
+ label="Anthro Enhancer strength",
763
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
764
+ )
765
+ general_strength = gr.Slider(
766
+ label="Reasoning Enhancer strength",
767
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
768
+ )
769
+ motion_strength = gr.Slider(
770
+ label="Anthro Posing Helper strength",
771
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
772
+ )
773
+ dreamlay_strength = gr.Slider(
774
+ label="Dreamlay strength",
775
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
776
+ )
777
+ mself_strength = gr.Slider(
778
+ label="Mself strength",
779
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
780
+ )
781
+ dramatic_strength = gr.Slider(
782
+ label="Dramatic strength",
783
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
784
+ )
785
+ fluid_strength = gr.Slider(
786
+ label="Fluid Helper strength",
787
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
788
+ )
789
+ liquid_strength = gr.Slider(
790
+ label="Liquid Helper strength",
791
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
792
+ )
793
+ demopose_strength = gr.Slider(
794
+ label="Audio Helper strength",
795
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
796
+ )
797
+ voice_strength = gr.Slider(
798
+ label="Voice Helper strength",
799
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
800
+ )
801
+ realism_strength = gr.Slider(
802
+ label="Anthro Realism strength",
803
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
804
+ )
805
+ transition_strength = gr.Slider(
806
+ label="POV strength",
807
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
808
+ )
809
+ physics_strength = gr.Slider(
810
+ label="Physics strength",
811
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
812
+ )
813
+ reasoning_strength = gr.Slider(
814
+ label="Official Reasoning strength",
815
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
816
+ )
817
+ twostep_strength = gr.Slider(
818
+ label="Two Step Reasoning strength",
819
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
820
+ )
821
+ prepare_lora_btn = gr.Button("Prepare / Load LoRA Cache", variant="secondary")
822
+ lora_status = gr.Textbox(
823
+ label="LoRA Cache Status",
824
+ value="No LoRA state prepared yet.",
825
+ interactive=False,
826
+ )
827
 
828
  with gr.Column():
829
+ output_video = gr.Video(label="Generated Video", autoplay=False)
830
+ gpu_duration = gr.Slider(
831
+ label="ZeroGPU duration (seconds; 10 second Img2Vid with 1024x1024 and LoRAs = ~70)",
832
+ minimum=30.0,
833
+ maximum=240.0,
834
+ value=75.0,
835
+ step=1.0,
836
+ )
837
 
838
+ gr.Examples(
839
+ examples=[
840
+ [
841
+ None,
842
+ "pinkknit.jpg",
843
+ None,
844
+ "The camera falls downward through darkness as if dropped into a tunnel. "
845
+ "As it slows, five friends wearing pink knitted hats and sunglasses lean "
846
+ "over and look down toward the camera with curious expressions. The lens "
847
+ "has a strong fisheye effect, creating a circular frame around them. They "
848
+ "crowd together closely, forming a symmetrical cluster while staring "
849
+ "directly into the lens.",
850
+ 3.0,
851
+ 80.0,
852
+ False,
853
+ 42,
854
+ True,
855
+ 1024,
856
+ 1024,
857
+ 0.0, # pose_strength (example)
858
+ 0.0, # general_strength (example)
859
+ 0.0, # motion_strength (example)
860
+ 0.0,
861
+ 0.0,
862
+ 0.0,
863
+ 0.0,
864
+ 0.0,
865
+ 0.0,
866
+ 0.0,
867
+ 0.0,
868
+ 0.0,
869
+ 0.0,
870
+ 0.0,
871
+ 0.0,
872
+ ],
873
+ ],
874
+ inputs=[
875
+ first_image, last_image, input_audio, prompt, duration, gpu_duration,
876
+ enhance_prompt, seed, randomize_seed, height, width,
877
+ pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength,
878
+ ],
879
+ )
880
+
881
+ first_image.change(
882
+ fn=on_image_upload,
883
+ inputs=[first_image, last_image, high_res],
884
+ outputs=[width, height],
885
+ )
886
+
887
+ last_image.change(
888
  fn=on_image_upload,
889
+ inputs=[first_image, last_image, high_res],
890
  outputs=[width, height],
891
  )
892
 
 
893
  high_res.change(
894
  fn=on_highres_toggle,
895
+ inputs=[first_image, last_image, high_res],
896
  outputs=[width, height],
897
  )
898
 
899
+ prepare_lora_btn.click(
900
+ fn=prepare_lora_cache,
901
+ inputs=[pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength],
902
+ outputs=[lora_status],
903
+ )
904
+
905
  generate_btn.click(
906
  fn=generate_video,
907
  inputs=[
908
+ first_image, last_image, input_audio, prompt, duration, gpu_duration, enhance_prompt,
909
  seed, randomize_seed, height, width,
910
+ pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength,
911
  ],
912
  outputs=[output_video, seed],
913
  )
 
915
 
916
  css = """
917
  .fillable{max-width: 1200px !important}
 
918
  """
919
 
920
  if __name__ == "__main__":
921
+ demo.launch(theme=gr.themes.Citrus(), css=css)