dagloop5 commited on
Commit
61ae37f
·
verified ·
1 Parent(s): 844ace6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -225
app.py CHANGED
@@ -51,25 +51,29 @@ from safetensors import safe_open
51
  import json
52
  import requests
53
 
54
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
 
55
  from ltx_core.components.noisers import GaussianNoiser
56
- from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
57
- from ltx_core.model.upsampler import upsample_video
58
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
59
- from ltx_core.quantization import QuantizationPolicy
60
- from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
61
- from ltx_pipelines.distilled import DistilledPipeline
62
- from ltx_pipelines.utils import euler_denoising_loop
63
- from ltx_pipelines.utils.args import ImageConditioningInput
64
- from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
 
 
 
 
65
  from ltx_pipelines.utils.helpers import (
66
- cleanup_memory,
67
  combined_image_conditionings,
68
- denoise_video_only,
69
- encode_prompts,
70
- simple_denoising_func,
71
  )
72
- from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
 
73
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
74
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
75
 
@@ -101,167 +105,169 @@ RESOLUTIONS = {
101
  }
102
 
103
 
104
- class LTX23DistilledA2VPipeline(DistilledPipeline):
105
- """DistilledPipeline with optional audio conditioning."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def __call__(
108
  self,
109
  prompt: str,
 
110
  seed: int,
111
  height: int,
112
  width: int,
113
  num_frames: int,
114
  frame_rate: float,
115
  images: list[ImageConditioningInput],
116
- audio_path: str | None = None,
117
  tiling_config: TilingConfig | None = None,
118
  enhance_prompt: bool = False,
119
- ):
120
- # Standard path when no audio input is provided.
121
- print(prompt)
122
- if audio_path is None:
123
- return super().__call__(
124
- prompt=prompt,
125
- seed=seed,
126
- height=height,
127
- width=width,
128
- num_frames=num_frames,
129
- frame_rate=frame_rate,
130
- images=images,
131
- tiling_config=tiling_config,
132
- enhance_prompt=enhance_prompt,
133
- )
134
 
135
  generator = torch.Generator(device=self.device).manual_seed(seed)
136
  noiser = GaussianNoiser(generator=generator)
137
- stepper = EulerDiffusionStep()
138
  dtype = torch.bfloat16
139
 
140
- (ctx_p,) = encode_prompts(
141
- [prompt],
142
- self.model_ledger,
143
  enhance_first_prompt=enhance_prompt,
144
- enhance_prompt_image=images[0].path if len(images) > 0 else None,
 
 
145
  )
146
- video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
147
-
148
- video_duration = num_frames / frame_rate
149
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
150
- if decoded_audio is None:
151
- raise ValueError(f"Could not extract audio stream from {audio_path}")
152
-
153
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
154
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
155
- expected_frames = audio_shape.frames
156
- actual_frames = encoded_audio_latent.shape[2]
157
-
158
- if actual_frames > expected_frames:
159
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
160
- elif actual_frames < expected_frames:
161
- pad = torch.zeros(
162
- encoded_audio_latent.shape[0],
163
- encoded_audio_latent.shape[1],
164
- expected_frames - actual_frames,
165
- encoded_audio_latent.shape[3],
166
- device=encoded_audio_latent.device,
167
- dtype=encoded_audio_latent.dtype,
168
- )
169
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
170
-
171
- video_encoder = self.model_ledger.video_encoder()
172
- transformer = self.model_ledger.transformer()
173
- stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
174
-
175
- def denoising_loop(sigmas, video_state, audio_state, stepper):
176
- return euler_denoising_loop(
177
- sigmas=sigmas,
178
- video_state=video_state,
179
- audio_state=audio_state,
180
- stepper=stepper,
181
- denoise_fn=simple_denoising_func(
182
- video_context=video_context,
183
- audio_context=audio_context,
184
- transformer=transformer,
185
- ),
186
- )
187
 
188
  stage_1_output_shape = VideoPixelShape(
189
- batch=1,
190
- frames=num_frames,
191
- width=width // 2,
192
- height=height // 2,
193
- fps=frame_rate,
194
  )
195
- stage_1_conditionings = combined_image_conditionings(
196
- images=images,
197
- height=stage_1_output_shape.height,
198
- width=stage_1_output_shape.width,
199
- video_encoder=video_encoder,
200
- dtype=dtype,
201
- device=self.device,
 
 
202
  )
203
- video_state = denoise_video_only(
204
- output_shape=stage_1_output_shape,
205
- conditionings=stage_1_conditionings,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  noiser=noiser,
207
- sigmas=stage_1_sigmas,
208
  stepper=stepper,
209
- denoising_loop_fn=denoising_loop,
210
- components=self.pipeline_components,
211
- dtype=dtype,
212
- device=self.device,
213
- initial_audio_latent=encoded_audio_latent,
 
 
 
 
214
  )
215
 
216
- torch.cuda.synchronize()
217
- cleanup_memory()
218
 
219
- upscaled_video_latent = upsample_video(
220
- latent=video_state.latent[:1],
221
- video_encoder=video_encoder,
222
- upsampler=self.model_ledger.spatial_upsampler(),
223
- )
224
- stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
225
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
226
- stage_2_conditionings = combined_image_conditionings(
227
- images=images,
228
- height=stage_2_output_shape.height,
229
- width=stage_2_output_shape.width,
230
- video_encoder=video_encoder,
231
- dtype=dtype,
232
- device=self.device,
233
  )
234
- video_state = denoise_video_only(
235
- output_shape=stage_2_output_shape,
236
- conditionings=stage_2_conditionings,
 
237
  noiser=noiser,
238
- sigmas=stage_2_sigmas,
239
  stepper=stepper,
240
- denoising_loop_fn=denoising_loop,
241
- components=self.pipeline_components,
242
- dtype=dtype,
243
- device=self.device,
244
- noise_scale=stage_2_sigmas[0],
245
- initial_video_latent=upscaled_video_latent,
246
- initial_audio_latent=encoded_audio_latent,
 
 
 
 
 
 
 
 
 
 
247
  )
248
 
249
- torch.cuda.synchronize()
250
- del transformer
251
- del video_encoder
252
- cleanup_memory()
253
-
254
- decoded_video = vae_decode_video(
255
- video_state.latent,
256
- self.model_ledger.video_decoder(),
257
- tiling_config,
258
- generator,
259
- )
260
- original_audio = Audio(
261
- waveform=decoded_audio.waveform.squeeze(0),
262
- sampling_rate=decoded_audio.sampling_rate,
263
- )
264
- return decoded_video, original_audio
265
 
266
 
267
  # Model repos
@@ -276,11 +282,11 @@ print("=" * 80)
276
  # LoRA cache directory and currently-applied key
277
  LORA_CACHE_DIR = Path("lora_cache")
278
  LORA_CACHE_DIR.mkdir(exist_ok=True)
279
- current_lora_key: str | None = None
280
 
 
281
  PENDING_LORA_KEY: str | None = None
282
- PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
283
- PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
284
 
285
  weights_dir = Path("weights")
286
  weights_dir.mkdir(exist_ok=True)
@@ -376,29 +382,19 @@ def prepare_lora_cache(
376
  progress=gr.Progress(track_tqdm=True),
377
  ):
378
  """
