dagloop5 commited on
Commit
bdb37cb
·
verified ·
1 Parent(s): 5f0bc90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -24
app.py CHANGED
@@ -125,6 +125,7 @@ class HQPipelineWithCachedLoRA:
125
  2. Handles ALL LoRAs via cached state (distilled + 12 custom)
126
  3. Supports CFG/negative prompts and guidance parameters
127
  4. Reuses single transformer for both stages
 
128
  """
129
 
130
  def __init__(
@@ -140,7 +141,6 @@ class HQPipelineWithCachedLoRA:
140
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141
  self.dtype = torch.bfloat16
142
 
143
- # Create ONE ModelLedger for everything
144
  print(" Creating ModelLedger (no LoRAs)...")
145
  self.model_ledger = ModelLedger(
146
  dtype=self.dtype,
@@ -148,17 +148,15 @@ class HQPipelineWithCachedLoRA:
148
  checkpoint_path=checkpoint_path,
149
  gemma_root_path=gemma_root,
150
  spatial_upsampler_path=spatial_upsampler_path,
151
- loras=(), # NO LoRAs
152
  quantization=quantization,
153
  )
154
 
155
- # Pipeline components
156
  self.pipeline_components = PipelineComponents(
157
  dtype=self.dtype,
158
  device=self.device,
159
  )
160
 
161
- # Storage for cached LoRA state
162
  self._cached_state = None
163
 
164
  def apply_cached_lora_state(self, state_dict):
@@ -175,7 +173,6 @@ class HQPipelineWithCachedLoRA:
175
  width: int,
176
  num_frames: int,
177
  frame_rate: float,
178
- num_inference_steps: int,
179
  video_guider_params: MultiModalGuiderParams,
180
  audio_guider_params: MultiModalGuiderParams,
181
  images: list,
@@ -186,7 +183,6 @@ class HQPipelineWithCachedLoRA:
186
  from ltx_core.tools import VideoLatentShape
187
  from ltx_core.components.noisers import GaussianNoiser
188
  from ltx_core.components.diffusion_steps import Res2sDiffusionStep
189
- from ltx_core.components.schedulers import LTX2Scheduler
190
  from ltx_core.types import VideoPixelShape
191
  from ltx_core.model.upsampler import upsample_video
192
  from ltx_core.model.video_vae import decode_video as vae_decode_video
@@ -199,12 +195,7 @@ class HQPipelineWithCachedLoRA:
199
  generator = torch.Generator(device=device).manual_seed(seed)
200
  noiser = GaussianNoiser(generator=generator)
201
 
202
- # Apply cached LoRA state if available
203
- if self._cached_state is not None:
204
- print("[LoRA] Applying cached state to transformer...")
205
- transformer = self.model_ledger.transformer()
206
- with torch.no_grad():
207
- transformer.load_state_dict(self._cached_state, strict=False)
208
 
209
  ctx_p, ctx_n = encode_prompts(
210
  [prompt, negative_prompt],
@@ -217,7 +208,7 @@ class HQPipelineWithCachedLoRA:
217
  v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
218
  v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
219
 
220
- # ===================== STAGE 1 =====================
221
  stage_1_output_shape = VideoPixelShape(
222
  batch=1, frames=num_frames,
223
  width=width // 2, height=height // 2, fps=frame_rate
@@ -238,13 +229,10 @@ class HQPipelineWithCachedLoRA:
238
 
239
  transformer = self.model_ledger.transformer()
240
 
241
- empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape())
 
 
242
  stepper = Res2sDiffusionStep()
243
- sigmas = (
244
- LTX2Scheduler()
245
- .execute(latent=empty_latent, steps=num_inference_steps)
246
- .to(dtype=torch.float32, device=device)
247
- )
248
 
249
  def first_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
250
  return res2s_audio_video_denoising_loop(
@@ -265,7 +253,7 @@ class HQPipelineWithCachedLoRA:
265
  output_shape=stage_1_output_shape,
266
  conditionings=stage_1_conditionings,
267
  noiser=noiser,
268
- sigmas=sigmas,
269
  stepper=stepper,
270
  denoising_loop_fn=first_stage_denoising_loop,
271
  components=self.pipeline_components,
@@ -298,11 +286,11 @@ class HQPipelineWithCachedLoRA:
298
  del video_encoder
299
  cleanup_memory()
300
 
301
- # ===================== STAGE 2 =====================
302
  transformer = self.model_ledger.transformer()
303
 
304
  from ltx_pipelines.utils.constants import STAGE_2_DISTILLED_SIGMA_VALUES
305
- distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
306
 
307
  def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
308
  return res2s_audio_video_denoising_loop(
@@ -321,13 +309,13 @@ class HQPipelineWithCachedLoRA:
321
  output_shape=stage_2_output_shape,
322
  conditionings=stage_2_conditionings,
323
  noiser=noiser,
324
- sigmas=distilled_sigmas,
325
  stepper=stepper,
326
  denoising_loop_fn=second_stage_denoising_loop,
327
  components=self.pipeline_components,
328
  dtype=dtype,
329
  device=device,
330
- noise_scale=distilled_sigmas[0],
331
  initial_video_latent=upscaled_video_latent,
332
  initial_audio_latent=audio_state.latent,
333
  )
 
125
  2. Handles ALL LoRAs via cached state (distilled + 12 custom)
126
  3. Supports CFG/negative prompts and guidance parameters
127
  4. Reuses single transformer for both stages
128
+ 5. Uses 8 steps at half resolution + 3 steps at full resolution
129
  """
