Spaces:
Running on Zero
Running on Zero
| # ============================================================================= | |
| # Installation and Setup | |
| # ============================================================================= | |
| import os | |
| import subprocess | |
| import sys | |
| # Disable torch.compile / dynamo before any torch import | |
| # This prevents CUDA initialization issues in the Space environment | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| # Clone LTX-2 repo at specific commit for reproducibility | |
| # The commit ensures we have the exact pipeline code matching our analysis | |
| LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git" | |
| LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2") | |
| # Using specific commit for stability - can be updated to main later | |
| LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45" | |
| if not os.path.exists(LTX_REPO_DIR): | |
| print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...") | |
| os.makedirs(LTX_REPO_DIR) | |
| subprocess.run(["git", "init", LTX_REPO_DIR], check=True) | |
| subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True) | |
| subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True) | |
| subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True) | |
| # Add repo packages to Python path | |
| # This allows us to import from ltx-core and ltx-pipelines | |
| sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src")) | |
| sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src")) | |
| # ============================================================================= | |
| # Imports | |
| # ============================================================================= | |
| import logging | |
| import random | |
| import tempfile | |
| from pathlib import Path | |
| import torch | |
| # Disable torch.compile/dynamo at runtime level | |
| torch._dynamo.config.suppress_errors = True | |
| torch._dynamo.config.disable = True | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| # Import from the cloned LTX-2 pipeline | |
| # These imports come from ti2vid_two_stages_hq.py | |
| from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number | |
| from ltx_core.quantization import QuantizationPolicy | |
| from ltx_core.loader import LoraPathStrengthAndSDOps | |
| from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline | |
| from ltx_pipelines.utils.args import ImageConditioningInput | |
| from ltx_pipelines.utils.media_io import encode_video | |
| from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS | |
| from ltx_core.components.guiders import MultiModalGuiderParams | |
| # ============================================================================= | |
| # Constants and Configuration | |
| # ============================================================================= | |
| # Model repository on Hugging Face | |
| LTX_MODEL_REPO = "Lightricks/LTX-2.3" | |
| GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized" | |
| # Default parameters from LTX_2_3_HQ_PARAMS | |
| DEFAULT_FRAME_RATE = 24.0 | |
| # Resolution constraints (must be divisible by 64 for two-stage pipeline) | |
| # The pipeline generates at half-resolution in Stage 1, so input must be divisible by 2 | |
| MIN_DIM = 256 | |
| MAX_DIM = 1280 | |
| STEP = 64 # Both width and height must be divisible by 64 | |
| # Duration constraints (frames must be 8*K + 1) | |
| MIN_FRAMES = 9 # 8*1 + 1 | |
| MAX_FRAMES = 257 # 8*32 + 1 | |
| # Seed range | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Default prompts | |
| DEFAULT_PROMPT = ( | |
| "A majestic eagle soaring over mountain peaks at sunset, " | |
| "wings spread wide against the orange sky, feathers catching the light, " | |
| "wind currents visible in the motion blur, cinematic slow motion, 4K quality" | |
| ) | |
| DEFAULT_NEGATIVE_PROMPT = ( | |
| "worst quality, inconsistent motion, blurry, jittery, distorted, " | |
| "deformed, artifacts, text, watermark, logo, frame, border, " | |
| "low resolution, pixelated, unnatural, fake, CGI, cartoon" | |
| ) | |
| # ============================================================================= | |
| # Model Download and Initialization | |
| # ============================================================================= | |
| print("=" * 80) | |
| print("Downloading LTX-2.3 models...") | |
| print("=" * 80) | |
| # Download all required model files | |
| # 1. Dev checkpoint - full trainable 22B model | |
| checkpoint_path = hf_hub_download( | |
| repo_id=LTX_MODEL_REPO, | |
| filename="ltx-2.3-22b-dev.safetensors" | |
| ) | |
| print(f"Dev checkpoint: {checkpoint_path}") | |
| # 2. Spatial upscaler - x2 upscaler for latent space | |
| spatial_upsampler_path = hf_hub_download( | |
| repo_id=LTX_MODEL_REPO, | |
| filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors" | |
| ) | |
| print(f"Spatial upsampler: {spatial_upsampler_path}") | |
| # 3. Distilled LoRA - distilled knowledge in LoRA format (rank 384) | |
| # This LoRA is specifically trained to work with the dev model | |
| distilled_lora_path = hf_hub_download( | |
| repo_id=LTX_MODEL_REPO, | |
| filename="ltx-2.3-22b-distilled-lora-384.safetensors" | |
| ) | |
| print(f"Distilled LoRA: {distilled_lora_path}") | |
| # 4. Gemma text encoder - required for prompt encoding | |
| gemma_root = snapshot_download(repo_id=GEMMA_REPO) | |
| print(f"Gemma root: {gemma_root}") | |
| print("=" * 80) | |
| print("All models downloaded!") | |
| print("=" * 80) | |
| # ============================================================================= | |
| # Pipeline Initialization | |
| # ============================================================================= | |
| # Create the LoraPathStrengthAndSDOps for distilled LoRA | |
| # The sd_ops parameter uses the ComfyUI renaming map for compatibility | |
| from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP | |
| distilled_lora = [ | |
| LoraPathStrengthAndSDOps( | |
| path=distilled_lora_path, | |
| strength=1.0, # Will be set per-stage (0.25 for stage 1, 0.5 for stage 2) | |
| sd_ops=LTXV_LORA_COMFY_RENAMING_MAP, | |
| ) | |
| ] | |
| # Initialize the Two-Stage HQ Pipeline | |
| # Key parameters: | |
| # - checkpoint_path: Full dev model (trainable) | |
| # - distilled_lora: LoRA containing distilled knowledge | |
| # - distilled_lora_strength_stage_1: 0.25 (lighter application at half-res) | |
| # - distilled_lora_strength_stage_2: 0.5 (stronger application after upscaling) | |
| # - spatial_upsampler_path: Required for two-stage upscaling | |
| # - gemma_root: Gemma text encoder for prompt encoding | |
| print("Initializing LTX-2.3 Two-Stage HQ Pipeline...") | |
| pipeline = TI2VidTwoStagesHQPipeline( | |
| checkpoint_path=checkpoint_path, | |
| distilled_lora=distilled_lora, | |
| distilled_lora_strength_stage_1=0.25, # From HQ params | |
| distilled_lora_strength_stage_2=0.50, # From HQ params | |
| spatial_upsampler_path=spatial_upsampler_path, | |
| gemma_root=gemma_root, | |
| loras=(), # No additional custom LoRAs for this Space | |
| quantization=QuantizationPolicy.fp8_cast(), # FP8 for memory efficiency | |
| torch_compile=False, # Disable for Space compatibility | |
| ) | |
| print("Pipeline initialized successfully!") | |
| print("=" * 80) | |
| # ============================================================================= | |
| # ZeroGPU Tensor Preloading - CPU Tensor Approach | |
| # ============================================================================= | |
| # ZeroGPU should pack any tensors in memory, not just GPU tensors. | |
| # We load model weights to CPU as proxy tensors to trigger packing. | |
| # During actual generation, ZeroGPU will move them to GPU. | |
| print("Creating CPU proxy tensors for ZeroGPU tensor packing...") | |
| print("This may take a few minutes (loading to CPU only)...") | |
| import gc | |
| # Create small proxy tensors for each model component | |
| # These don't need to be the actual weights - just tensors to trigger packing | |
| # ZeroGPU will pack whatever tensors exist when it runs | |
| _proxy_tensors = [] | |
| def create_proxy(name, shape, dtype=torch.float32): | |
| """Create a proxy tensor and ensure ZeroGPU sees it.""" | |
| print(f" Creating proxy for {name}: {shape}") | |
| t = torch.zeros(shape, dtype=dtype) | |
| _proxy_tensors.append(t) | |
| return t | |
| # Create proxies for various model components | |
| # These are just to ensure tensors exist in memory for ZeroGPU to pack | |
| create_proxy("transformer_stage1", (1, 1024, 512)) | |
| create_proxy("transformer_stage2", (1, 1024, 512)) | |
| create_proxy("video_encoder", (1, 768, 512)) | |
| create_proxy("video_decoder", (1, 512, 512)) | |
| create_proxy("audio_decoder", (1, 256, 512)) | |
| create_proxy("spatial_upsampler", (1, 256, 512)) | |
| create_proxy("text_encoder", (1, 2048, 256)) | |
| create_proxy("vocoder", (1, 128, 256)) | |
| # Keep proxies alive by storing in module globals | |
| proxy_stage1 = _proxy_tensors[0] | |
| proxy_stage2 = _proxy_tensors[1] | |
| proxy_venc = _proxy_tensors[2] | |
| proxy_vdec = _proxy_tensors[3] | |
| proxy_adec = _proxy_tensors[4] | |
| proxy_upsamp = _proxy_tensors[5] | |
| proxy_tenc = _proxy_tensors[6] | |
| proxy_voc = _proxy_tensors[7] | |
| # Clean up the temporary list | |
| del _proxy_tensors | |
| # Now trigger the actual model loading but catch GPU errors | |
| print("\nAttempting model initialization (GPU errors expected)...") | |
| try: | |
| # Try to access components - this will trigger loading but fail on GPU | |
| _ = pipeline.stage_1._transformer_ctx | |
| _ = pipeline.prompt_encoder._text_encoder_ctx | |
| print(" Model contexts accessed") | |
| except Exception as e: | |
| print(f" Context access: {type(e).__name__}") | |
| print("\n" + "=" * 80) | |
| print("Startup complete. Models will load to GPU during first generation.") | |
| print("=" * 80) | |
| # ============================================================================= | |
| # Helper Functions | |
| # ============================================================================= | |
| def log_memory(tag: str): | |
| """Log current GPU memory usage for debugging.""" | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| peak = torch.cuda.max_memory_allocated() / 1024**3 | |
| free, total = torch.cuda.mem_get_info() | |
| print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB") | |
| def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int: | |
| """ | |
| Calculate number of frames from duration. | |
| Frame count must be 8*K + 1 (K is a non-negative integer) for the LTX model. | |
| This constraint comes from the temporal upsampling architecture. | |
| Args: | |
| duration: Duration in seconds | |
| frame_rate: Frames per second | |
| Returns: | |
| Frame count that satisfies the 8*K + 1 constraint | |
| """ | |
| ideal_frames = int(duration * frame_rate) | |
| # Ensure it's at least MIN_FRAMES | |
| ideal_frames = max(ideal_frames, MIN_FRAMES) | |
| # Round to nearest 8*K + 1 | |
| k = round((ideal_frames - 1) / 8) | |
| frames = k * 8 + 1 | |
| # Clamp to max | |
| return min(frames, MAX_FRAMES) | |
| def validate_resolution(height: int, width: int) -> tuple[int, int]: | |
| """ | |
| Ensure resolution is valid for two-stage pipeline. | |
| The two-stage pipeline requires: | |
| - Both dimensions divisible by 64 (for final resolution) | |
| - Stage 1 operates at half resolution (divisible by 32) | |
| Args: | |
| height: Target height | |
| width: Target width | |
| Returns: | |
| Validated (height, width) tuple | |
| """ | |
| # Round to nearest multiple of 64 | |
| height = round(height / STEP) * STEP | |
| width = round(width / STEP) * STEP | |
| # Clamp to valid range | |
| height = max(MIN_DIM, min(height, MAX_DIM)) | |
| width = max(MIN_DIM, min(width, MAX_DIM)) | |
| return height, width | |
| def detect_aspect_ratio(image) -> str: | |
| """Detect the closest aspect ratio from an image for resolution presets.""" | |
| if image is None: | |
| return "16:9" | |
| if hasattr(image, "size"): | |
| w, h = image.size | |
| elif hasattr(image, "shape"): | |
| h, w = image.shape[:2] | |
| else: | |
| return "16:9" | |
| ratio = w / h | |
| candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0} | |
| return min(candidates, key=lambda k: abs(ratio - candidates[k])) | |
| # Resolution presets based on aspect ratio | |
| RESOLUTIONS = { | |
| "16:9": {"width": 1280, "height": 704}, # 960x540 * 1.33 = 1280x720, halved = 640x360 -> 1280x720 | |
| "9:16": {"width": 704, "height": 1280}, | |
| "1:1": {"width": 960, "height": 960}, | |
| } | |
| def get_duration( | |
| prompt: str, | |
| negative_prompt: str, | |
| input_image, | |
| duration: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| height: int, | |
| width: int, | |
| enhance_prompt: bool, | |
| video_cfg_scale: float, | |
| video_stg_scale: float, | |
| video_rescale_scale: float, | |
| video_a2v_scale: float, | |
| audio_cfg_scale: float, | |
| audio_stg_scale: float, | |
| audio_rescale_scale: float, | |
| audio_v2a_scale: float, | |
| progress, | |
| ) -> int: | |
| """ | |
| Dynamically calculate GPU duration based on generation parameters. | |
| This is used by @spaces.GPU to set the appropriate time limit. | |
| Longer videos and higher resolution require more time. | |
| Args: | |
| duration: Video duration in seconds | |
| height, width: Resolution | |
| num_frames: Number of frames (indicates complexity) | |
| Returns: | |
| Duration in seconds for the GPU allocation | |
| """ | |
| base = 60 | |
| # Longer videos need more time | |
| if duration > 4: | |
| base += 15 | |
| if duration > 6: | |
| base += 15 | |
| # Higher resolution needs more time | |
| if height > 700 or width > 1000: | |
| base += 15 | |
| # More frames means more processing | |
| # Calculate num_frames inside get_duration since it's no longer a parameter | |
| frames_from_duration = int(duration * DEFAULT_FRAME_RATE) | |
| if frames_from_duration > 81: | |
| base += 10 | |
| def generate_video( | |
| prompt: str, | |
| negative_prompt: str, | |
| input_image, | |
| duration: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| height: int, | |
| width: int, | |
| enhance_prompt: bool, | |
| # Guidance parameters | |
| video_cfg_scale: float, | |
| video_stg_scale: float, | |
| video_rescale_scale: float, | |
| video_a2v_scale: float, | |
| audio_cfg_scale: float, | |
| audio_stg_scale: float, | |
| audio_rescale_scale: float, | |
| audio_v2a_scale: float, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Generate high-quality video using the Two-Stage HQ Pipeline. | |
| This function implements a two-stage generation process: | |
| Stage 1 (Half Resolution + CFG): | |
| - Generates video at half the target resolution | |
| - Uses GuidedDenoiser with CFG (positive + negative prompts) | |
| - Applies distilled LoRA at strength 0.25 | |
| - Res2s sampler for efficient second-order denoising | |
| Stage 2 (Upscale + Refine): | |
| - Upscales latent representation 2x using spatial upsampler | |
| - Refines using SimpleDenoiser (no CFG, distilled approach) | |
| - Applies distilled LoRA at strength 0.5 | |
| - 4-step refined denoising schedule | |
| Args: | |
| prompt: Text description of desired video content | |
| negative_prompt: What to avoid in the video | |
| input_image: Optional input image for image-to-video | |
| duration: Video duration in seconds | |
| seed: Random seed for reproducibility | |
| randomize_seed: Whether to use a random seed | |
| height, width: Target resolution (must be divisible by 64) | |
| enhance_prompt: Whether to use prompt enhancement | |
| video_cfg_scale: Video CFG (prompt adherence) | |
| video_stg_scale: Video STG (spatio-temporal guidance) | |
| video_rescale_scale: Video rescaling factor | |
| video_a2v_scale: Audio-to-video cross-attention scale | |
| audio_cfg_scale: Audio CFG (prompt adherence) | |
| audio_stg_scale: Audio STG (spatio-temporal guidance) | |
| audio_rescale_scale: Audio rescaling factor | |
| audio_v2a_scale: Video-to-audio cross-attention scale | |
| Returns: | |
| Tuple of (output_video_path, used_seed) | |
| """ | |
| try: | |
| torch.cuda.reset_peak_memory_stats() | |
| log_memory("start") | |
| # Handle random seed | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| print(f"Using seed: {current_seed}") | |
| # Validate and adjust resolution | |
| height, width = validate_resolution(int(height), int(width)) | |
| print(f"Resolution: {width}x{height}") | |
| # Calculate frames (must be 8*K + 1) | |
| num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE) | |
| print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)") | |
| # Prepare image conditioning if provided | |
| images = [] | |
| if input_image is not None: | |
| # Save input image temporarily | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| temp_image_path = output_dir / f"temp_input_{current_seed}.jpg" | |
| if hasattr(input_image, "save"): | |
| input_image.save(temp_image_path) | |
| else: | |
| import shutil | |
| shutil.copy(input_image, temp_image_path) | |
| # Create ImageConditioningInput | |
| # path: image file path | |
| # frame_idx: target frame to condition on (0 = first frame) | |
| # strength: conditioning strength (1.0 = full influence) | |
| images = [ImageConditioningInput( | |
| path=str(temp_image_path), | |
| frame_idx=0, | |
| strength=1.0 | |
| )] | |
| # Create tiling config for VAE decoding | |
| # Tiling is necessary to avoid OOM errors during decoding | |
| tiling_config = TilingConfig.default() | |
| video_chunks_number = get_video_chunks_number(num_frames, tiling_config) | |
| # Configure MultiModalGuider parameters | |
| # These control how the model adheres to prompts and handles modality guidance | |
| # Video guider parameters | |
| # cfg_scale: Classifier-free guidance scale (higher = stronger prompt adherence) | |
| # stg_scale: Spatio-temporal guidance scale (0 = disabled) | |
| # rescale_scale: Rescaling factor for oversaturation prevention | |
| # modality_scale: Cross-attention scale (audio-to-video) | |
| # skip_step: Step skipping for faster inference (0 = no skipping) | |
| # stg_blocks: Which transformer blocks to perturb for STG | |
| video_guider_params = MultiModalGuiderParams( | |
| cfg_scale=video_cfg_scale, | |
| stg_scale=video_stg_scale, | |
| rescale_scale=video_rescale_scale, | |
| modality_scale=video_a2v_scale, | |
| skip_step=0, | |
| stg_blocks=[], # Empty for LTX 2.3 HQ | |
| ) | |
| # Audio guider parameters | |
| audio_guider_params = MultiModalGuiderParams( | |
| cfg_scale=audio_cfg_scale, | |
| stg_scale=audio_stg_scale, | |
| rescale_scale=audio_rescale_scale, | |
| modality_scale=audio_v2a_scale, | |
| skip_step=0, | |
| stg_blocks=[], # Empty for LTX 2.3 HQ | |
| ) | |
| log_memory("before pipeline call") | |
| # Call the pipeline | |
| # The pipeline uses Res2sDiffusionStep for second-order sampling | |
| # Stage 1: num_inference_steps from LTX_2_3_HQ_PARAMS (15 steps) | |
| # Stage 2: Fixed 4-step schedule from STAGE_2_DISTILLED_SIGMAS | |
| video, audio = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| seed=current_seed, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| frame_rate=DEFAULT_FRAME_RATE, | |
| num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps, # 15 steps | |
| video_guider_params=video_guider_params, | |
| audio_guider_params=audio_guider_params, | |
| images=images, | |
| tiling_config=tiling_config, | |
| enhance_prompt=enhance_prompt, | |
| ) | |
| log_memory("after pipeline call") | |
| # Encode video with audio | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| encode_video( | |
| video=video, | |
| fps=DEFAULT_FRAME_RATE, | |
| audio=audio, | |
| output_path=output_path, | |
| video_chunks_number=video_chunks_number, | |
| ) | |
| log_memory("after encode_video") | |
| return str(output_path), current_seed | |
| except Exception as e: | |
| import traceback | |
| log_memory("on error") | |
| print(f"Error: {str(e)}\n{traceback.format_exc()}") | |
| return None, current_seed | |
| # ============================================================================= | |
| # Gradio UI | |
| # ============================================================================= | |
| css = """ | |
| /* Custom styling for LTX-2.3 Space */ | |
| .fillable {max-width: 1200px !important} | |
| .progress-text {color: white} | |
| """ | |
| with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo: | |
| gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation") | |
| gr.Markdown( | |
| "High-quality text/image-to-video generation using the dev model + distilled LoRA. " | |
| "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) " | |
| "[[GitHub]](https://github.com/Lightricks/LTX-2)" | |
| ) | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(): | |
| # Input image (optional) | |
| input_image = gr.Image( | |
| label="Input Image (Optional - for image-to-video)", | |
| type="pil", | |
| sources=["upload", "webcam", "clipboard"] | |
| ) | |
| # Prompt inputs | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| info="Describe the video you want to generate", | |
| value=DEFAULT_PROMPT, | |
| lines=3, | |
| placeholder="Enter your prompt here..." | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| info="What to avoid in the generated video", | |
| value=DEFAULT_NEGATIVE_PROMPT, | |
| lines=2, | |
| placeholder="Enter negative prompt here..." | |
| ) | |
| # Duration slider | |
| duration = gr.Slider( | |
| label="Duration (seconds)", | |
| minimum=0.5, | |
| maximum=8.0, | |
| value=2.0, | |
| step=0.1, | |
| info="Video duration (clamped to 8K+1 frames)" | |
| ) | |
| # Enhance prompt toggle | |
| enhance_prompt = gr.Checkbox( | |
| label="Enhance Prompt", | |
| value=False, | |
| info="Use Gemma to enhance the prompt for better results" | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button("Generate Video", variant="primary", size="lg") | |
| # Output Column | |
| with gr.Column(): | |
| output_video = gr.Video( | |
| label="Generated Video", | |
| autoplay=True, | |
| interactive=False | |
| ) | |
| # Advanced Settings Accordion | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| # Resolution inputs | |
| width = gr.Number( | |
| label="Width", | |
| value=1280, | |
| precision=0, | |
| info="Must be divisible by 64" | |
| ) | |
| height = gr.Number( | |
| label="Height", | |
| value=704, | |
| precision=0, | |
| info="Must be divisible by 64" | |
| ) | |
| with gr.Row(): | |
| # Seed controls | |
| seed = gr.Number( | |
| label="Seed", | |
| value=42, | |
| precision=0, | |
| minimum=0, | |
| maximum=MAX_SEED | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True | |
| ) | |
| gr.Markdown("### Video Guidance Parameters") | |
| gr.Markdown("Control how strongly the model follows the video prompt and handles guidance.") | |
| with gr.Row(): | |
| video_cfg_scale = gr.Slider( | |
| label="Video CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale, | |
| step=0.1, | |
| info="Classifier-free guidance for video (higher = stronger prompt adherence)" | |
| ) | |
| video_stg_scale = gr.Slider( | |
| label="Video STG Scale", | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.0, | |
| step=0.1, | |
| info="Spatio-temporal guidance (0 = disabled)" | |
| ) | |
| with gr.Row(): | |
| video_rescale_scale = gr.Slider( | |
| label="Video Rescale", | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.45, | |
| step=0.1, | |
| info="Rescaling factor for oversaturation prevention" | |
| ) | |
| video_a2v_scale = gr.Slider( | |
| label="A2V Scale", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=3.0, | |
| step=0.1, | |
| info="Audio-to-video cross-attention scale" | |
| ) | |
| gr.Markdown("### Audio Guidance Parameters") | |
| gr.Markdown("Control audio generation quality and sync.") | |
| with gr.Row(): | |
| audio_cfg_scale = gr.Slider( | |
| label="Audio CFG Scale", | |
| minimum=1.0, | |
| maximum=15.0, | |
| value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale, | |
| step=0.1, | |
| info="Classifier-free guidance for audio" | |
| ) | |
| audio_stg_scale = gr.Slider( | |
| label="Audio STG Scale", | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.0, | |
| step=0.1, | |
| info="Spatio-temporal guidance for audio (0 = disabled)" | |
| ) | |
| with gr.Row(): | |
| audio_rescale_scale = gr.Slider( | |
| label="Audio Rescale", | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| info="Audio rescaling factor" | |
| ) | |
| audio_v2a_scale = gr.Slider( | |
| label="V2A Scale", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=3.0, | |
| step=0.1, | |
| info="Video-to-audio cross-attention scale" | |
| ) | |
| # Event handlers | |
| def on_image_upload(image, current_h, current_w): | |
| """Update resolution based on uploaded image aspect ratio.""" | |
| if image is None: | |
| return gr.update(), gr.update() | |
| aspect = detect_aspect_ratio(image) | |
| if aspect in RESOLUTIONS: | |
| return ( | |
| gr.update(value=RESOLUTIONS[aspect]["width"]), | |
| gr.update(value=RESOLUTIONS[aspect]["height"]) | |
| ) | |
| return gr.update(), gr.update() | |
| input_image.change( | |
| fn=on_image_upload, | |
| inputs=[input_image, height, width], | |
| outputs=[width, height], | |
| ) | |
| # Generate button click handler | |
| generate_btn.click( | |
| fn=generate_video, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| input_image, | |
| duration, | |
| seed, | |
| randomize_seed, | |
| height, | |
| width, | |
| enhance_prompt, | |
| video_cfg_scale, | |
| video_stg_scale, | |
| video_rescale_scale, | |
| video_a2v_scale, | |
| audio_cfg_scale, | |
| audio_stg_scale, | |
| audio_rescale_scale, | |
| audio_v2a_scale, | |
| ], | |
| outputs=[output_video, seed], | |
| ) | |
| # ============================================================================= | |
| # Main Entry Point | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| theme=gr.themes.Citrus(), | |
| css=css, | |
| mcp_server=True, | |
| share=True, | |
| ) |