| |
| |
| import argparse |
| import io |
| import json |
| import os |
| import re |
| from typing import Dict, List |
|
|
| from project_settings import project_path |
|
|
| os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() |
|
|
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image |
| import requests |
| import torch |
| from transformers.models.auto.processing_auto import AutoImageProcessor |
| from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor |
| from transformers.models.auto.modeling_auto import AutoModelForObjectDetection |
| import validators |
|
|
| from project_settings import project_path |
|
|
|
|
| |
| COLORS = [ |
| [0.000, 0.447, 0.741], |
| [0.850, 0.325, 0.098], |
| [0.929, 0.694, 0.125], |
| [0.494, 0.184, 0.556], |
| [0.466, 0.674, 0.188], |
| [0.301, 0.745, 0.933] |
| ] |
|
|
|
|
| def get_original_image(url_input): |
| if validators.url(url_input): |
| image = Image.open(requests.get(url_input, stream=True).raw) |
| return image |
|
|
|
|
| def figure2image(fig): |
| buf = io.BytesIO() |
| fig.savefig(buf) |
| buf.seek(0) |
| pil_image = Image.open(buf) |
| base_width = 750 |
| width_percent = base_width / float(pil_image.size[0]) |
| height_size = (float(pil_image.size[1]) * float(width_percent)) |
| height_size = int(height_size) |
| pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS) |
| return pil_image |
|
|
|
|
| def non_max_suppression(boxes, scores, threshold): |
| """Apply non-maximum suppression at test time to avoid detecting too many |
| overlapping bounding boxes for a given object. |
| Args: |
| boxes: array of [xmin, ymin, xmax, ymax] |
| scores: array of scores associated with each box. |
| threshold: IoU threshold |
| Return: |
| keep: indices of the boxes to keep |
| """ |
| x1 = boxes[:, 0] |
| y1 = boxes[:, 1] |
| x2 = boxes[:, 2] |
| y2 = boxes[:, 3] |
|
|
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) |
| order = scores.argsort()[::-1] |
|
|
| keep = [] |
| while order.size > 0: |
| i = order[0] |
| keep.append(i) |
|
|
| xx1 = np.maximum(x1[i], x1[order[1:]]) |
| yy1 = np.maximum(y1[i], y1[order[1:]]) |
| xx2 = np.minimum(x2[i], x2[order[1:]]) |
| yy2 = np.minimum(y2[i], y2[order[1:]]) |
|
|
| w = np.maximum(0.0, xx2 - xx1 + 1) |
| h = np.maximum(0.0, yy2 - yy1 + 1) |
| inter = w * h |
|
|
| ovr = inter / (areas[i] + areas[order[1:]] - inter) |
| inds = np.where(ovr <= threshold)[0] |
| order = order[inds + 1] |
|
|
| return keep |
|
|
|
|
| def draw_boxes(image, boxes, scores, labels, threshold: float, |
| idx_to_label: Dict[int, str] = None, labels_to_show: str = None): |
| if isinstance(labels_to_show, str): |
| if len(labels_to_show.strip()) == 0: |
| labels_to_show = None |
| else: |
| labels_to_show = labels_to_show.split(",") |
| labels_to_show = [label.strip().lower() for label in labels_to_show] |
| labels_to_show = None if len(labels_to_show) == 0 else labels_to_show |
|
|
| plt.figure(figsize=(50, 50)) |
| plt.imshow(image) |
|
|
| if idx_to_label is not None: |
| labels = [idx_to_label[x] for x in labels] |
|
|
| axis = plt.gca() |
| colors = COLORS * len(boxes) |
| for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors): |
| if labels_to_show is not None and label.lower() not in labels_to_show: |
| continue |
| if score < threshold: |
| continue |
| axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10)) |
| axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8)) |
| plt.axis("off") |
|
|
| return figure2image(plt.gcf()) |
|
|
|
|
| def detr_object_detection(url_input: str, |
| image_input: Image, |
| pretrained_model_name_or_path: str = "qgyd2021/detr_cppe5_object_detection", |
| threshold: float = 0.5, |
| iou_threshold: float = 0.5, |
| labels_to_show: str = None, |
| ): |
| |
| model = AutoModelForObjectDetection.from_pretrained(pretrained_model_name_or_path) |
| image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path) |
|
|
| |
| if validators.url(url_input): |
| image = get_original_image(url_input) |
| elif image_input: |
| image = image_input |
| else: |
| raise AssertionError("at least one `url_input` and `image_input`") |
| image_size = torch.tensor([tuple(reversed(image.size))]) |
|
|
| |
| |
| inputs = image_processor(images=image, return_tensors="pt") |
| outputs = model.forward(**inputs) |
|
|
| processed_outputs = image_processor.post_process_object_detection( |
| outputs, threshold=threshold, target_sizes=image_size) |
| |
| processed_outputs = processed_outputs[0] |
|
|
| |
| boxes = processed_outputs["boxes"].detach().numpy() |
| scores = processed_outputs["scores"].detach().numpy() |
| labels = processed_outputs["labels"].detach().numpy() |
|
|
| keep = non_max_suppression(boxes, scores, threshold=iou_threshold) |
| boxes = boxes[keep] |
| scores = scores[keep] |
| labels = labels[keep] |
|
|
| viz_image: Image = draw_boxes( |
| image, boxes, scores, labels, |
| threshold=threshold, |
| idx_to_label=model.config.id2label, |
| labels_to_show=labels_to_show |
| ) |
| return viz_image |
|
|
|
|
| def main(): |
|
|
| title = "## Detr Cppe5 Object Detection" |
|
|
| description = """ |
| reference: |
| https://huggingface.co/docs/transformers/tasks/object_detection |
| |
| """ |
|
|
| example_urls = [ |
| *[ |
| [ |
| "https://huggingface.co/datasets/intelli-zen/cppe-5/resolve/main/data/images/{}.png".format(idx), |
| "intelli-zen/detr_cppe5_object_detection", |
| 0.5, 0.6, None |
| ] for idx in range(1001, 1030) |
| ] |
| ] |
|
|
| example_images = [ |
| [ |
| "data/2lnWoly.jpg", |
| "intelli-zen/detr_cppe5_object_detection", |
| 0.5, 0.6, None |
| ] |
| ] |
|
|
| with gr.Blocks() as blocks: |
| gr.Markdown(value=title) |
| gr.Markdown(value=description) |
|
|
| model_name = gr.components.Dropdown( |
| choices=[ |
| "intelli-zen/detr_cppe5_object_detection", |
| ], |
| value="intelli-zen/detr_cppe5_object_detection", |
| label="model_name", |
| ) |
| threshold_slider = gr.components.Slider( |
| minimum=0, maximum=1.0, |
| step=0.01, value=0.5, |
| label="Threshold" |
| ) |
| iou_threshold_slider = gr.components.Slider( |
| minimum=0, maximum=1.0, |
| step=0.1, value=0.5, |
| label="IOU Threshold" |
| ) |
| classes_to_detect = gr.Textbox(placeholder="e.g. person, truck (split by , comma).", |
| label="labels to show") |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Image URL"): |
| with gr.Row(): |
| with gr.Column(): |
| url_input = gr.Textbox(lines=1, label="Enter valid image URL here..") |
| original_image = gr.Image() |
| url_input.change(get_original_image, url_input, original_image) |
| with gr.Column(): |
| img_output_from_url = gr.Image() |
|
|
| url_but = gr.Button("Detect") |
|
|
| with gr.Row(): |
| gr.Examples(examples=example_urls, |
| inputs=[url_input, model_name, threshold_slider, iou_threshold_slider], |
| examples_per_page=5, |
| ) |
|
|
| with gr.TabItem("Image Upload"): |
| with gr.Row(): |
| img_input = gr.Image(type="pil") |
| img_output_from_upload = gr.Image() |
|
|
| img_but = gr.Button("Detect") |
|
|
| with gr.Row(): |
| gr.Examples(examples=example_images, |
| inputs=[img_input, model_name, threshold_slider, iou_threshold_slider], |
| examples_per_page=5, |
| ) |
|
|
| inputs = [url_input, img_input, model_name, threshold_slider, iou_threshold_slider, classes_to_detect] |
| url_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_url], queue=True) |
| img_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_upload], queue=True) |
|
|
| blocks.queue().launch() |
| return |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|