dagloop5 commited on
Commit
4c3fdc0
·
verified ·
1 Parent(s): 8338af1

Create app(draft).py

Browse files
Files changed (1) hide show
  1. app(draft).py +373 -0
app(draft).py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ # Disable torch.compile / dynamo before any torch import
6
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
+
9
+ # Install xformers for memory-efficient attention
10
+ subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
+
12
+ # Clone LTX-2 repo and install packages
13
+ LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
+ LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
+ LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45"
16
+
17
+ if not os.path.exists(LTX_REPO_DIR):
18
+ print(f"Cloning {LTX_REPO_URL}...")
19
+ os.makedirs(LTX_REPO_DIR)
20
+ subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
21
+ subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
22
+ subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
23
+ subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
24
+
25
+
26
+ print("Installing ltx-core and ltx-pipelines from cloned repo...")
27
+ subprocess.run(
28
+ [sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "-e",
29
+ os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
30
+ "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")],
31
+ check=True,
32
+ )
33
+
34
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
35
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
36
+
37
+ import logging
38
+ import random
39
+ import tempfile
40
+ from pathlib import Path
41
+ from collections.abc import Iterator
42
+
43
+ import torch
44
+ torch._dynamo.config.suppress_errors = True
45
+ torch._dynamo.config.disable = True
46
+
47
+ import spaces
48
+ import gradio as gr
49
+ import numpy as np
50
+ from huggingface_hub import hf_hub_download, snapshot_download
51
+
52
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
53
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
54
+ from ltx_core.components.noisers import GaussianNoiser
55
+ from ltx_core.components.schedulers import LTX2Scheduler
56
+ from ltx_core.loader import LoraPathStrengthAndSDOps
57
+ from ltx_core.loader.registry import Registry
58
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
59
+ from ltx_core.quantization import QuantizationPolicy
60
+ from ltx_core.types import Audio, VideoLatentShape, VideoPixelShape
61
+ from ltx_pipelines.utils.args import ImageConditioningInput, hq_2_stage_arg_parser
62
+ from ltx_pipelines.utils.blocks import (
63
+ AudioDecoder,
64
+ DiffusionStage,
65
+ ImageConditioner,
66
+ PromptEncoder,
67
+ VideoDecoder,
68
+ VideoUpsampler,
69
+ )
70
+ from ltx_pipelines.utils.constants import (
71
+ LTX_2_3_HQ_PARAMS,
72
+ STAGE_2_DISTILLED_SIGMAS,
73
+ )
74
+ from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
75
+ from ltx_pipelines.utils.helpers import (
76
+ assert_resolution,
77
+ combined_image_conditionings,
78
+ get_device,
79
+ )
80
+ from ltx_pipelines.utils.media_io import encode_video
81
+ from ltx_pipelines.utils.samplers import res2s_audio_video_denoising_loop
82
+ from ltx_pipelines.utils.types import ModalitySpec
83
+
84
+ # Force-patch xformers attention into the LTX attention module.
85
+ from ltx_core.model.transformer import attention as _attn_mod
86
+ print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
87
+ try:
88
+ from xformers.ops import memory_efficient_attention as _mea
89
+ _attn_mod.memory_efficient_attention = _mea
90
+ print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
91
+ except Exception as e:
92
+ print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
93
+
94
+ logging.getLogger().setLevel(logging.INFO)
95
+
96
+ MAX_SEED = np.iinfo(np.int32).max
97
+ DEFAULT_PROMPT = (
98
+ "An astronaut hatches from a fragile egg on the surface of the Moon, "
99
+ "the shell cracking and peeling apart in gentle low-gravity motion. "
100
+ "Fine lunar dust lifts and drifts outward with each movement, floating "
101
+ "in slow arcs before settling back onto the ground."
102
+ )
103
+ DEFAULT_FRAME_RATE = 24.0
104
+
105
+ # Resolution presets: (width, height)
106
+ RESOLUTIONS = {
107
+ "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
108
+ "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
109
+ }
110
+
111
+ class TI2VidTwoStagesHQPipeline:
112
+ """
113
+ Two-stage text/image-to-video generation pipeline using the res_2s sampler.
114
+ Same structure as :class:`TI2VidTwoStagesPipeline`: stage 1 generates video at
115
+ half of the target resolution with CFG guidance (assuming full model is used),
116
+ then Stage 2 upsamples by 2x and refines using a distilled LoRA for higher
117
+ quality output.
118
+ Uses the res_2s second-order sampler instead of Euler, allowing fewer
119
+ steps for comparable quality. Supports optional image conditioning via
120
+ the images parameter.
121
+ """
122
+
123
+ def __init__( # noqa: PLR0913
124
+ self,
125
+ checkpoint_path: str,
126
+ distilled_lora: list[LoraPathStrengthAndSDOps],
127
+ distilled_lora_strength_stage_1: float,
128
+ distilled_lora_strength_stage_2: float,
129
+ spatial_upsampler_path: str,
130
+ gemma_root: str,
131
+ loras: tuple[LoraPathStrengthAndSDOps, ...],
132
+ device: torch.device | None = None,
133
+ quantization: QuantizationPolicy | None = None,
134
+ registry: Registry | None = None,
135
+ torch_compile: bool = False,
136
+ ):
137
+ self.device = device or get_device()
138
+ self.dtype = torch.bfloat16
139
+ self._scheduler = LTX2Scheduler()
140
+
141
+ distilled_lora_stage_1 = LoraPathStrengthAndSDOps(
142
+ path=distilled_lora[0].path,
143
+ strength=distilled_lora_strength_stage_1,
144
+ sd_ops=distilled_lora[0].sd_ops,
145
+ )
146
+ distilled_lora_stage_2 = LoraPathStrengthAndSDOps(
147
+ path=distilled_lora[0].path,
148
+ strength=distilled_lora_strength_stage_2,
149
+ sd_ops=distilled_lora[0].sd_ops,
150
+ )
151
+
152
+ self.prompt_encoder = PromptEncoder(checkpoint_path, gemma_root, self.dtype, self.device, registry=registry)
153
+ self.image_conditioner = ImageConditioner(checkpoint_path, self.dtype, self.device, registry=registry)
154
+ self.upsampler = VideoUpsampler(
155
+ checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry
156
+ )
157
+ self.video_decoder = VideoDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
158
+ self.audio_decoder = AudioDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
159
+
160
+ self.stage_1 = DiffusionStage(
161
+ checkpoint_path,
162
+ self.dtype,
163
+ self.device,
164
+ loras=(*loras, distilled_lora_stage_1),
165
+ quantization=quantization,
166
+ registry=registry,
167
+ torch_compile=torch_compile,
168
+ )
169
+ self.stage_2 = DiffusionStage(
170
+ checkpoint_path,
171
+ self.dtype,
172
+ self.device,
173
+ loras=(*loras, distilled_lora_stage_2),
174
+ quantization=quantization,
175
+ registry=registry,
176
+ torch_compile=torch_compile,
177
+ )
178
+
179
+ @torch.inference_mode()
180
+ def __call__( # noqa: PLR0913
181
+ self,
182
+ prompt: str,
183
+ negative_prompt: str,
184
+ seed: int,
185
+ height: int,
186
+ width: int,
187
+ num_frames: int,
188
+ frame_rate: float,
189
+ num_inference_steps: int,
190
+ video_guider_params: MultiModalGuiderParams,
191
+ audio_guider_params: MultiModalGuiderParams,
192
+ images: list[ImageConditioningInput],
193
+ tiling_config: TilingConfig | None = None,
194
+ enhance_prompt: bool = False,
195
+ streaming_prefetch_count: int | None = None,
196
+ max_batch_size: int = 1,
197
+ stage_1_sigmas: torch.Tensor | None = None,
198
+ stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS,
199
+ ) -> tuple[Iterator[torch.Tensor], Audio]:
200
+ assert_resolution(height=height, width=width, is_two_stage=True)
201
+
202
+ generator = torch.Generator(device=self.device).manual_seed(seed)
203
+ noiser = GaussianNoiser(generator=generator)
204
+ dtype = torch.bfloat16
205
+
206
+ ctx_p, ctx_n = self.prompt_encoder(
207
+ [prompt, negative_prompt],
208
+ enhance_first_prompt=enhance_prompt,
209
+ enhance_prompt_image=images[0][0] if len(images) > 0 else None,
210
+ enhance_prompt_seed=seed,
211
+ streaming_prefetch_count=streaming_prefetch_count,
212
+ )
213
+ v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
214
+ v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
215
+
216
+ # Stage 1: Generate video at half resolution with CFG guidance using res2s sampler.
217
+ stage_1_output_shape = VideoPixelShape(
218
+ batch=1,
219
+ frames=num_frames,
220
+ width=width // 2,
221
+ height=height // 2,
222
+ fps=frame_rate,
223
+ )
224
+ stage_1_conditionings = self.image_conditioner(
225
+ lambda enc: combined_image_conditionings(
226
+ images=images,
227
+ height=stage_1_output_shape.height,
228
+ width=stage_1_output_shape.width,
229
+ video_encoder=enc,
230
+ dtype=dtype,
231
+ device=self.device,
232
+ )
233
+ )
234
+
235
+ stepper = Res2sDiffusionStep()
236
+
237
+ if stage_1_sigmas is None:
238
+ empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape())
239
+ stage_1_sigmas = self._scheduler.execute(latent=empty_latent, steps=num_inference_steps)
240
+ sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device)
241
+
242
+ video_state, audio_state = self.stage_1(
243
+ denoiser=GuidedDenoiser(
244
+ v_context=v_context_p,
245
+ a_context=a_context_p,
246
+ video_guider=MultiModalGuider(
247
+ params=video_guider_params,
248
+ negative_context=v_context_n,
249
+ ),
250
+ audio_guider=MultiModalGuider(
251
+ params=audio_guider_params,
252
+ negative_context=a_context_n,
253
+ ),
254
+ ),
255
+ sigmas=sigmas,
256
+ noiser=noiser,
257
+ stepper=stepper,
258
+ width=stage_1_output_shape.width,
259
+ height=stage_1_output_shape.height,
260
+ frames=num_frames,
261
+ fps=frame_rate,
262
+ video=ModalitySpec(context=v_context_p, conditionings=stage_1_conditionings),
263
+ audio=ModalitySpec(context=a_context_p),
264
+ loop=res2s_audio_video_denoising_loop,
265
+ streaming_prefetch_count=streaming_prefetch_count,
266
+ max_batch_size=max_batch_size,
267
+ )
268
+
269
+ # Stage 2: Upsample and refine the video at higher resolution with distilled LoRA.
270
+ upscaled_video_latent = self.upsampler(video_state.latent[:1])
271
+
272
+ stage_2_sigmas = stage_2_sigmas.to(dtype=torch.float32, device=self.device)
273
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
274
+ stage_2_conditionings = self.image_conditioner(
275
+ lambda enc: combined_image_conditionings(
276
+ images=images,
277
+ height=stage_2_output_shape.height,
278
+ width=stage_2_output_shape.width,
279
+ video_encoder=enc,
280
+ dtype=dtype,
281
+ device=self.device,
282
+ )
283
+ )
284
+
285
+ video_state, audio_state = self.stage_2(
286
+ denoiser=SimpleDenoiser(v_context=v_context_p, a_context=a_context_p),
287
+ sigmas=stage_2_sigmas,
288
+ noiser=noiser,
289
+ stepper=stepper,
290
+ width=width,
291
+ height=height,
292
+ frames=num_frames,
293
+ fps=frame_rate,
294
+ video=ModalitySpec(
295
+ context=v_context_p,
296
+ conditionings=stage_2_conditionings,
297
+ noise_scale=stage_2_sigmas[0].item(),
298
+ initial_latent=upscaled_video_latent,
299
+ ),
300
+ audio=ModalitySpec(
301
+ context=a_context_p,
302
+ noise_scale=stage_2_sigmas[0].item(),
303
+ initial_latent=audio_state.latent,
304
+ ),
305
+ loop=res2s_audio_video_denoising_loop,
306
+ streaming_prefetch_count=streaming_prefetch_count,
307
+ )
308
+
309
+ decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
310
+ decoded_audio = self.audio_decoder(audio_state.latent)
311
+ return decoded_video, decoded_audio
312
+
313
+
314
+ @torch.inference_mode()
315
+ def main() -> None:
316
+ logging.getLogger().setLevel(logging.INFO)
317
+ parser = hq_2_stage_arg_parser(params=LTX_2_3_HQ_PARAMS)
318
+ args = parser.parse_args()
319
+ pipeline = TI2VidTwoStagesHQPipeline(
320
+ checkpoint_path=args.checkpoint_path,
321
+ distilled_lora=args.distilled_lora,
322
+ distilled_lora_strength_stage_1=args.distilled_lora_strength_stage_1,
323
+ distilled_lora_strength_stage_2=args.distilled_lora_strength_stage_2,
324
+ spatial_upsampler_path=args.spatial_upsampler_path,
325
+ gemma_root=args.gemma_root,
326
+ loras=tuple(args.lora) if args.lora else (),
327
+ quantization=args.quantization,
328
+ torch_compile=args.compile,
329
+ )
330
+ tiling_config = TilingConfig.default()
331
+ video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
332
+ video, audio = pipeline(
333
+ prompt=args.prompt,
334
+ negative_prompt=args.negative_prompt,
335
+ seed=args.seed,
336
+ height=args.height,
337
+ width=args.width,
338
+ num_frames=args.num_frames,
339
+ frame_rate=args.frame_rate,
340
+ num_inference_steps=args.num_inference_steps,
341
+ video_guider_params=MultiModalGuiderParams(
342
+ cfg_scale=args.video_cfg_guidance_scale,
343
+ stg_scale=args.video_stg_guidance_scale,
344
+ rescale_scale=args.video_rescale_scale,
345
+ modality_scale=args.a2v_guidance_scale,
346
+ skip_step=args.video_skip_step,
347
+ stg_blocks=args.video_stg_blocks,
348
+ ),
349
+ audio_guider_params=MultiModalGuiderParams(
350
+ cfg_scale=args.audio_cfg_guidance_scale,
351
+ stg_scale=args.audio_stg_guidance_scale,
352
+ rescale_scale=args.audio_rescale_scale,
353
+ modality_scale=args.v2a_guidance_scale,
354
+ skip_step=args.audio_skip_step,
355
+ stg_blocks=args.audio_stg_blocks,
356
+ ),
357
+ images=args.images,
358
+ tiling_config=tiling_config,
359
+ streaming_prefetch_count=args.streaming_prefetch_count,
360
+ max_batch_size=args.max_batch_size,
361
+ )
362
+
363
+ encode_video(
364
+ video=video,
365
+ fps=args.frame_rate,
366
+ audio=audio,
367
+ output_path=args.output_path,
368
+ video_chunks_number=video_chunks_number,
369
+ )
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main()