130
 
131
  def __init__(
 
141
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
  self.dtype = torch.bfloat16
143
 
 
144
  print(" Creating ModelLedger (no LoRAs)...")
145
  self.model_ledger = ModelLedger(
146
  dtype=self.dtype,
 
148
  checkpoint_path=checkpoint_path,
149
  gemma_root_path=gemma_root,
150
  spatial_upsampler_path=spatial_upsampler_path,
151
+ loras=(),
152
  quantization=quantization,
153
  )
154
 
 
155
  self.pipeline_components = PipelineComponents(
156
  dtype=self.dtype,
157
  device=self.device,
158
  )
159
 
 
160
  self._cached_state = None
161
 
162
  def apply_cached_lora_state(self, state_dict):
 
173
  width: int,
174
  num_frames: int,
175
  frame_rate: float,
 
176
  video_guider_params: MultiModalGuiderParams,
177
  audio_guider_params: MultiModalGuiderParams,
178
  images: list,
 
183
  from ltx_core.tools import VideoLatentShape
184
  from ltx_core.components.noisers import GaussianNoiser
185
  from ltx_core.components.diffusion_steps import Res2sDiffusionStep
 
186
  from ltx_core.types import VideoPixelShape
187
  from ltx_core.model.upsampler import upsample_video
188
  from ltx_core.model.video_vae import decode_video as vae_decode_video
 
195
  generator = torch.Generator(device=device).manual_seed(seed)
196
  noiser = GaussianNoiser(generator=generator)
197
 
198
+ # NO LoRA application here - done in apply_prepared_lora_state_to_pipeline()
 
 
 
 
 
199
 
200
  ctx_p, ctx_n = encode_prompts(
201
  [prompt, negative_prompt],
 
208
  v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
209
  v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
210
 
211
+ # ===================== STAGE 1: 8 steps at half resolution =====================
212
  stage_1_output_shape = VideoPixelShape(
213
  batch=1, frames=num_frames,
214
  width=width // 2, height=height // 2, fps=frame_rate
 
229
 
230
  transformer = self.model_ledger.transformer()
231
 
232
+ # Use DISTILLED_SIGMA_VALUES for 8 steps at half resolution
233
+ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES
234
+ stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=device)
235
  stepper = Res2sDiffusionStep()
 
 
 
 
 
236
 
237
  def first_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
238
  return res2s_audio_video_denoising_loop(
 
253
  output_shape=stage_1_output_shape,
254
  conditionings=stage_1_conditionings,
255
  noiser=noiser,
256
+ sigmas=stage_1_sigmas,
257
  stepper=stepper,
258
  denoising_loop_fn=first_stage_denoising_loop,
259
  components=self.pipeline_components,
 
286
  del video_encoder
287
  cleanup_memory()
288
 
289
+ # ===================== STAGE 2: 3 steps at full resolution =====================
290
  transformer = self.model_ledger.transformer()
291
 
292
  from ltx_pipelines.utils.constants import STAGE_2_DISTILLED_SIGMA_VALUES
293
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
294
 
295
  def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
296
  return res2s_audio_video_denoising_loop(
 
309
  output_shape=stage_2_output_shape,
310
  conditionings=stage_2_conditionings,
311
  noiser=noiser,
312
+ sigmas=stage_2_sigmas,
313
  stepper=stepper,
314
  denoising_loop_fn=second_stage_denoising_loop,
315
  components=self.pipeline_components,
316
  dtype=dtype,
317
  device=device,
318
+ noise_scale=stage_2_sigmas[0],
319
  initial_video_latent=upscaled_video_latent,
320
  initial_audio_latent=audio_state.latent,
321
  )