| import re |
| import torch |
| from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.image_utils import ImageInput |
| from typing import Any, Dict, List, Optional, Union |
| from PIL import Image |
|
|
| from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
|
| |
| def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor: |
| """ |
| Expands an image to a square by adding a background color. |
| """ |
| c, height, width = image.shape |
| if width == height: |
| return image |
| elif width > height: |
| result = torch.ones((c, width, width), dtype=image.dtype) * background_color |
| result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image |
| return result |
| else: |
| result = torch.ones((c, height, height), dtype=image.dtype) * background_color |
| result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image |
| return result |
|
|
|
|
| class FastVLMImageProcessor(CLIPImageProcessorFast): |
| def _preprocess(self, images, **kwargs): |
| image_sizes = [image.shape[-2:][::-1] for image in images] |
| images = [expand_to_square(image) for image in images] |
| images = super()._preprocess(images, **kwargs) |
| pixel_values = torch.stack(images.pixel_values, dim=0) |
| return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes}) |
|
|
| class FastVLMProcessor(ProcessorMixin): |
| attributes = ["tokenizer", "image_processor"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| def __init__( |
| self, |
| tokenizer, |
| image_processor, |
| chat_template=None, |
| **kwargs |
| ): |
| super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs) |
|
|
| def __call__( |
| self, |
| images: ImageInput = None, |
| text: Optional[Union[str, List[str]]] = None, |
| return_tensors: Optional[str] = "pt", |
| **kwargs, |
| ) -> BatchFeature: |
| if isinstance(text, str): |
| text = [text] |
| elif not isinstance(text, list) and not isinstance(text[0], str): |
| raise TypeError("Invalid input text. Please provide a string, or a list of strings") |
|
|
| image_inputs = {} |
| if images is not None: |
| image_inputs = self.image_processor(images=images) |
|
|
| image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64) |
| input_ids = torch.tensor([], dtype=torch.int64) |
| attention_mask = torch.tensor([], dtype=torch.int64) |
| for prompt in text: |
| image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)] |
| if len(image_indexes) > 1: |
| raise ValueError( |
| f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead." |
| ) |
|
|
| |
| pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN) |
| pre_ids = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids |
| post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
| sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64) |
| sample_mask = torch.ones_like(sample_ids) |
|
|
| input_ids = torch.cat([input_ids, sample_ids], dim=0) |
| attention_mask = torch.cat([attention_mask, sample_mask], dim=0) |
|
|
| return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors) |
|
|