| | import base64 |
| | import io |
| | from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
| |
|
| | import cv2 |
| | import easyocr |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from PIL.Image import Image as ImageType |
| | from supervision.detection.core import Detections |
| | from supervision.draw.color import Color, ColorPalette |
| | from torchvision.ops import box_convert |
| | from torchvision.transforms import ToPILImage |
| | from transformers import AutoModelForCausalLM, AutoProcessor |
| | from transformers.image_utils import load_image |
| | from ultralytics import YOLO |
| |
|
| | |
| | |
| | easyocr.Reader(["en"]) |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir: str = "/repository") -> None: |
| | self.device = ( |
| | torch.device("cuda") if torch.cuda.is_available() |
| | else (torch.device("mps") if torch.backends.mps.is_available() |
| | else torch.device("cpu")) |
| | ) |
| |
|
| | |
| | self.yolo = YOLO(f"{model_dir}/icon_detect/model.pt") |
| |
|
| | |
| | self.processor = AutoProcessor.from_pretrained( |
| | "microsoft/Florence-2-base", trust_remote_code=True |
| | ) |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | f"{model_dir}/icon_caption", |
| | torch_dtype=torch.float16, |
| | trust_remote_code=True, |
| | ).to(self.device) |
| |
|
| | |
| | self.ocr = easyocr.Reader(["en"]) |
| |
|
| | |
| | self.annotator = BoxAnnotator() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | data = data.pop("inputs") |
| |
|
| | |
| | image = load_image(data["image"]) |
| |
|
| | ocr_texts, ocr_bboxes = self.check_ocr_bboxes( |
| | image, |
| | out_format="xyxy", |
| | ocr_kwargs={"text_threshold": 0.8}, |
| | ) |
| | annotated_image, filtered_bboxes_out = self.get_som_labeled_img( |
| | image, |
| | image_size=data.get("image_size", None), |
| | ocr_texts=ocr_texts, |
| | ocr_bboxes=ocr_bboxes, |
| | bbox_threshold=data.get("bbox_threshold", 0.05), |
| | iou_threshold=data.get("iou_threshold", None), |
| | ) |
| | return { |
| | "image": annotated_image, |
| | "bboxes": filtered_bboxes_out, |
| | } |
| |
|
| | def check_ocr_bboxes( |
| | self, |
| | image: ImageType, |
| | out_format: Literal["xywh", "xyxy"] = "xywh", |
| | ocr_kwargs: Optional[Dict[str, Any]] = {}, |
| | ) -> Tuple[List[str], List[List[int]]]: |
| | if image.mode == "RBGA": |
| | image = image.convert("RGB") |
| |
|
| | result = self.ocr.readtext(np.array(image), **ocr_kwargs) |
| | texts = [str(item[1]) for item in result] |
| | bboxes = [ |
| | self.coordinates_to_bbox(item[0], format=out_format) for item in result |
| | ] |
| | return (texts, bboxes) |
| |
|
| | @staticmethod |
| | def coordinates_to_bbox( |
| | coordinates: np.ndarray, format: Literal["xywh", "xyxy"] = "xywh" |
| | ) -> List[int]: |
| | match format: |
| | case "xywh": |
| | return [ |
| | int(coordinates[0][0]), |
| | int(coordinates[0][1]), |
| | int(coordinates[2][0] - coordinates[0][0]), |
| | int(coordinates[2][1] - coordinates[0][1]), |
| | ] |
| | case "xyxy": |
| | return [ |
| | int(coordinates[0][0]), |
| | int(coordinates[0][1]), |
| | int(coordinates[2][0]), |
| | int(coordinates[2][1]), |
| | ] |
| |
|
| | @staticmethod |
| | def bbox_area(bbox: List[int], w: int, h: int) -> int: |
| | bbox = [bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h] |
| | return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
| |
|
| | @staticmethod |
| | def remove_bbox_overlap( |
| | xyxy_bboxes: List[Dict[str, Any]], |
| | ocr_bboxes: Optional[List[Dict[str, Any]]] = None, |
| | iou_threshold: Optional[float] = 0.7, |
| | ) -> List[Dict[str, Any]]: |
| | filtered_bboxes = [] |
| | if ocr_bboxes is not None: |
| | filtered_bboxes.extend(ocr_bboxes) |
| |
|
| | for i, bbox_outter in enumerate(xyxy_bboxes): |
| | bbox_left = bbox_outter["bbox"] |
| | valid_bbox = True |
| |
|
| | for j, bbox_inner in enumerate(xyxy_bboxes): |
| | if i == j: |
| | continue |
| |
|
| | bbox_right = bbox_inner["bbox"] |
| | if ( |
| | intersection_over_union( |
| | bbox_left, |
| | bbox_right, |
| | ) |
| | > iou_threshold |
| | ) and (area(bbox_left) > area(bbox_right)): |
| | valid_bbox = False |
| | break |
| |
|
| | if valid_bbox is False: |
| | continue |
| |
|
| | if ocr_bboxes is None: |
| | filtered_bboxes.append(bbox_outter) |
| | continue |
| |
|
| | box_added = False |
| | ocr_labels = [] |
| | for ocr_bbox in ocr_bboxes: |
| | if not box_added: |
| | bbox_right = ocr_bbox["bbox"] |
| | if overlap(bbox_right, bbox_left): |
| | try: |
| | ocr_labels.append(ocr_bbox["content"]) |
| | filtered_bboxes.remove(ocr_bbox) |
| | except Exception: |
| | continue |
| | elif overlap(bbox_left, bbox_right): |
| | box_added = True |
| | break |
| |
|
| | if not box_added: |
| | filtered_bboxes.append( |
| | { |
| | "type": "icon", |
| | "bbox": bbox_outter["bbox"], |
| | "interactivity": True, |
| | "content": " ".join(ocr_labels) if ocr_labels else None, |
| | } |
| | ) |
| |
|
| | return filtered_bboxes |
| |
|
| | def get_som_labeled_img( |
| | self, |
| | image: ImageType, |
| | image_size: Optional[Dict[Literal["w", "h"], int]] = None, |
| | ocr_texts: Optional[List[str]] = None, |
| | ocr_bboxes: Optional[List[List[int]]] = None, |
| | bbox_threshold: float = 0.01, |
| | iou_threshold: Optional[float] = None, |
| | caption_prompt: Optional[str] = None, |
| | caption_batch_size: int = 64, |
| | ) -> Tuple[str, List[Dict[str, Any]]]: |
| | if image.mode == "RBGA": |
| | image = image.convert("RGB") |
| |
|
| | w, h = image.size |
| | if image_size is None: |
| | imgsz = {"h": h, "w": w} |
| | else: |
| | imgsz = [image_size.get("h", h), image_size.get("w", w)] |
| |
|
| | out = self.yolo.predict( |
| | image, |
| | imgsz=imgsz, |
| | conf=bbox_threshold, |
| | iou=iou_threshold or 0.7, |
| | verbose=False, |
| | )[0] |
| | if out.boxes is None: |
| | raise RuntimeError( |
| | "YOLO prediction failed to produce the bounding boxes..." |
| | ) |
| |
|
| | xyxy_bboxes = out.boxes.xyxy |
| | xyxy_bboxes = xyxy_bboxes / torch.Tensor([w, h, w, h]).to(xyxy_bboxes.device) |
| | image_np = np.asarray(image) |
| |
|
| | if ocr_bboxes: |
| | ocr_bboxes = torch.tensor(ocr_bboxes) / torch.Tensor([w, h, w, h]) |
| | ocr_bboxes = ocr_bboxes.tolist() |
| |
|
| | ocr_bboxes = [ |
| | { |
| | "type": "text", |
| | "bbox": bbox, |
| | "interactivity": False, |
| | "content": text, |
| | "source": "box_ocr_content_ocr", |
| | } |
| | for bbox, text in zip(ocr_bboxes, ocr_texts) |
| | if self.bbox_area(bbox, w, h) > 0 |
| | ] |
| | xyxy_bboxes = [ |
| | { |
| | "type": "icon", |
| | "bbox": bbox, |
| | "interactivity": True, |
| | "content": None, |
| | "source": "box_yolo_content_yolo", |
| | } |
| | for bbox in xyxy_bboxes.tolist() |
| | if self.bbox_area(bbox, w, h) > 0 |
| | ] |
| |
|
| | filtered_bboxes = self.remove_bbox_overlap( |
| | xyxy_bboxes=xyxy_bboxes, |
| | ocr_bboxes=ocr_bboxes, |
| | iou_threshold=iou_threshold or 0.7, |
| | ) |
| |
|
| | filtered_bboxes_out = sorted( |
| | filtered_bboxes, key=lambda x: x["content"] is None |
| | ) |
| | starting_idx = next( |
| | ( |
| | idx |
| | for idx, bbox in enumerate(filtered_bboxes_out) |
| | if bbox["content"] is None |
| | ), |
| | -1, |
| | ) |
| |
|
| | filtered_bboxes = torch.tensor([box["bbox"] for box in filtered_bboxes_out]) |
| | non_ocr_bboxes = filtered_bboxes[starting_idx:] |
| |
|
| | bbox_images = [] |
| | for _, coordinates in enumerate(non_ocr_bboxes): |
| | try: |
| | xmin, xmax = ( |
| | int(coordinates[0] * image_np.shape[1]), |
| | int(coordinates[2] * image_np.shape[1]), |
| | ) |
| | ymin, ymax = ( |
| | int(coordinates[1] * image_np.shape[0]), |
| | int(coordinates[3] * image_np.shape[0]), |
| | ) |
| | cropped_image = image_np[ymin:ymax, xmin:xmax, :] |
| | cropped_image = cv2.resize(cropped_image, (64, 64)) |
| | bbox_images.append(ToPILImage()(cropped_image)) |
| | except Exception: |
| | continue |
| |
|
| | if caption_prompt is None: |
| | caption_prompt = "<CAPTION>" |
| |
|
| | captions = [] |
| | for idx in range(0, len(bbox_images), caption_batch_size): |
| | batch = bbox_images[idx : idx + caption_batch_size] |
| | inputs = self.processor( |
| | images=batch, |
| | text=[caption_prompt] * len(batch), |
| | return_tensors="pt", |
| | do_resize=False, |
| | ) |
| | if self.device.type in {"cuda", "mps"}: |
| | inputs = inputs.to(device=self.device, dtype=torch.float16) |
| |
|
| | with torch.inference_mode(): |
| | generated_ids = self.model.generate( |
| | input_ids=inputs["input_ids"], |
| | pixel_values=inputs["pixel_values"], |
| | max_new_tokens=20, |
| | num_beams=1, |
| | do_sample=False, |
| | early_stopping=False, |
| | ) |
| |
|
| | generated_texts = self.processor.batch_decode( |
| | generated_ids, skip_special_tokens=True |
| | ) |
| | captions.extend([text.strip() for text in generated_texts]) |
| |
|
| | ocr_texts = [f"Text Box ID {idx}: {text}" for idx, text in enumerate(ocr_texts)] |
| | for _, bbox in enumerate(filtered_bboxes_out): |
| | if bbox["content"] is None: |
| | bbox["content"] = captions.pop(0) |
| |
|
| | filtered_bboxes = box_convert( |
| | boxes=filtered_bboxes, in_fmt="xyxy", out_fmt="cxcywh" |
| | ) |
| |
|
| | annotated_image = image_np.copy() |
| | bboxes_annotate = filtered_bboxes * torch.Tensor([w, h, w, h]) |
| | xyxy_annotate = box_convert( |
| | bboxes_annotate, in_fmt="cxcywh", out_fmt="xyxy" |
| | ).numpy() |
| | detections = Detections(xyxy=xyxy_annotate) |
| | labels = [str(idx) for idx in range(bboxes_annotate.shape[0])] |
| |
|
| | annotated_image = self.annotator.annotate( |
| | scene=annotated_image, |
| | detections=detections, |
| | labels=labels, |
| | image_size=(w, h), |
| | ) |
| | assert w == annotated_image.shape[1] and h == annotated_image.shape[0] |
| |
|
| | out_image = Image.fromarray(annotated_image) |
| | out_buffer = io.BytesIO() |
| | out_image.save(out_buffer, format="PNG") |
| | encoded_image = base64.b64encode(out_buffer.getvalue()).decode("ascii") |
| |
|
| | return encoded_image, filtered_bboxes_out |
| |
|
| |
|
| | def area(bbox: List[int]) -> int: |
| | return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
| |
|
| |
|
| | def intersection_area(bbox_left: List[int], bbox_right: List[int]) -> int: |
| | return max( |
| | 0, min(bbox_left[2], bbox_right[2]) - min(bbox_left[0], bbox_right[0]) |
| | ) * max(0, min(bbox_left[3], bbox_right[3]) - min(bbox_left[1], bbox_right[1])) |
| |
|
| |
|
| | def intersection_over_union(bbox_left: List[int], bbox_right: List[int]) -> float: |
| | intersection = intersection_area(bbox_left, bbox_right) |
| | bbox_left_area = area(bbox_left) |
| | bbox_right_area = area(bbox_right) |
| | union = bbox_left_area + bbox_right_area - intersection + 1e-6 |
| |
|
| | ratio_left, ratio_right = 0, 0 |
| | if bbox_left_area > 0 and bbox_right_area > 0: |
| | ratio_left = intersection / bbox_left_area |
| | ratio_right = intersection / bbox_right_area |
| | return max(intersection / union, ratio_left, ratio_right) |
| |
|
| |
|
| | def overlap(bbox_left: List[int], bbox_right: List[int]) -> bool: |
| | intersection = intersection_area(bbox_left, bbox_right) |
| | ratio_left = intersection / area(bbox_left) |
| | return ratio_left > 0.80 |
| |
|
| |
|
| | class BoxAnnotator: |
| | def __init__( |
| | self, |
| | color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, |
| | thickness: int = 3, |
| | text_color: Color = Color.BLACK, |
| | text_scale: float = 0.5, |
| | text_thickness: int = 2, |
| | text_padding: int = 10, |
| | avoid_overlap: bool = True, |
| | ): |
| | self.color: Union[Color, ColorPalette] = color |
| | self.thickness: int = thickness |
| | self.text_color: Color = text_color |
| | self.text_scale: float = text_scale |
| | self.text_thickness: int = text_thickness |
| | self.text_padding: int = text_padding |
| | self.avoid_overlap: bool = avoid_overlap |
| |
|
| | def annotate( |
| | self, |
| | scene: np.ndarray, |
| | detections: Detections, |
| | labels: Optional[List[str]] = None, |
| | skip_label: bool = False, |
| | image_size: Optional[Tuple[int, int]] = None, |
| | ) -> np.ndarray: |
| | font = cv2.FONT_HERSHEY_SIMPLEX |
| | for i in range(len(detections)): |
| | x1, y1, x2, y2 = detections.xyxy[i].astype(int) |
| | class_id = ( |
| | detections.class_id[i] if detections.class_id is not None else None |
| | ) |
| | idx = class_id if class_id is not None else i |
| | color = ( |
| | self.color.by_idx(idx) |
| | if isinstance(self.color, ColorPalette) |
| | else self.color |
| | ) |
| | cv2.rectangle( |
| | img=scene, |
| | pt1=(x1, y1), |
| | pt2=(x2, y2), |
| | color=color.as_bgr(), |
| | thickness=self.thickness, |
| | ) |
| | if skip_label: |
| | continue |
| |
|
| | text = ( |
| | f"{class_id}" |
| | if (labels is None or len(detections) != len(labels)) |
| | else labels[i] |
| | ) |
| |
|
| | text_width, text_height = cv2.getTextSize( |
| | text=text, |
| | fontFace=font, |
| | fontScale=self.text_scale, |
| | thickness=self.text_thickness, |
| | )[0] |
| |
|
| | if not self.avoid_overlap: |
| | text_x = x1 + self.text_padding |
| | text_y = y1 - self.text_padding |
| |
|
| | text_background_x1 = x1 |
| | text_background_y1 = y1 - 2 * self.text_padding - text_height |
| |
|
| | text_background_x2 = x1 + 2 * self.text_padding + text_width |
| | text_background_y2 = y1 |
| | else: |
| | ( |
| | text_x, |
| | text_y, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ) = self.get_optimal_label_pos( |
| | self.text_padding, |
| | text_width, |
| | text_height, |
| | x1, |
| | y1, |
| | x2, |
| | y2, |
| | detections, |
| | image_size, |
| | ) |
| |
|
| | cv2.rectangle( |
| | img=scene, |
| | pt1=(text_background_x1, text_background_y1), |
| | pt2=(text_background_x2, text_background_y2), |
| | color=color.as_bgr(), |
| | thickness=cv2.FILLED, |
| | ) |
| | box_color = color.as_rgb() |
| | luminance = ( |
| | 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] |
| | ) |
| | text_color = (0, 0, 0) if luminance > 160 else (255, 255, 255) |
| | cv2.putText( |
| | img=scene, |
| | text=text, |
| | org=(text_x, text_y), |
| | fontFace=font, |
| | fontScale=self.text_scale, |
| | color=text_color, |
| | thickness=self.text_thickness, |
| | lineType=cv2.LINE_AA, |
| | ) |
| | return scene |
| |
|
| | @staticmethod |
| | def get_optimal_label_pos( |
| | text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size |
| | ): |
| | def get_is_overlap( |
| | detections, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | image_size, |
| | ): |
| | is_overlap = False |
| | for i in range(len(detections)): |
| | detection = detections.xyxy[i].astype(int) |
| | if ( |
| | intersection_over_union( |
| | [ |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ], |
| | detection, |
| | ) |
| | > 0.3 |
| | ): |
| | is_overlap = True |
| | break |
| | if ( |
| | text_background_x1 < 0 |
| | or text_background_x2 > image_size[0] |
| | or text_background_y1 < 0 |
| | or text_background_y2 > image_size[1] |
| | ): |
| | is_overlap = True |
| | return is_overlap |
| |
|
| | text_x = x1 + text_padding |
| | text_y = y1 - text_padding |
| |
|
| | text_background_x1 = x1 |
| | text_background_y1 = y1 - 2 * text_padding - text_height |
| |
|
| | text_background_x2 = x1 + 2 * text_padding + text_width |
| | text_background_y2 = y1 |
| | is_overlap = get_is_overlap( |
| | detections, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | image_size, |
| | ) |
| | if not is_overlap: |
| | return ( |
| | text_x, |
| | text_y, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ) |
| |
|
| | text_x = x1 - text_padding - text_width |
| | text_y = y1 + text_padding + text_height |
| |
|
| | text_background_x1 = x1 - 2 * text_padding - text_width |
| | text_background_y1 = y1 |
| |
|
| | text_background_x2 = x1 |
| | text_background_y2 = y1 + 2 * text_padding + text_height |
| | is_overlap = get_is_overlap( |
| | detections, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | image_size, |
| | ) |
| | if not is_overlap: |
| | return ( |
| | text_x, |
| | text_y, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ) |
| |
|
| | text_x = x2 + text_padding |
| | text_y = y1 + text_padding + text_height |
| |
|
| | text_background_x1 = x2 |
| | text_background_y1 = y1 |
| |
|
| | text_background_x2 = x2 + 2 * text_padding + text_width |
| | text_background_y2 = y1 + 2 * text_padding + text_height |
| |
|
| | is_overlap = get_is_overlap( |
| | detections, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | image_size, |
| | ) |
| | if not is_overlap: |
| | return ( |
| | text_x, |
| | text_y, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ) |
| |
|
| | text_x = x2 - text_padding - text_width |
| | text_y = y1 - text_padding |
| |
|
| | text_background_x1 = x2 - 2 * text_padding - text_width |
| | text_background_y1 = y1 - 2 * text_padding - text_height |
| |
|
| | text_background_x2 = x2 |
| | text_background_y2 = y1 |
| |
|
| | is_overlap = get_is_overlap( |
| | detections, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | image_size, |
| | ) |
| | if not is_overlap: |
| | return ( |
| | text_x, |
| | text_y, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ) |
| |
|
| | return ( |
| | text_x, |
| | text_y, |
| | text_background_x1, |
| | text_background_y1, |
| | text_background_x2, |
| | text_background_y2, |
| | ) |