379
- CPU-only step:
380
- - checks cache
381
- - loads cached fused transformer state_dict, or
382
- - builds fused transformer on CPU and saves it
383
- The resulting state_dict is stored in memory and can be applied later.
384
  """
385
- global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
386
-
387
- ledger = pipeline.model_ledger
388
- key, _ = _make_lora_key(pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength)
389
- cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
390
-
391
- progress(0.05, desc="Preparing LoRA state")
392
- if cache_path.exists():
393
- try:
394
- progress(0.20, desc="Loading cached fused state")
395
- state = load_file(str(cache_path))
396
- PENDING_LORA_KEY = key
397
- PENDING_LORA_STATE = state
398
- PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
399
- return PENDING_LORA_STATUS
400
- except Exception as e:
401
- print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
402
 
403
  entries = [
404
  (pose_lora_path, round(float(pose_strength), 2)),
@@ -414,6 +410,7 @@ def prepare_lora_cache(
414
  (realism_lora_path, round(float(realism_strength), 2)),
415
  (transition_lora_path, round(float(transition_strength), 2)),
416
  ]
 
417
  loras_for_builder = [
418
  LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
419
  for path, strength in entries
@@ -422,35 +419,31 @@ def prepare_lora_cache(
422
 
423
  if not loras_for_builder:
424
  PENDING_LORA_KEY = None
425
- PENDING_LORA_STATE = None
426
  PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
427
  return PENDING_LORA_STATUS
428
 
429
- tmp_ledger = None
430
- new_transformer_cpu = None
431
  try:
432
- progress(0.35, desc="Building fused CPU transformer")
433
- tmp_ledger = pipeline.model_ledger.__class__(
434
- dtype=ledger.dtype,
435
- device=torch.device("cpu"),
436
- checkpoint_path=str(checkpoint_path),
437
- spatial_upsampler_path=str(spatial_upsampler_path),
438
- gemma_root_path=str(gemma_root),
439
- loras=tuple(loras_for_builder),
440
- quantization=getattr(ledger, "quantization", None),
441
- )
442
- new_transformer_cpu = tmp_ledger.transformer()
443
-
444
- progress(0.70, desc="Extracting fused state_dict")
445
- state = {
446
- k: v.detach().cpu().contiguous()
447
- for k, v in new_transformer_cpu.state_dict().items()
448
- }
449
- save_file(state, str(cache_path))
450
 
451
  PENDING_LORA_KEY = key
452
- PENDING_LORA_STATE = state
453
- PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
454
  return PENDING_LORA_STATUS
455
 
456
  except Exception as e:
@@ -458,45 +451,32 @@ def prepare_lora_cache(
458
  print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
459
  print(traceback.format_exc())
460
  PENDING_LORA_KEY = None
461
- PENDING_LORA_STATE = None
462
  PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
463
  return PENDING_LORA_STATUS
464
 
465
- finally:
466
- try:
467
- del new_transformer_cpu
468
- except Exception:
469
- pass
470
- try:
471
- del tmp_ledger
472
- except Exception:
473
- pass
474
- gc.collect()
475
 
 
 
476
 
477
- def apply_prepared_lora_state_to_pipeline():
478
- """
479
- Fast step: copy the already prepared CPU state into the live transformer.
480
- This is the only part that should remain near generation time.
481
- """
482
- global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
483
-
484
- if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
485
- print("[LoRA] No prepared LoRA state available; skipping.")
486
  return False
487
 
488
  if current_lora_key == PENDING_LORA_KEY:
489
- print("[LoRA] Prepared LoRA state already active; skipping.")
490
  return True
491
 
492
- existing_transformer = _transformer
493
- with torch.no_grad():
494
- missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
495
- if missing or unexpected:
496
- print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
 
 
497
 
498
  current_lora_key = PENDING_LORA_KEY
499
- print("[LoRA] Prepared LoRA state applied to the pipeline.")
500
  return True
501
 
502
  # ---- REPLACE PRELOAD BLOCK START ----
@@ -588,8 +568,8 @@ def on_highres_toggle(first_image, last_image, high_res):
588
  def get_gpu_duration(
589
  first_image,
590
  last_image,
591
- input_audio,
592
  prompt: str,
 
593
  duration: float,
594
  gpu_duration: float,
595
  enhance_prompt: bool = True,
@@ -618,8 +598,8 @@ def get_gpu_duration(
618
  def generate_video(
619
  first_image,
620
  last_image,
621
- input_audio,
622
  prompt: str,
 
623
  duration: float,
624
  gpu_duration: float,
625
  enhance_prompt: bool = True,
@@ -682,15 +662,18 @@ def generate_video(
682
 
683
  video, audio = pipeline(
684
  prompt=prompt,
 
685
  seed=current_seed,
686
  height=int(height),
687
  width=int(width),
688
  num_frames=num_frames,
689
  frame_rate=frame_rate,
690
  images=images,
691
- audio_path=input_audio,
692
  tiling_config=tiling_config,
693
  enhance_prompt=enhance_prompt,
 
 
 
694
  )
695
 
696
  log_memory("after pipeline call")
@@ -723,7 +706,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
723
  with gr.Row():
724
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
725
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
726
- input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
727
  prompt = gr.Textbox(
728
  label="Prompt",
729
  info="for best results - make it as elaborate as possible",
@@ -731,6 +713,12 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
731
  lines=3,
732
  placeholder="Describe the motion and animation you want...",
733
  )
 
 
 
 
 
 
734
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=30.0, value=10.0, step=0.1)
735
 
736
 
@@ -817,13 +805,13 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
817
  [
818
  None,
819
  "pinkknit.jpg",
820
- None,
821
  "The camera falls downward through darkness as if dropped into a tunnel. "
822
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
823
  "over and look down toward the camera with curious expressions. The lens "
824
  "has a strong fisheye effect, creating a circular frame around them. They "
825
  "crowd together closely, forming a symmetrical cluster while staring "
826
  "directly into the lens.",
 
827
  3.0,
828
  80.0,
829
  False,
@@ -846,7 +834,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
846
  ],
847
  ],
848
  inputs=[
849
- first_image, last_image, input_audio, prompt, duration, gpu_duration,
850
  enhance_prompt, seed, randomize_seed, height, width,
851
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
852
  ],
@@ -879,7 +867,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
879
  generate_btn.click(
880
  fn=generate_video,
881
  inputs=[
882
- first_image, last_image, input_audio, prompt, duration, gpu_duration, enhance_prompt,
883
  seed, randomize_seed, height, width,
884
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
885
  ],
 
51
  import json
52
  import requests
53
 
54
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
55
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
56
  from ltx_core.components.noisers import GaussianNoiser
57
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
58
+ from ltx_core.types import Audio, VideoLatentShape, VideoPixelShape
59
+ from ltx_pipelines.utils.args import ImageConditioningInput, hq_2_stage_arg_parser
60
+ from ltx_pipelines.utils.blocks import (
61
+ AudioDecoder,
62
+ DiffusionStage,
63
+ ImageConditioner,
64
+ PromptEncoder,
65
+ VideoDecoder,
66
+ VideoUpsampler,
67
+ )
68
+ from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS, STAGE_2_DISTILLED_SIGMAS
69
+ from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
70
  from ltx_pipelines.utils.helpers import (
71
+ assert_resolution,
72
  combined_image_conditionings,
73
+ get_device,
 
 
74
  )
75
+ from ltx_pipelines.utils.media_io import encode_video
76
+ from ltx_pipelines.utils.samplers import res2s_audio_video_denoising_loop
77
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
78
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
79
 
 
105
  }
106
 
107
 
108
+ class LTX23NegativePromptTwoStagePipeline:
109
+ def __init__(
110
+ self,
111
+ checkpoint_path: str,
112
+ spatial_upsampler_path: str,
113
+ gemma_root: str,
114
+ loras: tuple[LoraPathStrengthAndSDOps, ...],
115
+ device: torch.device | None = None,
116
+ quantization: QuantizationPolicy | None = None,
117
+ registry: Registry | None = None,
118
+ torch_compile: bool = False,
119
+ ):
120
+ self.device = device or get_device()
121
+ self.dtype = torch.bfloat16
122
+ self._scheduler = LTX2Scheduler()
123
+
124
+ self.prompt_encoder = PromptEncoder(checkpoint_path, gemma_root, self.dtype, self.device, registry=registry)
125
+ self.image_conditioner = ImageConditioner(checkpoint_path, self.dtype, self.device, registry=registry)
126
+ self.upsampler = VideoUpsampler(checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry)
127
+ self.video_decoder = VideoDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
128
+ self.audio_decoder = AudioDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
129
+
130
+ self.stage_1 = DiffusionStage(
131
+ checkpoint_path,
132
+ self.dtype,
133
+ self.device,
134
+ loras=tuple(loras),
135
+ quantization=quantization,
136
+ registry=registry,
137
+ torch_compile=torch_compile,
138
+ )
139
+ self.stage_2 = DiffusionStage(
140
+ checkpoint_path,
141
+ self.dtype,
142
+ self.device,
143
+ loras=tuple(loras),
144
+ quantization=quantization,
145
+ registry=registry,
146
+ torch_compile=torch_compile,
147
+ )
148
 
149
  def __call__(
150
  self,
151
  prompt: str,
152
+ negative_prompt: str,
153
  seed: int,
154
  height: int,
155
  width: int,
156
  num_frames: int,
157
  frame_rate: float,
158
  images: list[ImageConditioningInput],
 
159
  tiling_config: TilingConfig | None = None,
160
  enhance_prompt: bool = False,
161
+ streaming_prefetch_count: int | None = None,
162
+ max_batch_size: int = 1,
163
+ stage_1_sigmas: torch.Tensor | None = None,
164
+ stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS,
165
+ video_guider_params: MultiModalGuiderParams | None = None,
166
+ audio_guider_params: MultiModalGuiderParams | None = None,
167
+ ) -> tuple[Iterator[torch.Tensor], Audio]:
168
+ assert_resolution(height=height, width=width, is_two_stage=True)
 
 
 
 
 
 
 
169
 
170
  generator = torch.Generator(device=self.device).manual_seed(seed)
171
  noiser = GaussianNoiser(generator=generator)
 
172
  dtype = torch.bfloat16
173
 
174
+ ctx_p, ctx_n = self.prompt_encoder(
175
+ [prompt, negative_prompt],
 
176
  enhance_first_prompt=enhance_prompt,
177
+ enhance_prompt_image=images[0][0] if len(images) > 0 else None,
178
+ enhance_prompt_seed=seed,
179
+ streaming_prefetch_count=streaming_prefetch_count,
180
  )
181
+ v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
182
+ v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  stage_1_output_shape = VideoPixelShape(
185
+ batch=1, frames=num_frames, width=width // 2, height=height // 2, fps=frame_rate
 
 
 
 
186
  )
187
+ stage_1_conditionings = self.image_conditioner(
188
+ lambda enc: combined_image_conditionings(
189
+ images=images,
190
+ height=stage_1_output_shape.height,
191
+ width=stage_1_output_shape.width,
192
+ video_encoder=enc,
193
+ dtype=dtype,
194
+ device=self.device,
195
+ )
196
  )
197
+
198
+ stepper = Res2sDiffusionStep()
199
+ if stage_1_sigmas is None:
200
+ empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape())
201
+ stage_1_sigmas = self._scheduler.execute(latent=empty_latent, steps=num_inference_steps)
202
+ sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device)
203
+
204
+ video_state, audio_state = self.stage_1(
205
+ denoiser=GuidedDenoiser(
206
+ v_context=v_context_p,
207
+ a_context=a_context_p,
208
+ video_guider=MultiModalGuider(
209
+ params=video_guider_params,
210
+ negative_context=v_context_n,
211
+ ),
212
+ audio_guider=MultiModalGuider(
213
+ params=audio_guider_params,
214
+ negative_context=a_context_n,
215
+ ),
216
+ ),
217
+ sigmas=sigmas,
218
  noiser=noiser,
 
219
  stepper=stepper,
220
+ width=stage_1_output_shape.width,
221
+ height=stage_1_output_shape.height,
222
+ frames=num_frames,
223
+ fps=frame_rate,
224
+ video=ModalitySpec(context=v_context_p, conditionings=stage_1_conditionings),
225
+ audio=ModalitySpec(context=a_context_p),
226
+ loop=res2s_audio_video_denoising_loop,
227
+ streaming_prefetch_count=streaming_prefetch_count,
228
+ max_batch_size=max_batch_size,
229
  )
230
 
231
+ upscaled_video_latent = self.upsampler(video_state.latent[:1])
 
232
 
233
+ stage_2_conditionings = self.image_conditioner(
234
+ lambda enc: combined_image_conditionings(
235
+ images=images,
236
+ height=height,
237
+ width=width,
238
+ video_encoder=enc,
239
+ dtype=dtype,
240
+ device=self.device,
241
+ )
 
 
 
 
 
242
  )
243
+
244
+ video_state, audio_state = self.stage_2(
245
+ denoiser=SimpleDenoiser(v_context=v_context_p, a_context=a_context_p),
246
+ sigmas=stage_2_sigmas.to(dtype=torch.float32, device=self.device),
247
  noiser=noiser,
 
248
  stepper=stepper,
249
+ width=width,
250
+ height=height,
251
+ frames=num_frames,
252
+ fps=frame_rate,
253
+ video=ModalitySpec(
254
+ context=v_context_p,
255
+ conditionings=stage_2_conditionings,
256
+ noise_scale=stage_2_sigmas[0].item(),
257
+ initial_latent=upscaled_video_latent,
258
+ ),
259
+ audio=ModalitySpec(
260
+ context=a_context_p,
261
+ noise_scale=stage_2_sigmas[0].item(),
262
+ initial_latent=audio_state.latent,
263
+ ),
264
+ loop=res2s_audio_video_denoising_loop,
265
+ streaming_prefetch_count=streaming_prefetch_count,
266
  )
267
 
268
+ decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
269
+ decoded_audio = self.audio_decoder(audio_state.latent)
270
+ return decoded_video, decoded_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
  # Model repos
 
282
  # LoRA cache directory and currently-applied key
283
  LORA_CACHE_DIR = Path("lora_cache")
284
  LORA_CACHE_DIR.mkdir(exist_ok=True)
 
285
 
286
+ current_lora_key: str | None = None
287
  PENDING_LORA_KEY: str | None = None
288
+ PENDING_LORA_LORAS: tuple[LoraPathStrengthAndSDOps, ...] | None = None
289
+ PENDING_LORA_STATUS: str = "No LoRA config prepared yet."
290
 
291
  weights_dir = Path("weights")
292
  weights_dir.mkdir(exist_ok=True)
 
382
  progress=gr.Progress(track_tqdm=True),
383
  ):
384
  """
