from typing import ClassVar, List, Optional, Tuple, Union import torch from PIL import Image from transformers import BatchFeature, PaliGemmaProcessor from .processing_utils import BaseVisualRetrieverProcessor class ColPaliProcessor(BaseVisualRetrieverProcessor, PaliGemmaProcessor): """ Processor for ColPali. """ visual_prompt_prefix: ClassVar[str] = "Describe the image." query_prefix: ClassVar[str] = "Query: " def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property def query_augmentation_token(self) -> str: """ Return the query augmentation token. Query augmentation buffers are used as reasoning buffers during inference. """ return self.tokenizer.pad_token def process_images( self, images: List[Image.Image], ) -> BatchFeature: """ Process images for ColPali. """ texts_doc = [self.visual_prompt_prefix] * len(images) images = [image[0].convert("RGB") for image in images] # @ruimeng, ColPali is buggy processing images of which size<=3 images = [i.resize((max(i.size), max(i.size))) if min(i.size) < 5 else i for i in images] batch_doc = self( text=texts_doc, images=images, return_tensors="pt", padding="longest", ) return batch_doc def process_queries( self, queries: List[str], max_length: int = 50, suffix: Optional[str] = None, ) -> BatchFeature: """ Process queries for ColPali. """ if suffix is None: suffix = self.query_augmentation_token * 10 texts_query: List[str] = [] for query in queries: query = self.tokenizer.bos_token + self.query_prefix + query query += suffix # add suffix (pad tokens) # NOTE: Make input ISO to PaliGemma's processor query += "\n" texts_query.append(query) batch_query = self.tokenizer( texts_query, text_pair=None, return_token_type_ids=False, return_tensors="pt", padding="longest", max_length=max_length, ) return batch_query def score( self, qs: List[torch.Tensor], ps: List[torch.Tensor], device: Optional[Union[str, torch.device]] = None, **kwargs, ) -> torch.Tensor: """ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. """ return self.score_multi_vector(qs, ps, device=device, **kwargs) def get_n_patches( self, image_size: Tuple[int, int], patch_size: int, ) -> Tuple[int, int]: n_patches_x = self.image_processor.size["width"] // patch_size n_patches_y = self.image_processor.size["height"] // patch_size return n_patches_x, n_patches_y def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: return batch_images.input_ids == self.image_token_id