void-model / diffusers /pipeline_void.py
sam-motamed's picture
Update pipeline_void.py: handle diffusers/ subfolder for local paths
3fe800f verified
raw
history blame
20.3 kB
"""
VOID (Video Object and Interaction Deletion) Pipeline.
Simple usage:
from pipeline_void import VOIDPipeline
pipe = VOIDPipeline.from_pretrained("netflix/void-model")
result = pipe.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.")
result.save("output.mp4")
Pass 2 refinement:
pipe2 = VOIDPipeline.from_pretrained("netflix/void-model", void_pass=2)
result2 = pipe2.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.",
pass1_video="output.mp4")
result2.save("output_refined.mp4")
"""
import os
import json
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from diffusers import CogVideoXDDIMScheduler
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from cogvideox_transformer3d import CogVideoXTransformer3DModel
from cogvideox_vae import AutoencoderKLCogVideoX
from pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
# The base model that VOID is fine-tuned from
BASE_MODEL_REPO = "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP"
# Checkpoint filenames in the VOID repo
PASS_CHECKPOINTS = {
1: "void_pass1.safetensors",
2: "void_pass2.safetensors",
}
# Default negative prompt (from config/quadmask_cogvideox.py)
DEFAULT_NEGATIVE_PROMPT = (
"The video is not of a high quality, it has a low resolution. "
"Watermark present in each frame. The background is solid. "
"Strange body and strange trajectory. Distortion. "
)
@dataclass
class VOIDOutput:
"""Output from VOID pipeline."""
video: torch.Tensor # (T, H, W, 3) uint8
video_float: torch.Tensor # (1, C, T, H, W) float [0, 1]
def save(self, path: str, fps: int = 12):
"""Save output video to file."""
import imageio
frames = [f for f in self.video.cpu().numpy()]
imageio.mimwrite(path, frames, fps=fps)
print(f"Saved {len(frames)} frames to {path}")
def _merge_void_weights(transformer, checkpoint_path):
"""Merge VOID checkpoint into base transformer, handling channel mismatch."""
state_dict = load_file(checkpoint_path)
param_name = "patch_embed.proj.weight"
if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
latent_ch = 16
feat_scale = 8
feat_dim = int(latent_ch * feat_scale)
new_weight = transformer.state_dict()[param_name].clone()
new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
state_dict[param_name] = new_weight
m, u = transformer.load_state_dict(state_dict, strict=False)
if m:
print(f"[VOID] Missing keys: {len(m)}")
if u:
print(f"[VOID] Unexpected keys: {len(u)}")
return transformer
def _load_video(path: str, max_frames: int) -> np.ndarray:
"""Load video as numpy array (T, H, W, 3) uint8."""
import imageio
frames = list(imageio.imiter(path))
frames = frames[:max_frames]
return np.array(frames)
def _prep_video_tensor(
video_np: np.ndarray,
sample_size: Tuple[int, int],
) -> torch.Tensor:
"""Convert video numpy array to pipeline input tensor.
Returns: (1, C, T, H, W) float32 in [0, 1]
"""
video = torch.from_numpy(video_np).float()
video = video.permute(3, 0, 1, 2) / 255.0 # (C, T, H, W)
video = F.interpolate(video, sample_size, mode="area")
return video.unsqueeze(0) # (1, C, T, H, W)
def _prep_mask_tensor(
mask_np: np.ndarray,
sample_size: Tuple[int, int],
use_quadmask: bool = True,
) -> torch.Tensor:
"""Convert mask numpy array to pipeline input tensor.
Quantizes to quadmask values [0, 63, 127, 255], inverts,
and normalizes to [0, 1].
Returns: (1, 1, T, H, W) float32 in [0, 1]
"""
mask = torch.from_numpy(mask_np).float()
if mask.ndim == 4:
mask = mask[..., 0] # drop channel dim -> (T, H, W)
mask = F.interpolate(mask.unsqueeze(0), sample_size, mode="area")
mask = mask.unsqueeze(0) # (1, 1, T, H, W)
if use_quadmask:
# Quantize to 4 values
mask = torch.where(mask <= 31, 0., mask)
mask = torch.where((mask > 31) * (mask <= 95), 63., mask)
mask = torch.where((mask > 95) * (mask <= 191), 127., mask)
mask = torch.where(mask > 191, 255., mask)
else:
# Trimask: 3 values
mask = torch.where(mask > 192, 255., mask)
mask = torch.where((mask <= 192) * (mask >= 64), 128., mask)
mask = torch.where(mask < 64, 0., mask)
# Invert and normalize to [0, 1]
mask = (255. - mask) / 255.
return mask
def _temporal_padding(
tensor: torch.Tensor,
min_length: int = 85,
max_length: int = 197,
dim: int = 2,
) -> torch.Tensor:
"""Pad video temporally by mirroring, matching CogVideoX requirements."""
length = tensor.size(dim)
min_len = (length // 4) * 4 + 1
if min_len < length:
min_len += 4
if (min_len / 4) % 2 == 0:
min_len += 4
target_length = min(min_len, max_length)
target_length = max(min_length, target_length)
# Truncate if needed
if dim == 2:
tensor = tensor[:, :, :target_length]
else:
raise NotImplementedError(f"dim={dim} not supported")
# Pad by mirroring
while tensor.size(dim) < target_length:
flipped = torch.flip(tensor, [dim])
tensor = torch.cat([tensor, flipped], dim=dim)
if dim == 2:
tensor = tensor[:, :, :target_length]
return tensor
def _generate_warped_noise(
pass1_video_path: str,
target_shape: Tuple[int, int, int, int],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Generate warped noise from Pass 1 output video.
Args:
pass1_video_path: Path to Pass 1 output video.
target_shape: (latent_T, latent_H, latent_W, latent_C)
device: Target device.
dtype: Target dtype.
Returns: (1, T, C, H, W) warped noise tensor.
"""
# Try to import rp and nw for direct warped noise generation
try:
# Fix for SLURM: rp crashes parsing GPU UUIDs like "GPU-9fca2b4f-..."
# Set CUDA_VISIBLE_DEVICES to numeric index if it contains UUIDs
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_env and not cuda_env.replace(",", "").isdigit():
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import rp
rp.r._pip_import_autoyes = True
rp.git_import('CommonSource')
import rp.git.CommonSource.noise_warp as nw
return _generate_warped_noise_direct(pass1_video_path, target_shape, device, dtype)
except ImportError as e:
print(f"[VOID] rp/noise_warp not available: {e}")
except Exception as e:
print(f"[VOID] Warped noise generation via rp failed: {e}")
import traceback
traceback.print_exc()
# Fallback: try to find and run make_warped_noise.py as subprocess
script_candidates = [
os.path.join(os.path.dirname(__file__), "make_warped_noise.py"),
os.path.join(os.path.dirname(__file__), "..", "inference", "cogvideox_fun", "make_warped_noise.py"),
]
gwf_script = None
for candidate in script_candidates:
if os.path.exists(candidate):
gwf_script = candidate
break
if gwf_script is None:
raise RuntimeError(
"Cannot generate warped noise: 'rp' package not installed and "
"make_warped_noise.py not found. Install 'rp' package or provide "
"pre-computed warped noise via warped_noise_path parameter."
)
with tempfile.TemporaryDirectory() as tmpdir:
cmd = [sys.executable, gwf_script, os.path.abspath(pass1_video_path), tmpdir]
print(f"[VOID] Generating warped noise (this may take a few minutes)...")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if result.returncode != 0:
raise RuntimeError(f"Warped noise generation failed:\n{result.stderr}")
# Find the output noises.npy
video_stem = os.path.splitext(os.path.basename(pass1_video_path))[0]
noise_path = os.path.join(tmpdir, video_stem, "noises.npy")
if not os.path.exists(noise_path):
# Try flat path
noise_path = os.path.join(tmpdir, "noises.npy")
if not os.path.exists(noise_path):
raise RuntimeError(f"Warped noise file not found after generation")
return _load_warped_noise(noise_path, target_shape, device, dtype)
def _generate_warped_noise_direct(
video_path: str,
target_shape: Tuple[int, int, int, int],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Generate warped noise directly using rp package."""
import rp
import rp.git.CommonSource.noise_warp as nw
video = rp.load_video(video_path)
video = rp.resize_list(video, length=72)
video = rp.resize_images_to_hold(video, height=480, width=720)
video = rp.crop_images(video, height=480, width=720, origin='center')
video = rp.as_numpy_array(video)
FRAME = 2**-1
FLOW = 2**3
LATENT = 8
output = nw.get_noise_from_video(
video,
remove_background=False,
visualize=False,
save_files=False,
noise_channels=16,
resize_frames=FRAME,
resize_flow=FLOW,
downscale_factor=round(FRAME * FLOW) * LATENT,
)
noises = output.numpy_noises # (T, H, W, C)
return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
def _load_warped_noise(
noise_path: str,
target_shape: Tuple[int, int, int, int],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Load and resize pre-computed warped noise."""
noises = np.load(noise_path)
if noises.dtype == np.float16:
noises = noises.astype(np.float32)
# Ensure THWC format
if noises.shape[1] == 16: # TCHW -> THWC
noises = np.transpose(noises, (0, 2, 3, 1))
return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
def _numpy_noise_to_tensor(
noises: np.ndarray,
target_shape: Tuple[int, int, int, int],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Convert numpy noise (T, H, W, C) to pipeline tensor (1, T, C, H, W)."""
latent_T, latent_H, latent_W, latent_C = target_shape
# Temporal resize if needed
if noises.shape[0] != latent_T:
indices = np.linspace(0, noises.shape[0] - 1, latent_T)
lower = np.floor(indices).astype(int)
upper = np.ceil(indices).astype(int)
frac = indices - lower
noises = noises[lower] * (1 - frac[:, None, None, None]) + noises[upper] * frac[:, None, None, None]
# Spatial resize if needed
if noises.shape[1] != latent_H or noises.shape[2] != latent_W:
resized = np.zeros((latent_T, latent_H, latent_W, latent_C), dtype=noises.dtype)
for t in range(latent_T):
for c in range(latent_C):
resized[t, :, :, c] = cv2.resize(
noises[t, :, :, c], (latent_W, latent_H),
interpolation=cv2.INTER_LINEAR,
)
noises = resized
# Convert to tensor: (T, H, W, C) -> (1, T, C, H, W)
tensor = torch.from_numpy(noises).permute(0, 3, 1, 2).unsqueeze(0)
return tensor.to(device=device, dtype=dtype)
class VOIDPipeline(CogVideoXFunInpaintPipeline):
"""
VOID: Video Object and Interaction Deletion.
Removes objects and their physical interactions from videos using
quadmask-conditioned video inpainting.
"""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
void_pass: int = 1,
base_model: str = BASE_MODEL_REPO,
torch_dtype: torch.dtype = torch.bfloat16,
**kwargs,
):
"""
Load the VOID pipeline.
Args:
pretrained_model_name_or_path: HF repo ID or local path containing
VOID checkpoint files (void_pass1.safetensors, etc.)
void_pass: Which pass checkpoint to load (1 or 2). Default: 1.
base_model: HF repo ID for the base CogVideoX-Fun model.
torch_dtype: Weight dtype. Default: torch.bfloat16.
"""
if void_pass not in PASS_CHECKPOINTS:
raise ValueError(f"void_pass must be 1 or 2, got {void_pass}")
# --- Download VOID checkpoint ---
checkpoint_name = PASS_CHECKPOINTS[void_pass]
print(f"[VOID] Loading Pass {void_pass} checkpoint...")
if os.path.isdir(pretrained_model_name_or_path):
checkpoint_path = os.path.join(pretrained_model_name_or_path, checkpoint_name)
if not os.path.exists(checkpoint_path):
# Check parent dir (checkpoints at root, code in diffusers/)
checkpoint_path = os.path.join(pretrained_model_name_or_path, "..", checkpoint_name)
else:
checkpoint_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename=checkpoint_name,
)
# --- Download and load base model ---
print(f"[VOID] Loading base model: {base_model}")
base_model_path = snapshot_download(repo_id=base_model)
# Transformer (with VAE mask channels)
print("[VOID] Loading transformer...")
transformer = CogVideoXTransformer3DModel.from_pretrained(
base_model_path,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
use_vae_mask=True,
)
# Merge VOID weights
print(f"[VOID] Merging Pass {void_pass} weights...")
transformer = _merge_void_weights(transformer, checkpoint_path)
transformer = transformer.to(torch_dtype)
# VAE
print("[VOID] Loading VAE...")
vae = AutoencoderKLCogVideoX.from_pretrained(
base_model_path, subfolder="vae"
).to(torch_dtype)
# Tokenizer + Text encoder
print("[VOID] Loading tokenizer and text encoder...")
from transformers import T5Tokenizer, T5EncoderModel
tokenizer = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=torch_dtype,
)
# Scheduler
scheduler = CogVideoXDDIMScheduler.from_pretrained(
base_model_path, subfolder="scheduler"
)
# Build pipeline
pipe = cls(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
pipe._void_pass = void_pass
print("[VOID] Pipeline ready!")
return pipe
def inpaint(
self,
video_path: str,
mask_path: str,
prompt: str,
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
height: int = 384,
width: int = 672,
num_inference_steps: int = 30,
guidance_scale: float = 1.0,
strength: float = 1.0,
temporal_window_size: int = 85,
max_video_length: int = 197,
fps: int = 12,
seed: int = 42,
pass1_video: Optional[str] = None,
warped_noise_path: Optional[str] = None,
use_quadmask: bool = True,
) -> VOIDOutput:
"""
Run VOID inpainting on a video.
Args:
video_path: Path to input video (mp4).
mask_path: Path to quadmask video (mp4). Grayscale with values:
0=object to remove, 63=overlap, 127=affected region, 255=background.
prompt: Text description of the desired result after removal.
E.g., "A lime falls on the table."
negative_prompt: Negative prompt for generation quality.
height: Output height (default 384).
width: Output width (default 672).
num_inference_steps: Denoising steps (default 30).
guidance_scale: CFG scale (default 1.0 = no CFG).
strength: Denoising strength (default 1.0).
temporal_window_size: Frames per inference window (default 85).
max_video_length: Max frames to process (default 197).
fps: Output FPS (default 12).
seed: Random seed (default 42).
pass1_video: Path to Pass 1 output video, for Pass 2 warped noise init.
warped_noise_path: Path to pre-computed warped noise (.npy).
use_quadmask: Use 4-value quadmask (default True). Set False for trimask.
Returns:
VOIDOutput with .video (uint8) and .save() method.
"""
sample_size = (height, width)
# Align video length to VAE temporal compression ratio
vae_temporal_ratio = self.vae.config.temporal_compression_ratio
video_length = int((max_video_length - 1) // vae_temporal_ratio * vae_temporal_ratio) + 1
# --- Load and prep video ---
print("[VOID] Loading video and mask...")
vid_np = _load_video(video_path, video_length)
mask_np = _load_video(mask_path, video_length)
video = _prep_video_tensor(vid_np, sample_size)
mask = _prep_mask_tensor(mask_np, sample_size, use_quadmask=use_quadmask)
# Temporal padding
video = _temporal_padding(video, min_length=temporal_window_size, max_length=max_video_length)
mask = _temporal_padding(mask, min_length=temporal_window_size, max_length=max_video_length)
num_frames = min(video.shape[2], temporal_window_size)
print(f"[VOID] Video: {video.shape}, Mask: {mask.shape}, Frames: {num_frames}")
# --- Handle warped noise for Pass 2 ---
latents = None
if warped_noise_path is not None or pass1_video is not None:
latent_T = (num_frames - 1) // 4 + 1
latent_H = height // 8
latent_W = width // 8
latent_C = 16
target_shape = (latent_T, latent_H, latent_W, latent_C)
if warped_noise_path is not None:
print(f"[VOID] Loading pre-computed warped noise from {warped_noise_path}")
latents = _load_warped_noise(
warped_noise_path, target_shape,
device=torch.device("cpu"), dtype=torch.bfloat16,
)
else:
print(f"[VOID] Generating warped noise from Pass 1 output...")
latents = _generate_warped_noise(
pass1_video, target_shape,
device=torch.device("cpu"), dtype=torch.bfloat16,
)
print(f"[VOID] Warped noise: {latents.shape}, mean={latents.mean():.4f}, std={latents.std():.4f}")
# --- Run inference ---
generator = torch.Generator(device="cpu").manual_seed(seed)
print(f"[VOID] Running inference ({num_frames} frames, {num_inference_steps} steps)...")
with torch.no_grad():
output = self(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
video=video,
mask_video=mask,
strength=strength,
use_trimask=True,
use_vae_mask=True,
latents=latents,
).videos
# --- Process output ---
if isinstance(output, np.ndarray):
output = torch.from_numpy(output)
# output is (B, C, T, H, W) in [0, 1]
video_float = output
video_uint8 = (output[0].permute(1, 2, 3, 0).clamp(0, 1) * 255).to(torch.uint8)
print(f"[VOID] Done! Output: {video_uint8.shape}")
return VOIDOutput(video=video_uint8, video_float=video_float)