385
+ Prepare the LoRA selection for the guided pipeline.
386
+ This caches the LoRA config, not fused weights.
 
 
 
387
  """
388
+ global PENDING_LORA_KEY, PENDING_LORA_LORAS, PENDING_LORA_STATUS
389
+
390
+ key = _make_lora_key(
391
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
392
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
393
+ demopose_strength, voice_strength, realism_strength, transition_strength
394
+ )
395
+ cache_path = LORA_CACHE_DIR / f"{key}.json"
396
+
397
+ progress(0.05, desc="Preparing LoRA config")
 
 
 
 
 
 
 
398
 
399
  entries = [
400
  (pose_lora_path, round(float(pose_strength), 2)),
 
410
  (realism_lora_path, round(float(realism_strength), 2)),
411
  (transition_lora_path, round(float(transition_strength), 2)),
412
  ]
413
+
414
  loras_for_builder = [
415
  LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
416
  for path, strength in entries
 
419
 
420
  if not loras_for_builder:
421
  PENDING_LORA_KEY = None
422
+ PENDING_LORA_LORAS = None
423
  PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
424
  return PENDING_LORA_STATUS
425
 
 
 
426
  try:
427
+ if cache_path.exists():
428
+ progress(0.20, desc="Loading cached LoRA config")
429
+ data = json.loads(cache_path.read_text())
430
+ loras_for_builder = [
431
+ LoraPathStrengthAndSDOps(item["path"], item["strength"], LTXV_LORA_COMFY_RENAMING_MAP)
432
+ for item in data
433
+ if float(item["strength"]) != 0.0
434
+ ]
435
+ else:
436
+ progress(0.30, desc="Saving LoRA config cache")
437
+ cache_path.write_text(
438
+ json.dumps(
439
+ [{"path": path, "strength": strength} for path, strength in entries if float(strength) != 0.0],
440
+ indent=2,
441
+ )
442
+ )
 
 
443
 
444
  PENDING_LORA_KEY = key
445
+ PENDING_LORA_LORAS = tuple(loras_for_builder)
446
+ PENDING_LORA_STATUS = f"Prepared LoRA config: {cache_path.name}"
447
  return PENDING_LORA_STATUS
448
 
449
  except Exception as e:
 
451
  print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
452
  print(traceback.format_exc())
453
  PENDING_LORA_KEY = None
454
+ PENDING_LORA_LORAS = None
455
  PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
456
  return PENDING_LORA_STATUS
457
 
 
 
 
 
 
 
 
 
 
 
458
 
459
+ def apply_prepared_lora_config_to_pipeline():
460
+ global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_LORAS, pipeline
461
 
462
+ if PENDING_LORA_LORAS is None or PENDING_LORA_KEY is None:
463
+ print("[LoRA] No prepared LoRA config available; skipping.")
 
 
 
 
 
 
 
464
  return False
465
 
466
  if current_lora_key == PENDING_LORA_KEY:
467
+ print("[LoRA] Prepared LoRA config already active; skipping.")
468
  return True
469
 
470
+ pipeline = LTX23NegativePromptTwoStagePipeline(
471
+ checkpoint_path=str(checkpoint_path),
472
+ spatial_upsampler_path=str(spatial_upsampler_path),
473
+ gemma_root=str(gemma_root),
474
+ loras=PENDING_LORA_LORAS,
475
+ quantization=QuantizationPolicy.fp8_cast(),
476
+ )
477
 
478
  current_lora_key = PENDING_LORA_KEY
479
+ print("[LoRA] Prepared LoRA config applied by rebuilding the pipeline.")
480
  return True
481
 
482
  # ---- REPLACE PRELOAD BLOCK START ----
 
568
  def get_gpu_duration(
569
  first_image,
570
  last_image,
 
571
  prompt: str,
572
+ negative_prompt: str,
573
  duration: float,
574
  gpu_duration: float,
575
  enhance_prompt: bool = True,
 
598
  def generate_video(
599
  first_image,
600
  last_image,
 
601
  prompt: str,
602
+ negative_prompt: str,
603
  duration: float,
604
  gpu_duration: float,
605
  enhance_prompt: bool = True,
 
662
 
663
  video, audio = pipeline(
664
  prompt=prompt,
665
+ negative_prompt=negative_prompt,
666
  seed=current_seed,
667
  height=int(height),
668
  width=int(width),
669
  num_frames=num_frames,
670
  frame_rate=frame_rate,
671
  images=images,
 
672
  tiling_config=tiling_config,
673
  enhance_prompt=enhance_prompt,
674
+ # if your wrapper exposes them:
675
+ video_guider_params=video_guider_params,
676
+ audio_guider_params=audio_guider_params,
677
  )
678
 
679
  log_memory("after pipeline call")
 
706
  with gr.Row():
707
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
708
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
 
709
  prompt = gr.Textbox(
710
  label="Prompt",
711
  info="for best results - make it as elaborate as possible",
 
713
  lines=3,
714
  placeholder="Describe the motion and animation you want...",
715
  )
716
+ negative_prompt = gr.Textbox(
717
+ label="Negative Prompt",
718
+ value="",
719
+ lines=2,
720
+ placeholder="Describe what you want to avoid...",
721
+ )
722
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=30.0, value=10.0, step=0.1)
723
 
724
 
 
805
  [
806
  None,
807
  "pinkknit.jpg",
 
808
  "The camera falls downward through darkness as if dropped into a tunnel. "
809
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
810
  "over and look down toward the camera with curious expressions. The lens "
811
  "has a strong fisheye effect, creating a circular frame around them. They "
812
  "crowd together closely, forming a symmetrical cluster while staring "
813
  "directly into the lens.",
814
+ "",
815
  3.0,
816
  80.0,
817
  False,
 
834
  ],
835
  ],
836
  inputs=[
837
+ first_image, last_image, prompt, negative_prompt, duration, gpu_duration,
838
  enhance_prompt, seed, randomize_seed, height, width,
839
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
840
  ],
 
867
  generate_btn.click(
868
  fn=generate_video,
869
  inputs=[
870
+ first_image, last_image, prompt, negative_prompt, duration, gpu_duration, enhance_prompt,
871
  seed, randomize_seed, height, width,
872
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
873
  ],