File size: 20,309 Bytes
b20559f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fe800f
 
 
b20559f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
"""
VOID (Video Object and Interaction Deletion) Pipeline.

Simple usage:

    from pipeline_void import VOIDPipeline

    pipe = VOIDPipeline.from_pretrained("netflix/void-model")
    result = pipe.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.")
    result.save("output.mp4")

Pass 2 refinement:

    pipe2 = VOIDPipeline.from_pretrained("netflix/void-model", void_pass=2)
    result2 = pipe2.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.",
                            pass1_video="output.mp4")
    result2.save("output_refined.mp4")
"""

import os
import json
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from diffusers import CogVideoXDDIMScheduler
from diffusers.pipelines.pipeline_utils import DiffusionPipeline

from cogvideox_transformer3d import CogVideoXTransformer3DModel
from cogvideox_vae import AutoencoderKLCogVideoX
from pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline

# The base model that VOID is fine-tuned from
BASE_MODEL_REPO = "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP"

# Checkpoint filenames in the VOID repo
PASS_CHECKPOINTS = {
    1: "void_pass1.safetensors",
    2: "void_pass2.safetensors",
}

# Default negative prompt (from config/quadmask_cogvideox.py)
DEFAULT_NEGATIVE_PROMPT = (
    "The video is not of a high quality, it has a low resolution. "
    "Watermark present in each frame. The background is solid. "
    "Strange body and strange trajectory. Distortion. "
)


@dataclass
class VOIDOutput:
    """Output from VOID pipeline."""
    video: torch.Tensor  # (T, H, W, 3) uint8
    video_float: torch.Tensor  # (1, C, T, H, W) float [0, 1]

    def save(self, path: str, fps: int = 12):
        """Save output video to file."""
        import imageio
        frames = [f for f in self.video.cpu().numpy()]
        imageio.mimwrite(path, frames, fps=fps)
        print(f"Saved {len(frames)} frames to {path}")


def _merge_void_weights(transformer, checkpoint_path):
    """Merge VOID checkpoint into base transformer, handling channel mismatch."""
    state_dict = load_file(checkpoint_path)
    param_name = "patch_embed.proj.weight"

    if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
        latent_ch = 16
        feat_scale = 8
        feat_dim = int(latent_ch * feat_scale)

        new_weight = transformer.state_dict()[param_name].clone()
        new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
        new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
        state_dict[param_name] = new_weight

    m, u = transformer.load_state_dict(state_dict, strict=False)
    if m:
        print(f"[VOID] Missing keys: {len(m)}")
    if u:
        print(f"[VOID] Unexpected keys: {len(u)}")

    return transformer


def _load_video(path: str, max_frames: int) -> np.ndarray:
    """Load video as numpy array (T, H, W, 3) uint8."""
    import imageio
    frames = list(imageio.imiter(path))
    frames = frames[:max_frames]
    return np.array(frames)


def _prep_video_tensor(
    video_np: np.ndarray,
    sample_size: Tuple[int, int],
) -> torch.Tensor:
    """Convert video numpy array to pipeline input tensor.

    Returns: (1, C, T, H, W) float32 in [0, 1]
    """
    video = torch.from_numpy(video_np).float()
    video = video.permute(3, 0, 1, 2) / 255.0  # (C, T, H, W)
    video = F.interpolate(video, sample_size, mode="area")
    return video.unsqueeze(0)  # (1, C, T, H, W)


def _prep_mask_tensor(
    mask_np: np.ndarray,
    sample_size: Tuple[int, int],
    use_quadmask: bool = True,
) -> torch.Tensor:
    """Convert mask numpy array to pipeline input tensor.

    Quantizes to quadmask values [0, 63, 127, 255], inverts,
    and normalizes to [0, 1].

    Returns: (1, 1, T, H, W) float32 in [0, 1]
    """
    mask = torch.from_numpy(mask_np).float()
    if mask.ndim == 4:
        mask = mask[..., 0]  # drop channel dim -> (T, H, W)
    mask = F.interpolate(mask.unsqueeze(0), sample_size, mode="area")
    mask = mask.unsqueeze(0)  # (1, 1, T, H, W)

    if use_quadmask:
        # Quantize to 4 values
        mask = torch.where(mask <= 31, 0., mask)
        mask = torch.where((mask > 31) * (mask <= 95), 63., mask)
        mask = torch.where((mask > 95) * (mask <= 191), 127., mask)
        mask = torch.where(mask > 191, 255., mask)
    else:
        # Trimask: 3 values
        mask = torch.where(mask > 192, 255., mask)
        mask = torch.where((mask <= 192) * (mask >= 64), 128., mask)
        mask = torch.where(mask < 64, 0., mask)

    # Invert and normalize to [0, 1]
    mask = (255. - mask) / 255.

    return mask


