code_SAS_VLM2Vec / src /model /baseline_backbone /colpali /processing_colpali.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from typing import ClassVar, List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import BatchFeature, PaliGemmaProcessor
from .processing_utils import BaseVisualRetrieverProcessor
class ColPaliProcessor(BaseVisualRetrieverProcessor, PaliGemmaProcessor):
"""
Processor for ColPali.
"""
visual_prompt_prefix: ClassVar[str] = "<image><bos>Describe the image."
query_prefix: ClassVar[str] = "Query: "
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def query_augmentation_token(self) -> str:
"""
Return the query augmentation token.
Query augmentation buffers are used as reasoning buffers during inference.
"""
return self.tokenizer.pad_token
def process_images(
self,
images: List[Image.Image],
) -> BatchFeature:
"""
Process images for ColPali.
"""
texts_doc = [self.visual_prompt_prefix] * len(images)
images = [image[0].convert("RGB") for image in images]
# @ruimeng, ColPali is buggy processing images of which size<=3
images = [i.resize((max(i.size), max(i.size))) if min(i.size) < 5 else i for i in images]
batch_doc = self(
text=texts_doc,
images=images,
return_tensors="pt",
padding="longest",
)
return batch_doc
def process_queries(
self,
queries: List[str],
max_length: int = 50,
suffix: Optional[str] = None,
) -> BatchFeature:
"""
Process queries for ColPali.
"""
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
for query in queries:
query = self.tokenizer.bos_token + self.query_prefix + query
query += suffix # add suffix (pad tokens)
# NOTE: Make input ISO to PaliGemma's processor
query += "\n"
texts_query.append(query)
batch_query = self.tokenizer(
texts_query,
text_pair=None,
return_token_type_ids=False,
return_tensors="pt",
padding="longest",
max_length=max_length,
)
return batch_query
def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
return self.score_multi_vector(qs, ps, device=device, **kwargs)
def get_n_patches(
self,
image_size: Tuple[int, int],
patch_size: int,
) -> Tuple[int, int]:
n_patches_x = self.image_processor.size["width"] // patch_size
n_patches_y = self.image_processor.size["height"] // patch_size
return n_patches_x, n_patches_y
def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
return batch_images.input_ids == self.image_token_id