| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional, Union |
| |
|
| | import numpy as np |
| |
|
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.image_utils import ImageInput |
| | from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs |
| | from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| | from transformers.video_utils import VideoInput |
| |
|
| |
|
| | class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): |
| | fps: Union[list[float], float] |
| |
|
| |
|
| | class Qwen2_5_VLImagesKwargs(ImagesKwargs): |
| | min_pixels: Optional[int] |
| | max_pixels: Optional[int] |
| | patch_size: Optional[int] |
| | temporal_patch_size: Optional[int] |
| | merge_size: Optional[int] |
| |
|
| |
|
| | class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): |
| | images_kwargs: Qwen2_5_VLImagesKwargs |
| | videos_kwargs: Qwen2_5_VLVideosProcessorKwargs |
| | _defaults = { |
| | "text_kwargs": { |
| | "padding": False, |
| | "return_mm_token_type_ids": False, |
| | }, |
| | } |
| |
|
| |
|
| | class Qwen2_5_VLProcessor(ProcessorMixin): |
| | r""" |
| | Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. |
| | [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the |
| | [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. |
| | Args: |
| | image_processor ([`Qwen2VLImageProcessor`], *optional*): |
| | The image processor is a required input. |
| | tokenizer ([`Qwen2TokenizerFast`], *optional*): |
| | The tokenizer is a required input. |
| | video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*): |
| | The video processor is a required input. |
| | chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages |
| | in a chat into a tokenizable string. |
| | """ |
| |
|
| | attributes = ["image_processor", "tokenizer", "video_processor"] |
| |
|
| | image_processor_class = "AutoImageProcessor" |
| | video_processor_class = "AutoVideoProcessor" |
| | tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
| |
|
| | def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): |
| | self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token |
| | self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token |
| | self.image_token_id = ( |
| | tokenizer.image_token_id |
| | if getattr(tokenizer, "image_token_id", None) |
| | else tokenizer.convert_tokens_to_ids(self.image_token) |
| | ) |
| | self.video_token_id = ( |
| | tokenizer.video_token_id |
| | if getattr(tokenizer, "video_token_id", None) |
| | else tokenizer.convert_tokens_to_ids(self.video_token) |
| | ) |
| | super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) |
| |
|
| | def __call__( |
| | self, |
| | images: Optional[ImageInput] = None, |
| | text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, |
| | videos: Optional[VideoInput] = None, |
| | **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], |
| | ) -> BatchFeature: |
| | """ |
| | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` |
| | and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode |
| | the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwargs` arguments to |
| | Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. |
| | |
| | Args: |
| | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): |
| | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
| | tensor. Both channels-first and channels-last formats are supported. |
| | text (`str`, `list[str]`, `list[list[str]]`): |
| | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
| | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
| | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
| | videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): |
| | The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch |
| | tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. |
| | return_tensors (`str` or [`~utils.TensorType`], *optional*): |
| | If set, will return tensors of a particular framework. Acceptable values are: |
| | - `'tf'`: Return TensorFlow `tf.constant` objects. |
| | - `'pt'`: Return PyTorch `torch.Tensor` objects. |
| | - `'np'`: Return NumPy `np.ndarray` objects. |
| | - `'jax'`: Return JAX `jnp.ndarray` objects. |
| | |
| | Returns: |
| | [`BatchFeature`]: A [`BatchFeature`] with the following fields: |
| | |
| | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
| | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
| | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
| | `None`). |
| | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
| | - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. |
| | - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. |
| | - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. |
| | - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. |
| | """ |
| | output_kwargs = self._merge_kwargs( |
| | Qwen2_5_VLProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | image_inputs = videos_inputs = {} |
| | if images is not None: |
| | image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) |
| | image_grid_thw = image_inputs["image_grid_thw"] |
| |
|
| | if videos is not None: |
| | fps = output_kwargs["videos_kwargs"].get("fps", 2.0) |
| | videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) |
| | video_grid_thw = videos_inputs["video_grid_thw"] |
| |
|
| | if isinstance(fps, (int, float)): |
| | second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) |
| | elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): |
| | second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] |
| | else: |
| | raise ValueError( |
| | f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." |
| | ) |
| | videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) |
| |
|
| | if not isinstance(text, list): |
| | text = [text] |
| |
|
| | text = text.copy() |
| | if images is not None: |
| | merge_length = self.image_processor.merge_size**2 |
| | index = 0 |
| | for i in range(len(text)): |
| | while self.image_token in text[i]: |
| | num_image_tokens = image_grid_thw[index].prod() // merge_length |
| | text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) |
| | index += 1 |
| | text[i] = text[i].replace("<|placeholder|>", self.image_token) |
| |
|
| | if videos is not None: |
| | merge_length = self.video_processor.merge_size**2 |
| | index = 0 |
| | for i in range(len(text)): |
| | while self.video_token in text[i]: |
| | num_video_tokens = video_grid_thw[index].prod() // merge_length |
| | text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) |
| | index += 1 |
| | text[i] = text[i].replace("<|placeholder|>", self.video_token) |
| |
|
| | return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
| | return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) |
| | text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
| | self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) |
| |
|
| | if return_mm_token_type_ids: |
| | array_ids = np.array(text_inputs["input_ids"]) |
| | mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) |
| | mm_token_type_ids[array_ids == self.image_token_id] = 1 |
| | text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() |
| |
|
| | return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) |
| |
|
| | def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): |
| | """ |
| | Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. |
| | Args: |
| | image_sizes (`list[list[int]]`, *optional*): |
| | The input sizes formatted as (height, width) per each image. |
| | video_sizes (`list[list[int]]`, *optional*): |
| | The input sizes formatted as (num_frames, height, width) per each video. |
| | Returns: |
| | `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided |
| | input modalities, along with other useful data. |
| | """ |
| |
|
| | vision_data = {} |
| | if image_sizes is not None: |
| | images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {}) |
| | images_kwargs.update(kwargs) |
| | merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size |
| |
|
| | num_image_patches = [ |
| | self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) |
| | for image_size in image_sizes |
| | ] |
| | num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] |
| | vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) |
| |
|
| | if video_sizes is not None: |
| | videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {}) |
| | videos_kwargs.update(kwargs) |
| | num_video_patches = [ |
| | self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) |
| | for video_size in video_sizes |
| | ] |
| | num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] |
| | vision_data["num_video_tokens"] = num_video_tokens |
| |
|
| | return MultiModalData(**vision_data) |
| |
|
| | def post_process_image_text_to_text( |
| | self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs |
| | ): |
| | """ |
| | Post-process the output of the model to decode the text. |
| | |
| | Args: |
| | generated_outputs (`torch.Tensor` or `np.ndarray`): |
| | The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` |
| | or `(sequence_length,)`. |
| | skip_special_tokens (`bool`, *optional*, defaults to `True`): |
| | Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. |
| | clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): |
| | Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. |
| | **kwargs: |
| | Additional arguments to be passed to the tokenizer's `batch_decode method`. |
| | |
| | Returns: |
| | `list[str]`: The decoded text. |
| | """ |
| | return self.tokenizer.batch_decode( |
| | generated_outputs, |
| | skip_special_tokens=skip_special_tokens, |
| | clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
| | **kwargs, |
| | ) |
| |
|
| | @property |
| | def model_input_names(self): |
| | tokenizer_input_names = self.tokenizer.model_input_names |
| | image_processor_input_names = self.image_processor.model_input_names |
| | names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
| | return names_from_processor + ["second_per_grid_ts"] |
| |
|
| |
|
| | __all__ = ["Qwen2_5_VLProcessor"] |
| |
|