| from typing import ClassVar, List, Optional, Tuple, Union |
|
|
| import torch |
| from PIL import Image |
| from transformers import BatchFeature, PaliGemmaProcessor |
|
|
| from .processing_utils import BaseVisualRetrieverProcessor |
|
|
|
|
| class ColPaliProcessor(BaseVisualRetrieverProcessor, PaliGemmaProcessor): |
| """ |
| Processor for ColPali. |
| """ |
|
|
| visual_prompt_prefix: ClassVar[str] = "<image><bos>Describe the image." |
| query_prefix: ClassVar[str] = "Query: " |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| @property |
| def query_augmentation_token(self) -> str: |
| """ |
| Return the query augmentation token. |
| Query augmentation buffers are used as reasoning buffers during inference. |
| """ |
| return self.tokenizer.pad_token |
|
|
| def process_images( |
| self, |
| images: List[Image.Image], |
| ) -> BatchFeature: |
| """ |
| Process images for ColPali. |
| """ |
| texts_doc = [self.visual_prompt_prefix] * len(images) |
| images = [image[0].convert("RGB") for image in images] |
| |
| images = [i.resize((max(i.size), max(i.size))) if min(i.size) < 5 else i for i in images] |
| batch_doc = self( |
| text=texts_doc, |
| images=images, |
| return_tensors="pt", |
| padding="longest", |
| ) |
| return batch_doc |
|
|
| def process_queries( |
| self, |
| queries: List[str], |
| max_length: int = 50, |
| suffix: Optional[str] = None, |
| ) -> BatchFeature: |
| """ |
| Process queries for ColPali. |
| """ |
|
|
| if suffix is None: |
| suffix = self.query_augmentation_token * 10 |
| texts_query: List[str] = [] |
|
|
| for query in queries: |
| query = self.tokenizer.bos_token + self.query_prefix + query |
| query += suffix |
|
|
| |
| query += "\n" |
|
|
| texts_query.append(query) |
|
|
| batch_query = self.tokenizer( |
| texts_query, |
| text_pair=None, |
| return_token_type_ids=False, |
| return_tensors="pt", |
| padding="longest", |
| max_length=max_length, |
| ) |
|
|
| return batch_query |
|
|
| def score( |
| self, |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| device: Optional[Union[str, torch.device]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """ |
| Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. |
| """ |
| return self.score_multi_vector(qs, ps, device=device, **kwargs) |
|
|
| def get_n_patches( |
| self, |
| image_size: Tuple[int, int], |
| patch_size: int, |
| ) -> Tuple[int, int]: |
| n_patches_x = self.image_processor.size["width"] // patch_size |
| n_patches_y = self.image_processor.size["height"] // patch_size |
|
|
| return n_patches_x, n_patches_y |
|
|
| def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: |
| return batch_images.input_ids == self.image_token_id |
|
|