# ============================================================================= # Installation and Setup # ============================================================================= import os import subprocess import sys # Disable torch.compile / dynamo before any torch import # This prevents CUDA initialization issues in the Space environment os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["TORCHDYNAMO_DISABLE"] = "1" # Clone LTX-2 repo at specific commit for reproducibility # The commit ensures we have the exact pipeline code matching our analysis LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git" LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2") # Using specific commit for stability - can be updated to main later LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45" if not os.path.exists(LTX_REPO_DIR): print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...") os.makedirs(LTX_REPO_DIR) subprocess.run(["git", "init", LTX_REPO_DIR], check=True) subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True) subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True) subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True) # Add repo packages to Python path # This allows us to import from ltx-core and ltx-pipelines sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src")) sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src")) # ============================================================================= # Imports # ============================================================================= import logging import random import tempfile from pathlib import Path import torch # Disable torch.compile/dynamo at runtime level torch._dynamo.config.suppress_errors = True torch._dynamo.config.disable = True import gradio as gr import spaces import numpy as np from huggingface_hub import hf_hub_download, snapshot_download # Import from the cloned LTX-2 pipeline # These imports come from ti2vid_two_stages_hq.py from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number from ltx_core.quantization import QuantizationPolicy from ltx_core.loader import LoraPathStrengthAndSDOps from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline from ltx_pipelines.utils.args import ImageConditioningInput from ltx_pipelines.utils.media_io import encode_video from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS from ltx_core.components.guiders import MultiModalGuiderParams # ============================================================================= # Constants and Configuration # ============================================================================= # Model repository on Hugging Face LTX_MODEL_REPO = "Lightricks/LTX-2.3" GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized" # Default parameters from LTX_2_3_HQ_PARAMS DEFAULT_FRAME_RATE = 24.0 # Resolution constraints (must be divisible by 64 for two-stage pipeline) # The pipeline generates at half-resolution in Stage 1, so input must be divisible by 2 MIN_DIM = 256 MAX_DIM = 1280 STEP = 64 # Both width and height must be divisible by 64 # Duration constraints (frames must be 8*K + 1) MIN_FRAMES = 9 # 8*1 + 1 MAX_FRAMES = 257 # 8*32 + 1 # Seed range MAX_SEED = np.iinfo(np.int32).max # Default prompts DEFAULT_PROMPT = ( "A majestic eagle soaring over mountain peaks at sunset, " "wings spread wide against the orange sky, feathers catching the light, " "wind currents visible in the motion blur, cinematic slow motion, 4K quality" ) DEFAULT_NEGATIVE_PROMPT = ( "worst quality, inconsistent motion, blurry, jittery, distorted, " "deformed, artifacts, text, watermark, logo, frame, border, " "low resolution, pixelated, unnatural, fake, CGI, cartoon" ) # ============================================================================= # Model Download and Initialization # ============================================================================= print("=" * 80) print("Downloading LTX-2.3 models...") print("=" * 80) # Download all required model files # 1. Dev checkpoint - full trainable 22B model checkpoint_path = hf_hub_download( repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-dev.safetensors" ) print(f"Dev checkpoint: {checkpoint_path}") # 2. Spatial upscaler - x2 upscaler for latent space spatial_upsampler_path = hf_hub_download( repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors" ) print(f"Spatial upsampler: {spatial_upsampler_path}") # 3. Distilled LoRA - distilled knowledge in LoRA format (rank 384) # This LoRA is specifically trained to work with the dev model distilled_lora_path = hf_hub_download( repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-lora-384.safetensors" ) print(f"Distilled LoRA: {distilled_lora_path}") # 4. Gemma text encoder - required for prompt encoding gemma_root = snapshot_download(repo_id=GEMMA_REPO) print(f"Gemma root: {gemma_root}") print("=" * 80) print("All models downloaded!") print("=" * 80) # ============================================================================= # Pipeline Initialization # ============================================================================= # Create the LoraPathStrengthAndSDOps for distilled LoRA # The sd_ops parameter uses the ComfyUI renaming map for compatibility from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP distilled_lora = [ LoraPathStrengthAndSDOps( path=distilled_lora_path, strength=1.0, # Will be set per-stage (0.25 for stage 1, 0.5 for stage 2) sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, ) ] # Initialize the Two-Stage HQ Pipeline # Key parameters: # - checkpoint_path: Full dev model (trainable) # - distilled_lora: LoRA containing distilled knowledge # - distilled_lora_strength_stage_1: 0.25 (lighter application at half-res) # - distilled_lora_strength_stage_2: 0.5 (stronger application after upscaling) # - spatial_upsampler_path: Required for two-stage upscaling # - gemma_root: Gemma text encoder for prompt encoding print("Initializing LTX-2.3 Two-Stage HQ Pipeline...") pipeline = TI2VidTwoStagesHQPipeline( checkpoint_path=checkpoint_path, distilled_lora=distilled_lora, distilled_lora_strength_stage_1=0.25, # From HQ params distilled_lora_strength_stage_2=0.50, # From HQ params spatial_upsampler_path=spatial_upsampler_path, gemma_root=gemma_root, loras=(), # No additional custom LoRAs for this Space quantization=QuantizationPolicy.fp8_cast(), # FP8 for memory efficiency torch_compile=False, # Disable for Space compatibility ) print("Pipeline initialized successfully!") print("=" * 80) # ============================================================================= # ZeroGPU Tensor Preloading - CPU Tensor Approach # ============================================================================= # ZeroGPU should pack any tensors in memory, not just GPU tensors. # We load model weights to CPU as proxy tensors to trigger packing. # During actual generation, ZeroGPU will move them to GPU. print("Creating CPU proxy tensors for ZeroGPU tensor packing...") print("This may take a few minutes (loading to CPU only)...") import gc # Create small proxy tensors for each model component # These don't need to be the actual weights - just tensors to trigger packing # ZeroGPU will pack whatever tensors exist when it runs _proxy_tensors = [] def create_proxy(name, shape, dtype=torch.float32): """Create a proxy tensor and ensure ZeroGPU sees it.""" print(f" Creating proxy for {name}: {shape}") t = torch.zeros(shape, dtype=dtype) _proxy_tensors.append(t) return t # Create proxies for various model components # These are just to ensure tensors exist in memory for ZeroGPU to pack create_proxy("transformer_stage1", (1, 1024, 512)) create_proxy("transformer_stage2", (1, 1024, 512)) create_proxy("video_encoder", (1, 768, 512)) create_proxy("video_decoder", (1, 512, 512)) create_proxy("audio_decoder", (1, 256, 512)) create_proxy("spatial_upsampler", (1, 256, 512)) create_proxy("text_encoder", (1, 2048, 256)) create_proxy("vocoder", (1, 128, 256)) # Keep proxies alive by storing in module globals proxy_stage1 = _proxy_tensors[0] proxy_stage2 = _proxy_tensors[1] proxy_venc = _proxy_tensors[2] proxy_vdec = _proxy_tensors[3] proxy_adec = _proxy_tensors[4] proxy_upsamp = _proxy_tensors[5] proxy_tenc = _proxy_tensors[6] proxy_voc = _proxy_tensors[7] # Clean up the temporary list del _proxy_tensors # Now trigger the actual model loading but catch GPU errors print("\nAttempting model initialization (GPU errors expected)...") try: # Try to access components - this will trigger loading but fail on GPU _ = pipeline.stage_1._transformer_ctx _ = pipeline.prompt_encoder._text_encoder_ctx print(" Model contexts accessed") except Exception as e: print(f" Context access: {type(e).__name__}") print("\n" + "=" * 80) print("Startup complete. Models will load to GPU during first generation.") print("=" * 80) # ============================================================================= # Helper Functions # ============================================================================= def log_memory(tag: str): """Log current GPU memory usage for debugging.""" if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 peak = torch.cuda.max_memory_allocated() / 1024**3 free, total = torch.cuda.mem_get_info() print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB") def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int: """ Calculate number of frames from duration. Frame count must be 8*K + 1 (K is a non-negative integer) for the LTX model. This constraint comes from the temporal upsampling architecture. Args: duration: Duration in seconds frame_rate: Frames per second Returns: Frame count that satisfies the 8*K + 1 constraint """ ideal_frames = int(duration * frame_rate) # Ensure it's at least MIN_FRAMES ideal_frames = max(ideal_frames, MIN_FRAMES) # Round to nearest 8*K + 1 k = round((ideal_frames - 1) / 8) frames = k * 8 + 1 # Clamp to max return min(frames, MAX_FRAMES) def validate_resolution(height: int, width: int) -> tuple[int, int]: """ Ensure resolution is valid for two-stage pipeline. The two-stage pipeline requires: - Both dimensions divisible by 64 (for final resolution) - Stage 1 operates at half resolution (divisible by 32) Args: height: Target height width: Target width Returns: Validated (height, width) tuple """ # Round to nearest multiple of 64 height = round(height / STEP) * STEP width = round(width / STEP) * STEP # Clamp to valid range height = max(MIN_DIM, min(height, MAX_DIM)) width = max(MIN_DIM, min(width, MAX_DIM)) return height, width def detect_aspect_ratio(image) -> str: """Detect the closest aspect ratio from an image for resolution presets.""" if image is None: return "16:9" if hasattr(image, "size"): w, h = image.size elif hasattr(image, "shape"): h, w = image.shape[:2] else: return "16:9" ratio = w / h candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0} return min(candidates, key=lambda k: abs(ratio - candidates[k])) # Resolution presets based on aspect ratio RESOLUTIONS = { "16:9": {"width": 1280, "height": 704}, # 960x540 * 1.33 = 1280x720, halved = 640x360 -> 1280x720 "9:16": {"width": 704, "height": 1280}, "1:1": {"width": 960, "height": 960}, } def get_duration( prompt: str, negative_prompt: str, input_image, duration: float, seed: int, randomize_seed: bool, height: int, width: int, enhance_prompt: bool, video_cfg_scale: float, video_stg_scale: float, video_rescale_scale: float, video_a2v_scale: float, audio_cfg_scale: float, audio_stg_scale: float, audio_rescale_scale: float, audio_v2a_scale: float, progress, ) -> int: """ Dynamically calculate GPU duration based on generation parameters. This is used by @spaces.GPU to set the appropriate time limit. Longer videos and higher resolution require more time. Args: duration: Video duration in seconds height, width: Resolution num_frames: Number of frames (indicates complexity) Returns: Duration in seconds for the GPU allocation """ base = 60 # Longer videos need more time if duration > 4: base += 15 if duration > 6: base += 15 # Higher resolution needs more time if height > 700 or width > 1000: base += 15 # More frames means more processing # Calculate num_frames inside get_duration since it's no longer a parameter frames_from_duration = int(duration * DEFAULT_FRAME_RATE) if frames_from_duration > 81: base += 10 @spaces.GPU(duration=get_duration) @torch.inference_mode() def generate_video( prompt: str, negative_prompt: str, input_image, duration: float, seed: int, randomize_seed: bool, height: int, width: int, enhance_prompt: bool, # Guidance parameters video_cfg_scale: float, video_stg_scale: float, video_rescale_scale: float, video_a2v_scale: float, audio_cfg_scale: float, audio_stg_scale: float, audio_rescale_scale: float, audio_v2a_scale: float, progress=gr.Progress(track_tqdm=True), ): """ Generate high-quality video using the Two-Stage HQ Pipeline. This function implements a two-stage generation process: Stage 1 (Half Resolution + CFG): - Generates video at half the target resolution - Uses GuidedDenoiser with CFG (positive + negative prompts) - Applies distilled LoRA at strength 0.25 - Res2s sampler for efficient second-order denoising Stage 2 (Upscale + Refine): - Upscales latent representation 2x using spatial upsampler - Refines using SimpleDenoiser (no CFG, distilled approach) - Applies distilled LoRA at strength 0.5 - 4-step refined denoising schedule Args: prompt: Text description of desired video content negative_prompt: What to avoid in the video input_image: Optional input image for image-to-video duration: Video duration in seconds seed: Random seed for reproducibility randomize_seed: Whether to use a random seed height, width: Target resolution (must be divisible by 64) enhance_prompt: Whether to use prompt enhancement video_cfg_scale: Video CFG (prompt adherence) video_stg_scale: Video STG (spatio-temporal guidance) video_rescale_scale: Video rescaling factor video_a2v_scale: Audio-to-video cross-attention scale audio_cfg_scale: Audio CFG (prompt adherence) audio_stg_scale: Audio STG (spatio-temporal guidance) audio_rescale_scale: Audio rescaling factor audio_v2a_scale: Video-to-audio cross-attention scale Returns: Tuple of (output_video_path, used_seed) """ try: torch.cuda.reset_peak_memory_stats() log_memory("start") # Handle random seed current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) print(f"Using seed: {current_seed}") # Validate and adjust resolution height, width = validate_resolution(int(height), int(width)) print(f"Resolution: {width}x{height}") # Calculate frames (must be 8*K + 1) num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE) print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)") # Prepare image conditioning if provided images = [] if input_image is not None: # Save input image temporarily output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) temp_image_path = output_dir / f"temp_input_{current_seed}.jpg" if hasattr(input_image, "save"): input_image.save(temp_image_path) else: import shutil shutil.copy(input_image, temp_image_path) # Create ImageConditioningInput # path: image file path # frame_idx: target frame to condition on (0 = first frame) # strength: conditioning strength (1.0 = full influence) images = [ImageConditioningInput( path=str(temp_image_path), frame_idx=0, strength=1.0 )] # Create tiling config for VAE decoding # Tiling is necessary to avoid OOM errors during decoding tiling_config = TilingConfig.default() video_chunks_number = get_video_chunks_number(num_frames, tiling_config) # Configure MultiModalGuider parameters # These control how the model adheres to prompts and handles modality guidance # Video guider parameters # cfg_scale: Classifier-free guidance scale (higher = stronger prompt adherence) # stg_scale: Spatio-temporal guidance scale (0 = disabled) # rescale_scale: Rescaling factor for oversaturation prevention # modality_scale: Cross-attention scale (audio-to-video) # skip_step: Step skipping for faster inference (0 = no skipping) # stg_blocks: Which transformer blocks to perturb for STG video_guider_params = MultiModalGuiderParams( cfg_scale=video_cfg_scale, stg_scale=video_stg_scale, rescale_scale=video_rescale_scale, modality_scale=video_a2v_scale, skip_step=0, stg_blocks=[], # Empty for LTX 2.3 HQ ) # Audio guider parameters audio_guider_params = MultiModalGuiderParams( cfg_scale=audio_cfg_scale, stg_scale=audio_stg_scale, rescale_scale=audio_rescale_scale, modality_scale=audio_v2a_scale, skip_step=0, stg_blocks=[], # Empty for LTX 2.3 HQ ) log_memory("before pipeline call") # Call the pipeline # The pipeline uses Res2sDiffusionStep for second-order sampling # Stage 1: num_inference_steps from LTX_2_3_HQ_PARAMS (15 steps) # Stage 2: Fixed 4-step schedule from STAGE_2_DISTILLED_SIGMAS video, audio = pipeline( prompt=prompt, negative_prompt=negative_prompt, seed=current_seed, height=height, width=width, num_frames=num_frames, frame_rate=DEFAULT_FRAME_RATE, num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps, # 15 steps video_guider_params=video_guider_params, audio_guider_params=audio_guider_params, images=images, tiling_config=tiling_config, enhance_prompt=enhance_prompt, ) log_memory("after pipeline call") # Encode video with audio output_path = tempfile.mktemp(suffix=".mp4") encode_video( video=video, fps=DEFAULT_FRAME_RATE, audio=audio, output_path=output_path, video_chunks_number=video_chunks_number, ) log_memory("after encode_video") return str(output_path), current_seed except Exception as e: import traceback log_memory("on error") print(f"Error: {str(e)}\n{traceback.format_exc()}") return None, current_seed # ============================================================================= # Gradio UI # ============================================================================= css = """ /* Custom styling for LTX-2.3 Space */ .fillable {max-width: 1200px !important} .progress-text {color: white} """ with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo: gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation") gr.Markdown( "High-quality text/image-to-video generation using the dev model + distilled LoRA. " "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) " "[[GitHub]](https://github.com/Lightricks/LTX-2)" ) with gr.Row(): # Input Column with gr.Column(): # Input image (optional) input_image = gr.Image( label="Input Image (Optional - for image-to-video)", type="pil", sources=["upload", "webcam", "clipboard"] ) # Prompt inputs prompt = gr.Textbox( label="Prompt", info="Describe the video you want to generate", value=DEFAULT_PROMPT, lines=3, placeholder="Enter your prompt here..." ) negative_prompt = gr.Textbox( label="Negative Prompt", info="What to avoid in the generated video", value=DEFAULT_NEGATIVE_PROMPT, lines=2, placeholder="Enter negative prompt here..." ) # Duration slider duration = gr.Slider( label="Duration (seconds)", minimum=0.5, maximum=8.0, value=2.0, step=0.1, info="Video duration (clamped to 8K+1 frames)" ) # Enhance prompt toggle enhance_prompt = gr.Checkbox( label="Enhance Prompt", value=False, info="Use Gemma to enhance the prompt for better results" ) # Generate button generate_btn = gr.Button("Generate Video", variant="primary", size="lg") # Output Column with gr.Column(): output_video = gr.Video( label="Generated Video", autoplay=True, interactive=False ) # Advanced Settings Accordion with gr.Accordion("Advanced Settings", open=False): with gr.Row(): # Resolution inputs width = gr.Number( label="Width", value=1280, precision=0, info="Must be divisible by 64" ) height = gr.Number( label="Height", value=704, precision=0, info="Must be divisible by 64" ) with gr.Row(): # Seed controls seed = gr.Number( label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED ) randomize_seed = gr.Checkbox( label="Randomize Seed", value=True ) gr.Markdown("### Video Guidance Parameters") gr.Markdown("Control how strongly the model follows the video prompt and handles guidance.") with gr.Row(): video_cfg_scale = gr.Slider( label="Video CFG Scale", minimum=1.0, maximum=10.0, value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale, step=0.1, info="Classifier-free guidance for video (higher = stronger prompt adherence)" ) video_stg_scale = gr.Slider( label="Video STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1, info="Spatio-temporal guidance (0 = disabled)" ) with gr.Row(): video_rescale_scale = gr.Slider( label="Video Rescale", minimum=0.0, maximum=2.0, value=0.45, step=0.1, info="Rescaling factor for oversaturation prevention" ) video_a2v_scale = gr.Slider( label="A2V Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1, info="Audio-to-video cross-attention scale" ) gr.Markdown("### Audio Guidance Parameters") gr.Markdown("Control audio generation quality and sync.") with gr.Row(): audio_cfg_scale = gr.Slider( label="Audio CFG Scale", minimum=1.0, maximum=15.0, value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale, step=0.1, info="Classifier-free guidance for audio" ) audio_stg_scale = gr.Slider( label="Audio STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1, info="Spatio-temporal guidance for audio (0 = disabled)" ) with gr.Row(): audio_rescale_scale = gr.Slider( label="Audio Rescale", minimum=0.0, maximum=2.0, value=1.0, step=0.1, info="Audio rescaling factor" ) audio_v2a_scale = gr.Slider( label="V2A Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1, info="Video-to-audio cross-attention scale" ) # Event handlers def on_image_upload(image, current_h, current_w): """Update resolution based on uploaded image aspect ratio.""" if image is None: return gr.update(), gr.update() aspect = detect_aspect_ratio(image) if aspect in RESOLUTIONS: return ( gr.update(value=RESOLUTIONS[aspect]["width"]), gr.update(value=RESOLUTIONS[aspect]["height"]) ) return gr.update(), gr.update() input_image.change( fn=on_image_upload, inputs=[input_image, height, width], outputs=[width, height], ) # Generate button click handler generate_btn.click( fn=generate_video, inputs=[ prompt, negative_prompt, input_image, duration, seed, randomize_seed, height, width, enhance_prompt, video_cfg_scale, video_stg_scale, video_rescale_scale, video_a2v_scale, audio_cfg_scale, audio_stg_scale, audio_rescale_scale, audio_v2a_scale, ], outputs=[output_video, seed], ) # ============================================================================= # Main Entry Point # ============================================================================= if __name__ == "__main__": demo.queue().launch( theme=gr.themes.Citrus(), css=css, mcp_server=True, share=True, )