InternVideo3-8B-Instruct / processing_internvideo3.py
yanziang's picture
Upload folder using huggingface_hub
e3bb923 verified
# coding=utf-8
# Copyright 2025 The InternVideo Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor class for InternVideo3."""
from typing import Optional, Union
import numpy as np
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
logger = logging.get_logger(__name__)
class InternVideo3VideosProcessorKwargs(VideosKwargs, total=False):
pass
class InternVideo3ImagesKwargs(ImagesKwargs):
min_pixels: Optional[int]
max_pixels: Optional[int]
patch_size: Optional[int]
temporal_patch_size: Optional[int]
merge_size: Optional[int]
class InternVideo3ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: InternVideo3ImagesKwargs
videos_kwargs: InternVideo3VideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_token_type_ids": False,
},
"videos_kwargs": {"return_metadata": True},
}
class InternVideo3Processor(ProcessorMixin):
r"""
Constructs an InternVideo3 processor which wraps an image processor, a video processor,
and a tokenizer into a single processor.
Args:
image_processor: The image processor.
tokenizer: The tokenizer.
video_processor: The video processor.
chat_template (`str`, *optional*): A Jinja template for chat formatting.
"""
attributes = ["image_processor", "tokenizer", "video_processor"]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.vision_start_token = (
"<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
)
self.vision_end_token = (
"<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
)
self.vision_start_token_id = (
tokenizer.vision_start_token_id
if getattr(tokenizer, "vision_start_token_id", None)
else tokenizer.convert_tokens_to_ids(self.vision_start_token)
)
self.vision_end_token_id = (
tokenizer.vision_end_token_id
if getattr(tokenizer, "vision_end_token_id", None)
else tokenizer.convert_tokens_to_ids(self.vision_end_token)
)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
videos: VideoInput = None,
**kwargs: Unpack[InternVideo3ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare inputs for the model.
Args:
images: The image or batch of images to be prepared.
text: The sequence or batch of sequences to be encoded.
videos: The video or batch of videos to be prepared.
return_tensors: If set, will return tensors of a particular framework.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- Token ids to be fed to a model.
- **attention_mask** -- Attention mask.
- **pixel_values** -- Pixel values for images.
- **pixel_values_videos** -- Pixel values for videos.
- **image_grid_thw** -- Image 3D grid dimensions.
- **video_grid_thw** -- Video 3D grid dimensions.
"""
output_kwargs = self._merge_kwargs(
InternVideo3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
# If user has not requested video metadata, pop it
if "return_metadata" not in kwargs:
video_metadata = videos_inputs.pop("video_metadata", None)
else:
video_metadata = videos_inputs.get("video_metadata", None)
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}
video_grid_thw = None
video_metadata = None
if not isinstance(text, list):
text = [text]
text = text.copy()
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
num_image_tokens = image_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.video_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
metadata = video_metadata[index] if video_metadata else None
if metadata is not None:
if metadata.fps is None:
logger.warning_once(
"InternVideo3 requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
"Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
)
metadata.fps = 24
curr_timestamp = self._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
self.video_processor.merge_size,
)
video_placeholder = ""
frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
for frame_idx in range(video_grid_thw[index][0]):
curr_time = curr_timestamp[frame_idx]
video_placeholder += f"<{curr_time:.1f} seconds>"
video_placeholder += (
self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
)
if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
text[i] = text[i].replace(
f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
)
else:
text[i] = text[i].replace(self.video_token, video_placeholder, 1)
else:
num_video_tokens = video_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2):
if not isinstance(indices, list):
indices = indices.tolist()
if len(indices) % merge_size != 0:
indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
timestamps = [idx / video_fps for idx in indices]
timestamps = [
(timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
]
return timestamps
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
__all__ = ["InternVideo3Processor"]