| """Image-to-video generation using Wan 2.1 via fal.ai API. |
| |
| Reads generated images and their prompts, produces a short video clip |
| per segment. Each clip is ~5s at 16fps; the assembler later trims to |
| the exact beat interval duration. |
| |
| Two backends: |
| - "api" : fal.ai hosted Wan 2.1 (for development / local runs) |
| - "hf" : on-device Wan 2.1 with FP8 on ZeroGPU (for HF Spaces deployment) |
| |
| Set FAL_KEY env var for API mode. |
| """ |
|
|
| import base64 |
| import json |
| import os |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import requests |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| |
| |
| |
|
|
| FAL_MODEL_ID = "fal-ai/wan-i2v" |
|
|
| |
| ASPECT_RATIO = "9:16" |
| RESOLUTION = "480p" |
| NUM_FRAMES = 81 |
| FPS = 16 |
| NUM_INFERENCE_STEPS = 30 |
| GUIDANCE_SCALE = 5.0 |
| SEED = 42 |
|
|
|
|
| def _image_to_data_uri(image_path: str | Path) -> str: |
| """Convert a local image file to a base64 data URI for the API.""" |
| path = Path(image_path) |
| suffix = path.suffix.lower() |
| mime = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"} |
| content_type = mime.get(suffix, "image/png") |
|
|
| with open(path, "rb") as f: |
| encoded = base64.b64encode(f.read()).decode() |
|
|
| return f"data:{content_type};base64,{encoded}" |
|
|
|
|
| def _download_video(url: str, output_path: Path) -> Path: |
| """Download a video from URL to a local file.""" |
| resp = requests.get(url, timeout=300) |
| resp.raise_for_status() |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "wb") as f: |
| f.write(resp.content) |
| return output_path |
|
|
|
|
| |
| |
| |
|
|
| def generate_clip_api( |
| image_path: str | Path, |
| prompt: str, |
| negative_prompt: str = "", |
| seed: Optional[int] = None, |
| ) -> dict: |
| """Generate a video clip from an image using fal.ai Wan 2.1 API. |
| |
| Args: |
| image_path: Path to the source image. |
| prompt: Motion/scene description for the video. |
| negative_prompt: What to avoid. |
| seed: Random seed for reproducibility. |
| |
| Returns: |
| API response dict with 'video' (url, content_type, file_size) and 'seed'. |
| """ |
| import fal_client |
|
|
| image_uri = _image_to_data_uri(image_path) |
|
|
| args = { |
| "image_url": image_uri, |
| "prompt": prompt, |
| "aspect_ratio": ASPECT_RATIO, |
| "resolution": RESOLUTION, |
| "num_frames": NUM_FRAMES, |
| "frames_per_second": FPS, |
| "num_inference_steps": NUM_INFERENCE_STEPS, |
| "guide_scale": GUIDANCE_SCALE, |
| "negative_prompt": negative_prompt, |
| "enable_safety_checker": False, |
| "enable_prompt_expansion": False, |
| } |
| if seed is not None: |
| args["seed"] = seed |
|
|
| result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def generate_clip( |
| image_path: str | Path, |
| prompt: str, |
| output_path: str | Path, |
| negative_prompt: str = "", |
| seed: Optional[int] = None, |
| ) -> Path: |
| """Generate a video clip from an image and save it locally. |
| |
| Args: |
| image_path: Path to the source image. |
| prompt: Motion/scene description. |
| output_path: Where to save the .mp4 clip. |
| negative_prompt: What to avoid. |
| seed: Random seed. |
| |
| Returns: |
| Path to the saved video clip. |
| """ |
| output_path = Path(output_path) |
|
|
| result = generate_clip_api(image_path, prompt, negative_prompt, seed) |
|
|
| video_url = result["video"]["url"] |
| return _download_video(video_url, output_path) |
|
|
|
|
| def generate_all( |
| segments: list[dict], |
| images_dir: str | Path, |
| output_dir: str | Path, |
| seed: int = SEED, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Generate video clips for all segments. |
| |
| Expects images at images_dir/segment_001.png, segment_002.png, etc. |
| Segments should have 'prompt' and optionally 'negative_prompt' keys |
| (from prompt_generator). |
| |
| Args: |
| segments: List of segment dicts with 'segment', 'prompt' keys. |
| images_dir: Directory containing generated images. |
| output_dir: Directory to save video clips. |
| seed: Base seed (incremented per segment). |
| |
| Returns: |
| List of saved video clip paths. |
| """ |
| images_dir = Path(images_dir) |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| paths = [] |
| for seg in segments: |
| idx = seg["segment"] |
| image_path = images_dir / f"segment_{idx:03d}.png" |
| clip_path = output_dir / f"clip_{idx:03d}.mp4" |
|
|
| if clip_path.exists(): |
| print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| paths.append(clip_path) |
| continue |
|
|
| if not image_path.exists(): |
| print(f" Segment {idx}: image not found at {image_path}, skipping") |
| continue |
|
|
| |
| prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) |
| neg = seg.get("negative_prompt", "") |
|
|
| print(f" Segment {idx}/{len(segments)}: generating video clip...") |
| t0 = time.time() |
| generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) |
| elapsed = time.time() - t0 |
| print(f" Saved {clip_path.name} ({elapsed:.1f}s)") |
|
|
| paths.append(clip_path) |
| if progress_callback: |
| progress_callback(idx, len(segments)) |
|
|
| return paths |
|
|
|
|
| def run( |
| data_dir: str | Path, |
| seed: int = SEED, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Full video generation pipeline: read segments, generate clips, save. |
| |
| Args: |
| data_dir: Song data directory containing segments.json and images/. |
| seed: Base random seed. |
| |
| Returns: |
| List of saved video clip paths. |
| """ |
| data_dir = Path(data_dir) |
|
|
| with open(data_dir / "segments.json") as f: |
| segments = json.load(f) |
|
|
| paths = generate_all( |
| segments, |
| images_dir=data_dir / "images", |
| output_dir=data_dir / "clips", |
| seed=seed, |
| progress_callback=progress_callback, |
| ) |
|
|
| print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") |
| return paths |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python -m src.video_generator <data_dir>") |
| print(" e.g. python -m src.video_generator data/Gone") |
| print("\nRequires FAL_KEY environment variable.") |
| sys.exit(1) |
|
|
| if not os.getenv("FAL_KEY"): |
| print("Error: FAL_KEY environment variable not set.") |
| print("Get your key at https://fal.ai/dashboard/keys") |
| sys.exit(1) |
|
|
| run(sys.argv[1]) |
|
|