wzzanthony7 commited on
Commit
cc810ad
·
verified ·
1 Parent(s): b8690db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -33
app.py CHANGED
@@ -7,6 +7,98 @@ import tempfile
7
 
8
  classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'tick', 'fraction']
9
  API_KEY = os.environ.get("ROBOFLOW_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def packFilterBoxInfo(filter_box_info):
11
  # 数字类别映射
12
  digit_classes = {
@@ -95,23 +187,34 @@ def generate_textual_description(box_info):
95
  class_summary[c_name].append([x-w/2, y-h/2, x+w/2, y+h/2])
96
  # Generate a summary for each class
97
  #the index of the left one
98
- class_summary['zero'].sort()
99
- left_most_zero_cor = class_summary['zero'][0]
100
- left_zero = True
101
- class_summary['one'].sort()
102
- right_most_one_cor = class_summary['one'][-1]
103
- right_one = True
104
- for fra_box in class_summary['fraction']:
105
- if getOverlap(fra_box, left_most_zero_cor) >= 0.5:
106
- left_zero = False
107
- if getOverlap(fra_box, right_most_one_cor) >= 0.5:
108
- left_one = False
109
- textual_description = ""
 
 
 
 
 
 
 
 
 
110
  textual_description += "The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)"
111
  #print(f"The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)")
112
- if left_zero:
 
113
  textual_description += f"There is a zero on the left side of the number line. Its coordinate is (({left_most_zero_cor[0]:.2f}, {left_most_zero_cor[1]:.2f}), ({left_most_zero_cor[2]:.2f}, {left_most_zero_cor[3]:.2f}))"
114
- if right_one:
 
115
  textual_description += f"There is a one on the right side of the number line. Its coordinate is (({right_most_one_cor[0]:.2f}, {right_most_one_cor[1]:.2f}), ({right_most_one_cor[2]:.2f}, {right_most_one_cor[3]:.2f}))"
116
  present_classes = ['fraction', 'tick']
117
  for cid, boxes in class_summary.items():
@@ -147,41 +250,76 @@ def greet(name):
147
 
148
  def process_image(image):
149
  if image is None:
150
- return None, "", ""
151
  pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image)
152
  boxed_img = drawWithAllBox_info(pil_image, test_box_info)
153
  textual = generate_textual_description(test_box_info)
154
  json_str = json.dumps(test_box_info, indent=2)
155
  return boxed_img, textual, json_str
156
- '''
157
- def download_json(json_str):
158
- with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w") as f:
159
- f.write(json_str)
160
- temp_path = f.name
161
- return gr.File.update(value=temp_path, visible=True)
162
- '''
163
  with gr.Blocks() as demo:
164
- img_input = gr.Image(type="pil", label="Upload Image")
165
- run_btn = gr.Button("Run Detection")
166
- img_out = gr.Image(type="pil", label="Image with Boxes")
167
- text_out = gr.Textbox(label="Textual Description", lines=8)
168
- json_state = gr.State("")
169
- download_btn = gr.DownloadButton(
170
- label="Download Box Info as JSON"
171
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def _process(image):
173
  boxed_img, textual, json_str = process_image(image)
174
  return boxed_img, textual, json_str
 
175
  def create_and_download_json(json_str):
176
  if not json_str:
177
  return None
178
  with tempfile.NamedTemporaryFile(
179
- mode='w', delete=False, prefix="detection_box_", suffix='.json', encoding='utf-8'
 
180
  ) as f:
181
  f.write(json_str)
182
  return f.name
183
 
184
- run_btn.click(_process, inputs=img_input, outputs=[img_out, text_out, json_state])
185
- download_btn.click(create_and_download_json, inputs=json_state, outputs=download_btn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  demo.launch()
 
7
 
8
  classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'tick', 'fraction']
9
  API_KEY = os.environ.get("ROBOFLOW_API_KEY")
10
+ ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
11
+
12
+
13
+ def RoboFlowGetOutlineBoxesPIL(pil_img):
14
+ client = InferenceHTTPClient(
15
+ api_url="https://detect.roboflow.com",
16
+ api_key= roboflow_api
17
+ )
18
+ result = client.run_workflow(
19
+ workspace_name="mathnet-mmpuo",
20
+ workflow_id="custom-workflow-2",
21
+ images={
22
+ "image": pil_img
23
+ },
24
+ use_cache=True # cache workflow definition for 15 minutes
25
+ )
26
+ return result
27
+ def calculate_iou(box1, box2):
28
+ """计算两个框的IoU
29
+ box格式: (x, y, width, height)
30
+ """
31
+ # 计算每个框的左上角和右下角坐标
32
+ box1_x1 = box1['x'] - box1['width']/2
33
+ box1_y1 = box1['y'] - box1['height']/2
34
+ box1_x2 = box1['x'] + box1['width']/2
35
+ box1_y2 = box1['y'] + box1['height']/2
36
+
37
+ box2_x1 = box2['x'] - box2['width']/2
38
+ box2_y1 = box2['y'] - box2['height']/2
39
+ box2_x2 = box2['x'] + box2['width']/2
40
+ box2_y2 = box2['y'] + box2['height']/2
41
+
42
+ # 计算交集区域的坐标
43
+ inter_x1 = max(box1_x1, box2_x1)
44
+ inter_y1 = max(box1_y1, box2_y1)
45
+ inter_x2 = min(box1_x2, box2_x2)
46
+ inter_y2 = min(box1_y2, box2_y2)
47
+
48
+ # 计算交集面积
49
+ if inter_x1 < inter_x2 and inter_y1 < inter_y2:
50
+ inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
51
+ else:
52
+ return 0.0
53
+
54
+ # 计算两个框的面积
55
+ box1_area = box1['width'] * box1['height']
56
+ box2_area = box2['width'] * box2['height']
57
+
58
+ # 计算并集面积
59
+ union_area = box1_area + box2_area - inter_area
60
+
61
+ # 返回IoU
62
+ return inter_area / union_area
63
+
64
+ def parse_roboflow_result(result, kept_classes):
65
+ all_box_info = []
66
+ for box_info in result[0]['predictions']['predictions']['predictions']:
67
+ if box_info['class'] in kept_classes:
68
+ all_box_info.append(box_info)
69
+ return all_box_info
70
+
71
+ def filter_overlapping_boxes(filter_box_info, iou_threshold=0.5):
72
+
73
+ digit_classes = {'zero', 'one', 'two', 'three', 'four',
74
+ 'five', 'six', 'seven', 'eight', 'nine'}
75
+
76
+ # 分离数字框和其他框
77
+ digit_boxes = []
78
+ other_boxes = []
79
+
80
+ for box in filter_box_info:
81
+ if box['class'] in digit_classes:
82
+ digit_boxes.append(box)
83
+ else:
84
+ other_boxes.append(box)
85
+
86
+
87
+ digit_boxes.sort(key=lambda x: x['confidence'], reverse=True)
88
+ kept_boxes = []
89
+
90
+ for i, box in enumerate(digit_boxes):
91
+ should_keep = True
92
+
93
+ for kept_box in kept_boxes:
94
+ if calculate_iou(box, kept_box) > iou_threshold:
95
+ should_keep = False
96
+ break
97
+ if should_keep:
98
+ kept_boxes.append(box)
99
+
100
+ return other_boxes + kept_boxes
101
+
102
  def packFilterBoxInfo(filter_box_info):
103
  # 数字类别映射
104
  digit_classes = {
 
187
  class_summary[c_name].append([x-w/2, y-h/2, x+w/2, y+h/2])
188
  # Generate a summary for each class
189
  #the index of the left one
190
+ kept_zero_boxes = []
191
+ for zero_box in class_summary['zero']:
192
+ kept_zero = True
193
+ for fra_box in class_summary['fraction']:
194
+ if getOverlap(fra_box, left_most_zero_cor) >= 0.5:
195
+ kept_zero = False
196
+ break
197
+ if kept_zero:
198
+ kept_zero_boxes.append(zero_box)
199
+ kept_one_boxes = []
200
+ for one_box in class_summary['fraction']:
201
+ kept_one = True
202
+ for fra_box in class_summary['fraction']:
203
+ if getOverlap(fra_box, right_most_one_cor) >= 0.5:
204
+ kept_one = False
205
+ break
206
+ if kept_one:
207
+ kept_one_boxes.append(one_box)
208
+ kept_zero_boxes.sort(key = lambda x: x[0])
209
+ kept_one_boxes.sort(key = lambda x: x[0])
210
+ textual_description = "" #final output
211
  textual_description += "The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)"
212
  #print(f"The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)")
213
+ if len(kept_zero_boxes) >= 1:
214
+ left_most_zero_cor = kept_zero_boxes[0]
215
  textual_description += f"There is a zero on the left side of the number line. Its coordinate is (({left_most_zero_cor[0]:.2f}, {left_most_zero_cor[1]:.2f}), ({left_most_zero_cor[2]:.2f}, {left_most_zero_cor[3]:.2f}))"
216
+ if len(kept_one_boxes) >= 1:
217
+ right_most_one_cor = kept_one_boxes[-1]
218
  textual_description += f"There is a one on the right side of the number line. Its coordinate is (({right_most_one_cor[0]:.2f}, {right_most_one_cor[1]:.2f}), ({right_most_one_cor[2]:.2f}, {right_most_one_cor[3]:.2f}))"
219
  present_classes = ['fraction', 'tick']
220
  for cid, boxes in class_summary.items():
 
250
 
251
  def process_image(image):
252
  if image is None:
253
+ return None, "", "", None
254
  pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image)
255
  boxed_img = drawWithAllBox_info(pil_image, test_box_info)
256
  textual = generate_textual_description(test_box_info)
257
  json_str = json.dumps(test_box_info, indent=2)
258
  return boxed_img, textual, json_str
259
+
260
+
 
 
 
 
 
261
  with gr.Blocks() as demo:
262
+ # --- Authentication Layer ---
263
+ with gr.Row():
264
+ token_input = gr.Textbox(
265
+ label="Invite Token",
266
+ type="password",
267
+ placeholder="Enter your invite token to unlock the app"
268
+ )
269
+ unlock_btn = gr.Button("Unlock")
270
+
271
+ status_text = gr.Markdown()
272
+
273
+ # --- Main Application (initially hidden) ---
274
+ with gr.Column(visible=False) as main_app:
275
+ img_input = gr.Image(type="pil", label="Upload Image")
276
+ run_btn = gr.Button("Run Detection")
277
+ img_out = gr.Image(type="pil", label="Image with Boxes")
278
+ text_out = gr.Textbox(label="Textual Description", lines=8)
279
+ json_state = gr.State("")
280
+ download_btn = gr.DownloadButton(
281
+ label="Download Box Info as JSON"
282
+ )
283
+
284
+ # --- Backend Functions ---
285
  def _process(image):
286
  boxed_img, textual, json_str = process_image(image)
287
  return boxed_img, textual, json_str
288
+
289
  def create_and_download_json(json_str):
290
  if not json_str:
291
  return None
292
  with tempfile.NamedTemporaryFile(
293
+ prefix="detection_info_",
294
+ mode='w', delete=False, suffix='.json', encoding='utf-8'
295
  ) as f:
296
  f.write(json_str)
297
  return f.name
298
 
299
+ def check_token(token):
300
+ # Securely check if the token is correct
301
+ if ACCESS_TOKEN and token == ACCESS_TOKEN:
302
+ return gr.update(visible=True), "Token accepted. You can now use the application."
303
+ else:
304
+ return gr.update(visible=False), "Invalid token. Please try again."
305
+
306
+ # --- Event Listeners ---
307
+ unlock_btn.click(
308
+ check_token,
309
+ inputs=token_input,
310
+ outputs=[main_app, status_text]
311
+ )
312
+
313
+ run_btn.click(
314
+ _process,
315
+ inputs=img_input,
316
+ outputs=[img_out, text_out, json_state]
317
+ )
318
+
319
+ download_btn.click(
320
+ create_and_download_json,
321
+ inputs=json_state,
322
+ outputs=download_btn
323
+ )
324
 
325
  demo.launch()