| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Processor class for Florence-2. |
| | """ |
| |
|
| | import re |
| | import logging |
| | from typing import List, Optional, Union |
| | import numpy as np |
| |
|
| | import torch |
| | import PIL |
| |
|
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.image_utils import ImageInput |
| | from transformers.processing_utils import ProcessorMixin |
| | from transformers.tokenization_utils_base import ( |
| | PaddingStrategy, |
| | TextInput, |
| | TruncationStrategy, |
| | ) |
| | from transformers.utils import TensorType |
| | import re |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class Florence2Processor(ProcessorMixin): |
| | attributes = ["image_processor", "tokenizer"] |
| | image_processor_class = "CLIPImageProcessor" |
| | tokenizer_class = ("BartTokenizer", "BartTokenizerFast") |
| |
|
| | def __init__( |
| | self, |
| | image_processor=None, |
| | tokenizer=None, |
| | ): |
| | if image_processor is None: |
| | raise ValueError("You need to specify an `image_processor`.") |
| | if tokenizer is None: |
| | raise ValueError("You need to specify a `tokenizer`.") |
| | if not hasattr(image_processor, "image_seq_length"): |
| | raise ValueError("Image processor is missing an `image_seq_length` attribute.") |
| |
|
| | self.image_seq_length = image_processor.image_seq_length |
| |
|
| | tokens_to_add = { |
| | 'additional_special_tokens': \ |
| | tokenizer.additional_special_tokens + \ |
| | ['<od>', '</od>', '<ocr>', '</ocr>'] + \ |
| | [f'<loc_{x}>' for x in range(1000)] + \ |
| | ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>'] + \ |
| | ['<panel>', '<text>', '<character>', '<tail>'] |
| | } |
| | tokenizer.add_special_tokens(tokens_to_add) |
| | self.decoder_start_token_id = 2 |
| |
|
| | self.box_quantizer = BoxQuantizer( |
| | mode='floor', |
| | bins=(1000, 1000), |
| | ) |
| |
|
| | super().__init__(image_processor, tokenizer) |
| | |
| | def __call__( |
| | self, |
| | batch_input_text: List[TextInput] = None, |
| | batch_input_list_of_list_of_bboxes: List[List[List[List[float]]]] = None, |
| | batch_output_text: List[TextInput] = None, |
| | batch_output_list_of_list_of_bboxes: List[List[List[List[float]]]] = None, |
| | batch_images: ImageInput = None, |
| | batch_character_cluster_labels = None, |
| | batch_text_character_association_labels = None, |
| | batch_text_tail_association_labels = None, |
| | batch_is_essential_text_labels = None, |
| | batch_tail_character_association_labels = None, |
| | padding: Union[bool, str, PaddingStrategy] = None, |
| | truncation: Union[bool, str, TruncationStrategy] = None, |
| | max_input_length_including_image_tokens=None, |
| | max_output_length=None, |
| | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| | do_resize: bool = None, |
| | do_normalize: bool = None, |
| | image_mean: Optional[Union[float, List[float]]] = None, |
| | image_std: Optional[Union[float, List[float]]] = None, |
| | data_format: Optional["ChannelDimension"] = "channels_first", |
| | input_data_format: Optional[ |
| | Union[str, "ChannelDimension"] |
| | ] = None, |
| | resample: "PILImageResampling" = None, |
| | do_convert_rgb: bool = None, |
| | dtype: torch.dtype = None, |
| | device: torch.device = None, |
| | ) -> BatchFeature: |
| |
|
| | assert batch_images is not None, "`batch_images` are expected as arguments to a `Florence2Processor` instance." |
| | assert batch_input_text is not None, "`batch_input_text` are expected as arguments to a `Florence2Processor` instance." |
| | if batch_input_list_of_list_of_bboxes is None: |
| | batch_input_list_of_list_of_bboxes = [[] for _ in range(len(batch_input_text))] |
| | assert len(batch_input_text) == len(batch_input_list_of_list_of_bboxes) == len(batch_images), "`batch_input_text`, `batch_input_list_of_list_of_bboxes` and `batch_images` have different lengths." |
| | if batch_output_text is None: |
| | assert batch_output_list_of_list_of_bboxes is None, "`batch_output_text` and `batch_output_list_of_list_of_bboxes` should be provided together." |
| | else: |
| | if batch_output_list_of_list_of_bboxes is None: |
| | batch_output_list_of_list_of_bboxes = [[] for _ in range(len(batch_output_text))] |
| | assert len(batch_output_text) == len(batch_output_list_of_list_of_bboxes) == len(batch_images), "`batch_output_text`, `batch_output_list_of_list_of_bboxes` and `batch_images` have different lengths." |
| |
|
| | max_input_length = max_input_length_including_image_tokens - self.image_seq_length if max_input_length_including_image_tokens is not None else None |
| | batch_input_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(batch_input_text, batch_input_list_of_list_of_bboxes, batch_images)] |
| | inputs = self.tokenizer( |
| | batch_input_texts, |
| | return_tensors=return_tensors, |
| | padding=padding, |
| | truncation=False, |
| | ) |
| | |
| | if inputs["input_ids"].shape[1] > max_input_length: |
| | inputs["input_ids"] = inputs["input_ids"][:, :max_input_length] |
| | inputs["attention_mask"] = inputs["attention_mask"][:, :max_input_length] |
| | |
| | if batch_output_text is not None: |
| | batch_output_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(batch_output_text, batch_output_list_of_list_of_bboxes, batch_images)] |
| | decoder_inputs = self.tokenizer( |
| | batch_output_texts, |
| | return_tensors=return_tensors, |
| | padding=padding, |
| | truncation=False, |
| | ) |
| | |
| | if decoder_inputs["input_ids"].shape[1] > max_output_length: |
| | decoder_inputs["input_ids"] = decoder_inputs["input_ids"][:, :max_output_length] |
| | decoder_inputs["attention_mask"] = decoder_inputs["attention_mask"][:, :max_output_length] |
| | |
| |
|
| | pixel_values = self.image_processor( |
| | batch_images, |
| | do_resize=do_resize, |
| | do_normalize=do_normalize, |
| | return_tensors=return_tensors, |
| | image_mean=image_mean, |
| | image_std=image_std, |
| | input_data_format=input_data_format, |
| | data_format=data_format, |
| | resample=resample, |
| | do_convert_rgb=do_convert_rgb, |
| | )["pixel_values"] |
| |
|
| | if dtype is not None: |
| | pixel_values = pixel_values.to(dtype) |
| | |
| | return_data = {**inputs, "pixel_values": pixel_values} |
| |
|
| | if batch_output_text is not None: |
| | labels = decoder_inputs["input_ids"] |
| | decoder_input_ids = labels.new_zeros(labels.shape) |
| | decoder_input_ids[:, 1:] = labels[:, :-1].clone() |
| | decoder_input_ids[:, 0] = self.decoder_start_token_id |
| | decoder_attention_mask = decoder_inputs["attention_mask"].new_ones(decoder_input_ids.shape) |
| | decoder_attention_mask[:, 1:] = decoder_inputs["attention_mask"][:, :-1].clone() |
| | |
| | labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100) |
| | return_data.update({ |
| | "labels": labels, |
| | "decoder_input_ids": decoder_input_ids, |
| | "decoder_attention_mask": decoder_attention_mask, |
| | }) |
| | |
| | if device is not None: |
| | for key, value in return_data.items(): |
| | if isinstance(value, torch.Tensor): |
| | return_data[key] = value.to(device) |
| |
|
| | if batch_character_cluster_labels is not None: |
| | return_data["character_cluster_labels"] = batch_character_cluster_labels |
| | if batch_text_character_association_labels is not None: |
| | return_data["text_character_association_labels"] = batch_text_character_association_labels |
| | if batch_text_tail_association_labels is not None: |
| | return_data["text_tail_association_labels"] = batch_text_tail_association_labels |
| | if batch_is_essential_text_labels is not None: |
| | return_data["is_essential_text_labels"] = batch_is_essential_text_labels |
| | if batch_tail_character_association_labels is not None: |
| | return_data["tail_character_association_labels"] = batch_tail_character_association_labels |
| |
|
| | return_data["tokenizer"] = self.tokenizer |
| | return BatchFeature(data=return_data) |
| |
|
| | def cleanup_generated_text(self, generated_text): |
| | return generated_text.replace("<s>", "").replace("</s>", "").replace("<pad>", "") |
| |
|
| | def postprocess_output(self, generated_ids, images): |
| | generated_ids.masked_fill_(generated_ids == -100, self.tokenizer.pad_token_id) |
| | batch_decoded_texts = self.batch_decode(generated_ids, skip_special_tokens=False) |
| | batch_decoded_texts = [self.cleanup_generated_text(text) for text in batch_decoded_texts] |
| | batch_list_of_list_of_bboxes = [] |
| | batch_indices_of_bboxes_in_new_string = [] |
| | batch_new_texts = [] |
| | for text, image in zip(batch_decoded_texts, images): |
| | size_wh = self._get_image_size_wh(image) |
| | parsed_text, list_of_stringified_bboxes, start_end_in_new_string = self._parse_text_with_bboxes(text) |
| | list_of_list_of_bboxes = [self.box_quantizer.dequantize_from_stringified_bboxes(stringified_bbox, size_wh) for stringified_bbox in list_of_stringified_bboxes] |
| | batch_list_of_list_of_bboxes.append(list_of_list_of_bboxes) |
| | batch_indices_of_bboxes_in_new_string.append(start_end_in_new_string) |
| | batch_new_texts.append(parsed_text) |
| | return batch_new_texts, batch_list_of_list_of_bboxes, batch_indices_of_bboxes_in_new_string |
| |
|
| | def _parse_text_with_bboxes(self, text): |
| | loc_pattern = r'((?:<loc_\d+>){4}(?:,(?:<loc_\d+>){4})*)' |
| | grounding_pattern = r'<grounding>(.*?)</grounding>' + loc_pattern |
| | |
| | list_of_stringified_bboxes = [] |
| | start_end_in_new_string = [] |
| | new_text = "" |
| | original_pos = 0 |
| | new_pos = 0 |
| |
|
| | for match in re.finditer(grounding_pattern + '|' + loc_pattern, text): |
| | |
| | new_text += text[original_pos:match.start()] |
| | new_pos += match.start() - original_pos |
| |
|
| | if match.group(0).startswith('<grounding>'): |
| | |
| | grounding_text = match.group(1) |
| | locs = match.group(2) |
| | new_text += grounding_text |
| | list_of_stringified_bboxes.append(locs) |
| | start_end_in_new_string.append((new_pos, new_pos + len(grounding_text))) |
| | new_pos += len(grounding_text) |
| | else: |
| | |
| | locs = match.group(0) |
| | replacement = "" |
| | new_text += replacement |
| | list_of_stringified_bboxes.append(locs) |
| | start_end_in_new_string.append((new_pos, new_pos + len(replacement))) |
| | new_pos += len(replacement) |
| |
|
| | original_pos = match.end() |
| |
|
| | |
| | new_text += text[original_pos:] |
| |
|
| | return new_text, list_of_stringified_bboxes, start_end_in_new_string |
| | |
| | def _format_text_with_bboxes(self, text, list_of_list_of_bboxes, image): |
| | size_wh = self._get_image_size_wh(image) |
| | quantized_bbox_lists = [] |
| | for list_of_bboxes in list_of_list_of_bboxes: |
| | quantized_bboxes = self.box_quantizer.quantize(list_of_bboxes, size_wh=size_wh) |
| | stringified_bboxes = [f"<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>" for x1, y1, x2, y2 in quantized_bboxes] |
| | stringified_bboxes = ",".join(stringified_bboxes) |
| | quantized_bbox_lists.append(stringified_bboxes) |
| | return text.format(*quantized_bbox_lists) |
| |
|
| | def _get_image_size_wh(self, image): |
| | |
| | if isinstance(image, torch.Tensor): |
| | |
| | if image.dim() == 3: |
| | size_wh = (image.shape[2], image.shape[1]) |
| | elif image.dim() == 4: |
| | size_wh = (image.shape[3], image.shape[2]) |
| | else: |
| | raise ValueError("Unsupported tensor dimensions") |
| | elif isinstance(image, np.ndarray): |
| | |
| | if image.ndim == 2: |
| | size_wh = (image.shape[1], image.shape[0]) |
| | elif image.ndim == 3: |
| | size_wh = (image.shape[1], image.shape[0]) |
| | else: |
| | raise ValueError("Unsupported array dimensions") |
| | elif isinstance(image, PIL.Image.Image): |
| | |
| | size_wh = image.size |
| | else: |
| | raise TypeError("Unsupported image type") |
| | return size_wh |
| |
|
| | |
| | def batch_decode(self, *args, **kwargs): |
| | """ |
| | This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
| | refer to the docstring of this method for more information. |
| | """ |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | |
| | def decode(self, *args, **kwargs): |
| | """ |
| | This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
| | the docstring of this method for more information. |
| | """ |
| | return self.tokenizer.decode(*args, **kwargs) |
| |
|
| | @property |
| | |
| | def model_input_names(self): |
| | tokenizer_input_names = self.tokenizer.model_input_names |
| | image_processor_input_names = self.image_processor.model_input_names |
| | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
| |
|
| | class BoxQuantizer(object): |
| | def __init__(self, mode, bins): |
| | self.mode = mode |
| | self.bins = bins |
| |
|
| | def quantize(self, boxes, size_wh): |
| | if not isinstance(boxes, torch.Tensor): |
| | boxes = torch.tensor(boxes) |
| | bins_w, bins_h = self.bins |
| | size_w, size_h = size_wh |
| | size_per_bin_w = size_w / bins_w |
| | size_per_bin_h = size_h / bins_h |
| | xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) |
| |
|
| | if self.mode == 'floor': |
| | quantized_xmin = ( |
| | xmin / size_per_bin_w).floor().clamp(0, bins_w - 1) |
| | quantized_ymin = ( |
| | ymin / size_per_bin_h).floor().clamp(0, bins_h - 1) |
| | quantized_xmax = ( |
| | xmax / size_per_bin_w).floor().clamp(0, bins_w - 1) |
| | quantized_ymax = ( |
| | ymax / size_per_bin_h).floor().clamp(0, bins_h - 1) |
| |
|
| | elif self.mode == 'round': |
| | raise NotImplementedError() |
| |
|
| | else: |
| | raise ValueError('Incorrect quantization type.') |
| |
|
| | quantized_boxes = torch.cat( |
| | (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1 |
| | ).int() |
| |
|
| | return quantized_boxes.tolist() |
| |
|
| | def dequantize_from_stringified_bboxes(self, stringified_bboxes, size_wh): |
| | bboxes = stringified_bboxes.split(',') |
| |
|
| | def parse_bbox(bbox_string): |
| | pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>' |
| | match = re.match(pattern, bbox_string) |
| | if match: |
| | return [int(match.group(i)) for i in range(1, 5)] |
| | else: |
| | raise ValueError(f"Invalid bbox string format: {bbox_string}") |
| |
|
| | parsed_bboxes = [parse_bbox(bbox) for bbox in bboxes] |
| | return self.dequantize(parsed_bboxes, size_wh).tolist() |
| |
|
| | def dequantize(self, boxes: torch.Tensor, size): |
| | if not isinstance(boxes, torch.Tensor): |
| | boxes = torch.tensor(boxes) |
| | bins_w, bins_h = self.bins |
| | size_w, size_h = size |
| | size_per_bin_w = size_w / bins_w |
| | size_per_bin_h = size_h / bins_h |
| | xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) |
| |
|
| | if self.mode == 'floor': |
| | |
| | dequantized_xmin = (xmin + 0.5) * size_per_bin_w |
| | dequantized_ymin = (ymin + 0.5) * size_per_bin_h |
| | dequantized_xmax = (xmax + 0.5) * size_per_bin_w |
| | dequantized_ymax = (ymax + 0.5) * size_per_bin_h |
| |
|
| | elif self.mode == 'round': |
| | raise NotImplementedError() |
| |
|
| | else: |
| | raise ValueError('Incorrect quantization type.') |
| |
|
| | dequantized_boxes = torch.cat( |
| | (dequantized_xmin, dequantized_ymin, |
| | dequantized_xmax, dequantized_ymax), dim=-1 |
| | ) |
| |
|
| | return dequantized_boxes |