basantyahya commited on
Commit
b385db3
·
verified ·
1 Parent(s): 94bea90

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from io import BytesIO
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from PIL import ImageColor
6
+ import json
7
+ import google.generativeai as genai
8
+ from google.generativeai import types
9
+ from dotenv import load_dotenv
10
+
11
+
12
+ # 1. SETUP API KEY
13
+ # ----------------
14
+ load_dotenv()
15
+ api_key = os.getenv("Gemini_API_Key")
16
+ # Configure the Google AI library
17
+ genai.configure(api_key=api_key)
18
+
19
+
20
+ # 2. DEFINE MODEL AND INSTRUCTIONS
21
+
22
+ bounding_box_system_instructions = """
23
+ Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
24
+ If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
25
+ """
26
+ model = genai.GenerativeModel( model_name='gemini-2.5-flash', system_instruction=bounding_box_system_instructions , safety_settings=[ types.SafetySettingDict( category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH", ) ],)
27
+ generation_config = genai.types.GenerationConfig(
28
+ temperature=0.5,
29
+
30
+ )
31
+
32
+
33
+ def generate_bounding_boxes(prompt, image):
34
+ image = image.resize((1024, int(1024 * image.height / image.width)))
35
+ response = model.generate_content([prompt, image], generation_config=generation_config)
36
+ bounding_boxes = parse_json(response.text)
37
+ img=plot_bounding_boxes(image, bounding_boxes)
38
+ return img
39
+
40
+
41
+ def parse_json(json_output):
42
+ lines = json_output.splitlines()
43
+ for i, line in enumerate(lines):
44
+ if line == "```json":
45
+ json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
46
+ json_output = json_output.split("```")[0] # Remove everything after the closing "```"
47
+ break
48
+ return json_output
49
+
50
+ def plot_bounding_boxes(im, bounding_boxes):
51
+ """
52
+ Plots bounding boxes on an image with labels.
53
+ """
54
+ additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
55
+
56
+ im = im.copy()
57
+ width, height = im.size
58
+ draw = ImageDraw.Draw(im)
59
+ colors = [
60
+ 'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
61
+ 'lime', 'magenta', 'violet', 'gold', 'silver'
62
+ ] + additional_colors
63
+
64
+ try:
65
+ # Use a default font if NotoSansCJK is not available
66
+ try:
67
+ font = ImageFont.load_default()
68
+ except OSError:
69
+ print("NotoSansCJK-Regular.ttc not found. Using default font.")
70
+ font = ImageFont.load_default()
71
+
72
+ bounding_boxes_json = json.loads(bounding_boxes)
73
+ for i, bounding_box in enumerate(bounding_boxes_json):
74
+ color = colors[i % len(colors)]
75
+ abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
76
+ abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
77
+ abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
78
+ abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
79
+
80
+ if abs_x1 > abs_x2:
81
+ abs_x1, abs_x2 = abs_x2, abs_x1
82
+
83
+ if abs_y1 > abs_y2:
84
+ abs_y1, abs_y2 = abs_y2, abs_y1
85
+
86
+ # Draw bounding box and label
87
+ draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
88
+ if "label" in bounding_box:
89
+ draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
90
+ except Exception as e:
91
+ print(f"Error drawing bounding boxes: {e}")
92
+
93
+ return im
94
+ def gradio_interface():
95
+ """
96
+ Gradio app interface for bounding box generation with example pairs.
97
+ """
98
+ # Example image + prompt pairs
99
+ examples = [
100
+ ["cookies.jpg", "Detect the cookies and label their types."],
101
+ ["messed_room.jpg", "Find the unorganized item and suggest action in label in the image to fix them."],
102
+ ["yoga.jpg", "Show the different yoga poses and name them."],
103
+ ["zoom_face.png", "Label the tired faces in the image."]
104
+ ]
105
+
106
+ with gr.Blocks(gr.themes.Glass(secondary_hue= "rose")) as demo:
107
+ gr.Markdown("# Gemini Bounding Box Generator")
108
+
109
+ with gr.Row():
110
+ with gr.Column():
111
+ gr.Markdown("### Input Section")
112
+ input_image = gr.Image(type="pil", label="Input Image")
113
+ input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect.")
114
+ submit_btn = gr.Button("Generate")
115
+
116
+ with gr.Column():
117
+ gr.Markdown("### Output Section")
118
+ output_image = gr.Image(type="pil", label="Output Image")
119
+ #output_json = gr.Textbox(label="Bounding Boxes JSON")
120
+
121
+ gr.Markdown("### Examples")
122
+ gr.Examples(
123
+ examples=examples,
124
+ inputs=[input_image, input_prompt],
125
+ label="Example Images with Prompts"
126
+ )
127
+
128
+ # Event to generate bounding boxes
129
+ submit_btn.click(
130
+ generate_bounding_boxes,
131
+ inputs=[input_prompt, input_image],
132
+ outputs=[output_image]
133
+ )
134
+
135
+ return demo
136
+
137
+
138
+
139
+ if __name__ == "__main__":
140
+ app = gradio_interface()
141
+ app.launch()