def _temporal_padding(
    tensor: torch.Tensor,
    min_length: int = 85,
    max_length: int = 197,
    dim: int = 2,
) -> torch.Tensor:
    """Pad video temporally by mirroring, matching CogVideoX requirements."""
    length = tensor.size(dim)

    min_len = (length // 4) * 4 + 1
    if min_len < length:
        min_len += 4
    if (min_len / 4) % 2 == 0:
        min_len += 4
    target_length = min(min_len, max_length)
    target_length = max(min_length, target_length)

    # Truncate if needed
    if dim == 2:
        tensor = tensor[:, :, :target_length]
    else:
        raise NotImplementedError(f"dim={dim} not supported")

    # Pad by mirroring
    while tensor.size(dim) < target_length:
        flipped = torch.flip(tensor, [dim])
        tensor = torch.cat([tensor, flipped], dim=dim)

    if dim == 2:
        tensor = tensor[:, :, :target_length]

    return tensor


def _generate_warped_noise(
    pass1_video_path: str,
    target_shape: Tuple[int, int, int, int],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Generate warped noise from Pass 1 output video.

    Args:
        pass1_video_path: Path to Pass 1 output video.
        target_shape: (latent_T, latent_H, latent_W, latent_C)
        device: Target device.
        dtype: Target dtype.

    Returns: (1, T, C, H, W) warped noise tensor.
    """
    # Try to import rp and nw for direct warped noise generation
    try:
        # Fix for SLURM: rp crashes parsing GPU UUIDs like "GPU-9fca2b4f-..."
        # Set CUDA_VISIBLE_DEVICES to numeric index if it contains UUIDs
        cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
        if cuda_env and not cuda_env.replace(",", "").isdigit():
            os.environ["CUDA_VISIBLE_DEVICES"] = "0"

        import rp
        rp.r._pip_import_autoyes = True
        rp.git_import('CommonSource')
        import rp.git.CommonSource.noise_warp as nw
        return _generate_warped_noise_direct(pass1_video_path, target_shape, device, dtype)
    except ImportError as e:
        print(f"[VOID] rp/noise_warp not available: {e}")
    except Exception as e:
        print(f"[VOID] Warped noise generation via rp failed: {e}")
        import traceback
        traceback.print_exc()

    # Fallback: try to find and run make_warped_noise.py as subprocess
    script_candidates = [
        os.path.join(os.path.dirname(__file__), "make_warped_noise.py"),
        os.path.join(os.path.dirname(__file__), "..", "inference", "cogvideox_fun", "make_warped_noise.py"),
    ]
    gwf_script = None
    for candidate in script_candidates:
        if os.path.exists(candidate):
            gwf_script = candidate
            break

    if gwf_script is None:
        raise RuntimeError(
            "Cannot generate warped noise: 'rp' package not installed and "
            "make_warped_noise.py not found. Install 'rp' package or provide "
            "pre-computed warped noise via warped_noise_path parameter."
        )

    with tempfile.TemporaryDirectory() as tmpdir:
        cmd = [sys.executable, gwf_script, os.path.abspath(pass1_video_path), tmpdir]
        print(f"[VOID] Generating warped noise (this may take a few minutes)...")
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
        if result.returncode != 0:
            raise RuntimeError(f"Warped noise generation failed:\n{result.stderr}")

        # Find the output noises.npy
        video_stem = os.path.splitext(os.path.basename(pass1_video_path))[0]
        noise_path = os.path.join(tmpdir, video_stem, "noises.npy")
        if not os.path.exists(noise_path):
            # Try flat path
            noise_path = os.path.join(tmpdir, "noises.npy")
        if not os.path.exists(noise_path):
            raise RuntimeError(f"Warped noise file not found after generation")

        return _load_warped_noise(noise_path, target_shape, device, dtype)


def _generate_warped_noise_direct(
    video_path: str,
    target_shape: Tuple[int, int, int, int],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Generate warped noise directly using rp package."""
    import rp
    import rp.git.CommonSource.noise_warp as nw

    video = rp.load_video(video_path)
    video = rp.resize_list(video, length=72)
    video = rp.resize_images_to_hold(video, height=480, width=720)
    video = rp.crop_images(video, height=480, width=720, origin='center')
    video = rp.as_numpy_array(video)

    FRAME = 2**-1
    FLOW = 2**3
    LATENT = 8

    output = nw.get_noise_from_video(
        video,
        remove_background=False,
        visualize=False,
        save_files=False,
        noise_channels=16,
        resize_frames=FRAME,
        resize_flow=FLOW,
        downscale_factor=round(FRAME * FLOW) * LATENT,
    )

    noises = output.numpy_noises  # (T, H, W, C)
    return _numpy_noise_to_tensor(noises, target_shape, device, dtype)


def _load_warped_noise(
    noise_path: str,
    target_shape: Tuple[int, int, int, int],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Load and resize pre-computed warped noise."""
    noises = np.load(noise_path)
    if noises.dtype == np.float16:
        noises = noises.astype(np.float32)
    # Ensure THWC format
    if noises.shape[1] == 16:  # TCHW -> THWC
        noises = np.transpose(noises, (0, 2, 3, 1))
    return _numpy_noise_to_tensor(noises, target_shape, device, dtype)


def _numpy_noise_to_tensor(
    noises: np.ndarray,
    target_shape: Tuple[int, int, int, int],
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Convert numpy noise (T, H, W, C) to pipeline tensor (1, T, C, H, W)."""
    latent_T, latent_H, latent_W, latent_C = target_shape

    # Temporal resize if needed
    if noises.shape[0] != latent_T:
        indices = np.linspace(0, noises.shape[0] - 1, latent_T)
        lower = np.floor(indices).astype(int)
        upper = np.ceil(indices).astype(int)
        frac = indices - lower
        noises = noises[lower] * (1 - frac[:, None, None, None]) + noises[upper] * frac[:, None, None, None]

    # Spatial resize if needed
    if noises.shape[1] != latent_H or noises.shape[2] != latent_W:
        resized = np.zeros((latent_T, latent_H, latent_W, latent_C), dtype=noises.dtype)
        for t in range(latent_T):
            for c in range(latent_C):
                resized[t, :, :, c] = cv2.resize(
                    noises[t, :, :, c], (latent_W, latent_H),
                    interpolation=cv2.INTER_LINEAR,
                )
        noises = resized

    # Convert to tensor: (T, H, W, C) -> (1, T, C, H, W)
    tensor = torch.from_numpy(noises).permute(0, 3, 1, 2).unsqueeze(0)
    return tensor.to(device=device, dtype=dtype)


class VOIDPipeline(CogVideoXFunInpaintPipeline):
    """
    VOID: Video Object and Interaction Deletion.

    Removes objects and their physical interactions from videos using
    quadmask-conditioned video inpainting.
    """

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        void_pass: int = 1,
        base_model: str = BASE_MODEL_REPO,
        torch_dtype: torch.dtype = torch.bfloat16,
        **kwargs,
    ):
        """
        Load the VOID pipeline.

        Args:
            pretrained_model_name_or_path: HF repo ID or local path containing
                VOID checkpoint files (void_pass1.safetensors, etc.)
            void_pass: Which pass checkpoint to load (1 or 2). Default: 1.
            base_model: HF repo ID for the base CogVideoX-Fun model.
            torch_dtype: Weight dtype. Default: torch.bfloat16.
        """
        if void_pass not in PASS_CHECKPOINTS:
            raise ValueError(f"void_pass must be 1 or 2, got {void_pass}")

        # --- Download VOID checkpoint ---
        checkpoint_name = PASS_CHECKPOINTS[void_pass]
        print(f"[VOID] Loading Pass {void_pass} checkpoint...")

        if os.path.isdir(pretrained_model_name_or_path):
            checkpoint_path = os.path.join(pretrained_model_name_or_path, checkpoint_name)
            if not os.path.exists(checkpoint_path):
                # Check parent dir (checkpoints at root, code in diffusers/)
                checkpoint_path = os.path.join(pretrained_model_name_or_path, "..", checkpoint_name)
        else:
            checkpoint_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename=checkpoint_name,
            )

        # --- Download and load base model ---
        print(f"[VOID] Loading base model: {base_model}")
        base_model_path = snapshot_download(repo_id=base_model)

        # Transformer (with VAE mask channels)
        print("[VOID] Loading transformer...")
        transformer = CogVideoXTransformer3DModel.from_pretrained(
            base_model_path,
            subfolder="transformer",
            low_cpu_mem_usage=True,
            torch_dtype=torch_dtype,
            use_vae_mask=True,
        )

        # Merge VOID weights
        print(f"[VOID] Merging Pass {void_pass} weights...")
        transformer = _merge_void_weights(transformer, checkpoint_path)
        transformer = transformer.to(torch_dtype)

        # VAE
        print("[VOID] Loading VAE...")
        vae = AutoencoderKLCogVideoX.from_pretrained(
            base_model_path, subfolder="vae"
        ).to(torch_dtype)

        # Tokenizer + Text encoder
        print("[VOID] Loading tokenizer and text encoder...")
        from transformers import T5Tokenizer, T5EncoderModel
        tokenizer = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
        text_encoder = T5EncoderModel.from_pretrained(
            base_model_path, subfolder="text_encoder", torch_dtype=torch_dtype,
        )

        # Scheduler
        scheduler = CogVideoXDDIMScheduler.from_pretrained(
            base_model_path, subfolder="scheduler"
        )

        # Build pipeline
        pipe = cls(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            vae=vae,
            transformer=transformer,
            scheduler=scheduler,
        )
        pipe._void_pass = void_pass

        print("[VOID] Pipeline ready!")
        return pipe

    def inpaint(
        self,
        video_path: str,
        mask_path: str,
        prompt: str,
        negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
        height: int = 384,
        width: int = 672,
        num_inference_steps: int = 30,
        guidance_scale: float = 1.0,
        strength: float = 1.0,
        temporal_window_size: int = 85,
        max_video_length: int = 197,
        fps: int = 12,
        seed: int = 42,
        pass1_video: Optional[str] = None,
        warped_noise_path: Optional[str] = None,
        use_quadmask: bool = True,
    ) -> VOIDOutput:
        """
        Run VOID inpainting on a video.

        Args:
            video_path: Path to input video (mp4).
            mask_path: Path to quadmask video (mp4). Grayscale with values:
                0=object to remove, 63=overlap, 127=affected region, 255=background.
            prompt: Text description of the desired result after removal.
                E.g., "A lime falls on the table."
            negative_prompt: Negative prompt for generation quality.
            height: Output height (default 384).
            width: Output width (default 672).
            num_inference_steps: Denoising steps (default 30).
            guidance_scale: CFG scale (default 1.0 = no CFG).
            strength: Denoising strength (default 1.0).
            temporal_window_size: Frames per inference window (default 85).
            max_video_length: Max frames to process (default 197).
            fps: Output FPS (default 12).
            seed: Random seed (default 42).
            pass1_video: Path to Pass 1 output video, for Pass 2 warped noise init.
            warped_noise_path: Path to pre-computed warped noise (.npy).
            use_quadmask: Use 4-value quadmask (default True). Set False for trimask.

        Returns:
            VOIDOutput with .video (uint8) and .save() method.
        """
        sample_size = (height, width)

        # Align video length to VAE temporal compression ratio
        vae_temporal_ratio = self.vae.config.temporal_compression_ratio
        video_length = int((max_video_length - 1) // vae_temporal_ratio * vae_temporal_ratio) + 1

        # --- Load and prep video ---
        print("[VOID] Loading video and mask...")
        vid_np = _load_video(video_path, video_length)
        mask_np = _load_video(mask_path, video_length)

        video = _prep_video_tensor(vid_np, sample_size)
        mask = _prep_mask_tensor(mask_np, sample_size, use_quadmask=use_quadmask)

        # Temporal padding
        video = _temporal_padding(video, min_length=temporal_window_size, max_length=max_video_length)
        mask = _temporal_padding(mask, min_length=temporal_window_size, max_length=max_video_length)

        num_frames = min(video.shape[2], temporal_window_size)

        print(f"[VOID] Video: {video.shape}, Mask: {mask.shape}, Frames: {num_frames}")

        # --- Handle warped noise for Pass 2 ---
        latents = None
        if warped_noise_path is not None or pass1_video is not None:
            latent_T = (num_frames - 1) // 4 + 1
            latent_H = height // 8
            latent_W = width // 8
            latent_C = 16
            target_shape = (latent_T, latent_H, latent_W, latent_C)

            if warped_noise_path is not None:
                print(f"[VOID] Loading pre-computed warped noise from {warped_noise_path}")
                latents = _load_warped_noise(
                    warped_noise_path, target_shape,
                    device=torch.device("cpu"), dtype=torch.bfloat16,
                )
            else:
                print(f"[VOID] Generating warped noise from Pass 1 output...")
                latents = _generate_warped_noise(
                    pass1_video, target_shape,
                    device=torch.device("cpu"), dtype=torch.bfloat16,
                )
            print(f"[VOID] Warped noise: {latents.shape}, mean={latents.mean():.4f}, std={latents.std():.4f}")

        # --- Run inference ---
        generator = torch.Generator(device="cpu").manual_seed(seed)

        print(f"[VOID] Running inference ({num_frames} frames, {num_inference_steps} steps)...")
        with torch.no_grad():
            output = self(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_frames=num_frames,
                height=height,
                width=width,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                generator=generator,
                video=video,
                mask_video=mask,
                strength=strength,
                use_trimask=True,
                use_vae_mask=True,
                latents=latents,
            ).videos

        # --- Process output ---
        if isinstance(output, np.ndarray):
            output = torch.from_numpy(output)

        # output is (B, C, T, H, W) in [0, 1]
        video_float = output
        video_uint8 = (output[0].permute(1, 2, 3, 0).clamp(0, 1) * 255).to(torch.uint8)

        print(f"[VOID] Done! Output: {video_uint8.shape}")
        return VOIDOutput(video=video_uint8, video_float=video_float)