YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

vrthinker

A video reward model that compares two videos against a text prompt and outputs per-dimension preferences: TA (Text Alignment), MQ (Motion Quality), VQ (Visual Quality), OA (Overall). Each label is one of 1 (Video 1 wins), 2 (Video 2 wins), 0 (tie).

The model reasons step-by-step and may call a select_frames tool to request additional frames from the videos before committing to an answer.

Install

pip install torch transformers accelerate pillow opencv-python

Inference

Save the snippet below as infer.py, then:

python infer.py --video1 path/to/v1.mp4 --video2 path/to/v2.mp4 \
                --prompt "A robot rides a unicorn across a rainbow bridge."
# infer.py
import argparse, json, re
from pathlib import Path

import cv2
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

MODEL_DIR = str(Path(__file__).resolve().parent)        # the dir containing this README
FRAMES_PER_VIDEO = 128
INITIAL_PER_VIDEO = 4
MAX_TURNS = 6
MAX_FRAMES_PER_CALL = 12
IMAGE_SIDE = 448

SYSTEM_PROMPT = """Task Description:
Your task is to compare two videos generated based on the same text prompt by analyzing their frames in detail and provide an overall judgment along with a judgment for each evaluation dimension.

The provided frames are downsampled from these videos:
- Video 1: First four input frames.
- Video 2: Next four input frames.

Evaluation Dimensions:
1. Text Alignment (TA): How faithfully each video reflects the text prompt.
2. Visual Quality (VQ): Aesthetics, artifacts, blurriness, distortion, color, resolution, flickering.
3. Motion Quality (MQ): Smoothness, jitter, unnatural motion, temporal consistency.
4. Overall Assessment (OA): Holistic judgment across the above.

Frames and Analysis Rules:
- 8 sampled frames are provided initially (4 per video), evenly downsampled from 128 frames per video. The first 4 are Video 1, the next 4 are Video 2.
- Each video has 128 frames (indices 0-127). To inspect more frames, call select_frames with the indices you need; the tool retrieves the same indices from both videos symmetrically.
- Tool returns are paired: for [i, j, k] you get (v1[i], v2[i], v1[j], v2[j], v1[k], v2[k]). Use this pairing to compare the same moment across both videos.
- Each tool call accepts at most 12 indices.

Format Requirement:
1. <Snapshot></Snapshot> โ€” summarize useful visual details after receiving frames.
2. <Think></Think> โ€” reasoning.
3. <Answer></Answer> โ€” final judgment.

Label semantics: 1 = Video 1 better, 2 = Video 2 better, 0 = tie.

Examples:
<Answer>TA=1, VQ=1, MQ=0, OA=1</Answer>

Tool call format:
When you want to inspect more frames, emit a tool call inside <tool_call></tool_call> tags:
<tool_call>{"name": "select_frames", "arguments": {"frame_indices": [10, 30, 60, 90]}}</tool_call>
"""


def extract_frames(video_path: str, indices: list[int]) -> list[Image.Image]:
    """Return PIL frames at the given indices, evenly mapped over the video's actual length."""
    cap = cv2.VideoCapture(video_path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    out: list[Image.Image] = []
    for idx in indices:
        # map idx in [0, FRAMES_PER_VIDEO) -> real frame in [0, total)
        real = min(int(idx / FRAMES_PER_VIDEO * total), total - 1)
        cap.set(cv2.CAP_PROP_POS_FRAMES, real)
        ok, frame = cap.read()
        if not ok:
            continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame).resize((IMAGE_SIDE, IMAGE_SIDE))
        out.append(img)
    cap.release()
    return out


def initial_frames(v1: str, v2: str) -> list[Image.Image]:
    idxs = [int(FRAMES_PER_VIDEO * (i + 0.5) / INITIAL_PER_VIDEO) for i in range(INITIAL_PER_VIDEO)]
    return extract_frames(v1, idxs) + extract_frames(v2, idxs)


def tool_frames(v1: str, v2: str, indices: list[int]) -> list[Image.Image]:
    indices = indices[:MAX_FRAMES_PER_CALL]
    out: list[Image.Image] = []
    for i in indices:
        out += extract_frames(v1, [i])
        out += extract_frames(v2, [i])
    return out


def parse_tool_call(text: str) -> dict | None:
    m = re.search(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", text, re.DOTALL)
    if not m:
        return None
    try:
        obj = json.loads(m.group(1))
        return obj.get("arguments", {})
    except json.JSONDecodeError:
        return None


def parse_answer(text: str) -> dict | None:
    m = re.search(r"<Answer>(.*?)</Answer>", text, re.DOTALL | re.IGNORECASE)
    if not m:
        return None
    body = m.group(1)
    return {d: int(re.search(rf"{d}\s*=\s*(\d)", body).group(1))
            for d in ("TA", "MQ", "VQ", "OA")
            if re.search(rf"{d}\s*=\s*(\d)", body)}


@torch.inference_mode()
def run(video1: str, video2: str, prompt: str) -> dict:
    processor = AutoProcessor.from_pretrained(MODEL_DIR, trust_remote_code=True)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_DIR, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
    ).eval()

    images = initial_frames(video1, video2)
    user_text = (
        f"Compare the two videos generated from the following prompt and evaluate them "
        f"across Text Alignment (TA), Motion Quality (MQ), Visual Quality (VQ), and "
        f"Overall Assessment (OA).\n\nPrompt: {prompt}\n\n"
        f"The first 4 images are uniformly sampled from Video 1, and the next 4 are from "
        f"Video 2. Each video has 128 frames (indices 0-127). "
        f"Use the select_frames tool to request additional frames if needed."
    )
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user", "content": [{"type": "image"}] * len(images)
                                    + [{"type": "text", "text": user_text}]},
    ]

    for turn in range(MAX_TURNS):
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=[text], images=images, return_tensors="pt", padding=True).to(model.device)
        output_ids = model.generate(**inputs, max_new_tokens=2048, do_sample=False, temperature=0.0)
        reply = processor.batch_decode(
            output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
        )[0]

        messages.append({"role": "assistant", "content": [{"type": "text", "text": reply}]})

        answer = parse_answer(reply)
        if answer:
            return answer

        call = parse_tool_call(reply)
        if not call or "frame_indices" not in call:
            return parse_answer(reply) or {"TA": None, "MQ": None, "VQ": None, "OA": None}

        new_imgs = tool_frames(video1, video2, call["frame_indices"])
        images += new_imgs
        messages.append({
            "role": "user",
            "content": [{"type": "image"}] * len(new_imgs)
                       + [{"type": "text",
                           "text": f"<tool_response>Retrieved {len(call['frame_indices'])} "
                                   f"frame pairs ({call['frame_indices']}) symmetrically from both "
                                   f"videos.</tool_response>"}],
        })

    return {"TA": None, "MQ": None, "VQ": None, "OA": None}


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--video1", required=True)
    p.add_argument("--video2", required=True)
    p.add_argument("--prompt", required=True)
    args = p.parse_args()
    print(json.dumps(run(args.video1, args.video2, args.prompt), indent=2))

Output

{
  "TA": 1,
  "MQ": 0,
  "VQ": 2,
  "OA": 1
}

1 = Video 1 wins on that dimension, 2 = Video 2 wins, 0 = tie.

Hardware

Requires ~16 GB GPU memory in bf16. Tested on a single A100/H100.

Downloads last month
24
Safetensors
Model size
8B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support