MathNet / app.py
wzzanthony7's picture
Update app.py
92746fc verified
import gradio as gr
import os
import sys
import json
from PIL import Image,ImageDraw
import tempfile
from inference_sdk import InferenceHTTPClient
from ultralytics import YOLO
import ultralytics
classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'tick', 'fraction']
API_KEY = os.environ.get("ROBOFLOW_API_KEY")
ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
def RoboFlowGetOutlineBoxesPIL(pil_img):
client = InferenceHTTPClient(
api_url="https://detect.roboflow.com",
api_key= API_KEY
)
result = client.run_workflow(
workspace_name="mathnet-mmpuo",
workflow_id="custom-workflow-2",
images={
"image": pil_img
},
use_cache=True # cache workflow definition for 15 minutes
)
return result
def localOutlineBox(image, selected_model="yolo_accurate"):
model_paths = {
"yolo_accurate": "./vt_dataset_yolov12_v7_weights.pt",
"yolo_extra_large": "./VT_dataset_2_Yolov12_Extra_large.pt"
}
model_path = model_paths.get(selected_model, "./vt_dataset_yolov12_v7_weights.pt")
model_path = "./vt_dataset_yolov12_v7_weights.pt"
if os.path.exists(model_path):
print("model exists")
print(f'current model path is: {model_path}')
else:
print("model is not available")
model = YOLO(model_path)
yolo_ret = model(image, verbose=False)
useful_ret = yolo_ret[0]
names = model.names
all_box_info = []
for bb in useful_ret.boxes:
box_info = {}
x, y, w, h = bb.xywh[0].tolist()
cls_name = names[int(bb.cls)]
box_info['class'] = cls_name
box_info['x'] = x
box_info['y'] = y
box_info['width'] = w
box_info['height'] = h
box_info['confidence'] = float(bb.conf)
box_info['class_id'] = int(bb.cls)
all_box_info.append(box_info)
print(f"total length of all box_info is: {len(all_box_info)}")
return all_box_info
def calculate_iou(box1, box2):
"""计算两个框的IoU
box格式: (x, y, width, height)
"""
# 计算每个框的左上角和右下角坐标
box1_x1 = box1['x'] - box1['width']/2
box1_y1 = box1['y'] - box1['height']/2
box1_x2 = box1['x'] + box1['width']/2
box1_y2 = box1['y'] + box1['height']/2
box2_x1 = box2['x'] - box2['width']/2
box2_y1 = box2['y'] - box2['height']/2
box2_x2 = box2['x'] + box2['width']/2
box2_y2 = box2['y'] + box2['height']/2
# 计算交集区域的坐标
inter_x1 = max(box1_x1, box2_x1)
inter_y1 = max(box1_y1, box2_y1)
inter_x2 = min(box1_x2, box2_x2)
inter_y2 = min(box1_y2, box2_y2)
# 计算交集面积
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
else:
return 0.0
# 计算两个框的面积
box1_area = box1['width'] * box1['height']
box2_area = box2['width'] * box2['height']
# 计算并集面积
union_area = box1_area + box2_area - inter_area
# 返回IoU
return inter_area / union_area
def parse_roboflow_result(result, kept_classes):
all_box_info = []
for box_info in result[0]['predictions']['predictions']['predictions']:
if box_info['class'] in kept_classes:
all_box_info.append(box_info)
return all_box_info
def filter_overlapping_boxes(filter_box_info, iou_threshold=0.5):
digit_classes = {'zero', 'one', 'two', 'three', 'four',
'five', 'six', 'seven', 'eight', 'nine'}
# 分离数字框和其他框
digit_boxes = []
other_boxes = []
for box in filter_box_info:
if box['class'] in digit_classes:
digit_boxes.append(box)
else:
other_boxes.append(box)
digit_boxes.sort(key=lambda x: x['confidence'], reverse=True)
kept_boxes = []
for i, box in enumerate(digit_boxes):
should_keep = True
for kept_box in kept_boxes:
if calculate_iou(box, kept_box) > iou_threshold:
should_keep = False
break
if should_keep:
kept_boxes.append(box)
kept_other_boxes = []
for i, box in enumerate(other_boxes):
should_keep = True
for kept_box in kept_other_boxes:
if calculate_iou(box, kept_box) > iou_threshold:
should_keep = False
break
if should_keep:
kept_other_boxes.append(box)
return kept_other_boxes + kept_boxes
def getCenterXDis(box1, box2):
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
#centeral corrdinates
center_x1 = (x0_1 + x1_1) / 2
center_y1 = (y0_1 + y1_1) / 2
center_x2 = (x0_2 + x1_2) / 2
center_y2 = (y0_2 + y1_2) / 2
return abs(center_x1 - center_x2)
def to_ordinal(n):
"""
Converts an integer to its ordinal string representation.
e.g., 1 -> "1st", 2 -> "2nd", 13 -> "13th"
"""
if not isinstance(n, int):
raise TypeError("Input must be an integer.")
# Check for 11th, 12th, 13th, which are special cases
if 11 <= (n % 100) <= 13:
suffix = 'th'
else:
# Check the last digit for all other cases
last_digit = n % 10
if last_digit == 1:
suffix = 'st'
elif last_digit == 2:
suffix = 'nd'
elif last_digit == 3:
suffix = 'rd'
else:
suffix = 'th'
return f"{n}{suffix}"
def packFilterBoxInfo(filter_box_info):
# 数字类别映射
digit_classes = {
'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5',
'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'zero': '0'
}
fraction_boxes = []
number_boxes = []
for box in filter_box_info:
if box['class'] == 'fraction':
fraction_boxes.append(box)
elif box['class'] in digit_classes:
number_boxes.append(box)
fraction_boxes.sort(key=lambda x: x['x'] - x['width']/2)
fraction_values = []
for frac_box in fraction_boxes:
# fraction框的边界
frac_x = frac_box['x']
frac_y = frac_box['y']
frac_width = frac_box['width']
frac_height = frac_box['height']
# 定义分子分母的区域
numerator_numbers = []
denominator_numbers = []
# 遍历所有数字,判断是否在当前fraction框内
for num_box in number_boxes:
# 检查数字是否在fraction框的水平范围内
if (frac_x - frac_width/2 <= num_box['x'] <= frac_x + frac_width/2 and frac_y - frac_height/2 <= num_box['y'] <= frac_y + frac_height/2):
# 获取数字值
digit = digit_classes[num_box['class']]
# 根据y坐标判断是分子还是分母
if num_box['y'] < frac_y: # 在分数线上方
numerator_numbers.append((num_box['x'], num_box['y'], num_box['width'], num_box['height'], digit))
else: # 在分数线下方
denominator_numbers.append((num_box['x'], num_box['y'], num_box['width'], num_box['height'], digit))
# 按x坐标排序
numerator_numbers.sort(key=lambda x: x[0]-x[2]/2)
denominator_numbers.sort(key=lambda x: x[0]-x[2]/2)
# 提取排序后的数字
numerator = ''.join(digit for _, _, _, _, digit in numerator_numbers)
denominator = ''.join(digit for _, _, _, _, digit in denominator_numbers)
if numerator == "":
numerator = "?"
if denominator == "":
denominator = "?"
fraction_values.append(f"{numerator}/{denominator}")
return fraction_values
#Assume its coordinate are top-left, bottom-right
def getOverlap(box1, box2):
b1_x1, b1_y1, b1_x2, b1_y2 = box1
b2_x1, b2_y1, b2_x2, b2_y2 = box2
inter_x1 = max(b1_x1, b2_x1)
inter_y1= max(b1_y1, b2_y1)
inter_x2 = min(b1_x2, b2_x2)
inter_y2 = min(b1_y2, b2_y2)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
else:
return 0.0
#b1_area = abs(b1_x2 - b1_x1) * abs(b1_y2 - b1_y1)
b2_area = abs(b2_x2 - b2_x1) * abs(b2_y2 - b2_y1)
return inter_area / b2_area
def tick2fraction(ticks, fractions):
ret = []
used_ticks = set()
for fi, frac in enumerate(fractions):
all_dis = []
for ti, tick in enumerate(ticks):
if ti not in used_ticks:
dis = getCenterXDis(tick, frac)
all_dis.append((dis, ti))
if len(all_dis) == 0:
#print(f"no tick found for fraction {fi}")
#no tick found for this fraction
break
all_dis.sort(key=lambda x: x[0])
min_dis_index = all_dis[0][1]
used_ticks.add(min_dis_index)
ret.append(f"T{min_dis_index}-F{fi}")
return ret
def generate_textual_description(box_info):
fraction_values = packFilterBoxInfo(box_info)
# Create a dictionary to store information by class ID
class_summary = {c: [] for c in classes}
for box in box_info:
c_name = box['class']
if c_name not in class_summary:
continue
else:
x, y, w, h = box['x'], box['y'], box['width'], box['height']
class_summary[c_name].append([x-w/2, y-h/2, x+w/2, y+h/2])
# Generate a summary for each class
#the index of the left one
kept_zero_boxes = []
for zero_box in class_summary['zero']:
kept_zero = True
for fra_box in class_summary['fraction']:
if getOverlap(fra_box, zero_box) >= 0.5:
kept_zero = False
break
for tick_box in class_summary['tick']:
if getOverlap(tick_box, zero_box) >= 0.5:
kept_zero = False
break
if kept_zero:
kept_zero_boxes.append(zero_box)
kept_one_boxes = []
for one_box in class_summary['one']:
kept_one = True
for fra_box in class_summary['fraction']:
if getOverlap(fra_box, one_box) >= 0.5:
kept_one = False
break
for tick_box in class_summary['tick']:
if getOverlap(tick_box, one_box) >= 0.5:
kept_one = False
break
if kept_one:
kept_one_boxes.append(one_box)
kept_zero_boxes.sort(key = lambda x: x[0])
kept_one_boxes.sort(key = lambda x: x[0])
textual_description = "" #final output
textual_description += "The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)."
#print(f"The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)")
if len(kept_zero_boxes) >= 1:
left_most_zero_cor = kept_zero_boxes[0]
textual_description += f"\nThere is a zero on the left side of the number line. Its coordinate is (({left_most_zero_cor[0]:.2f}, {left_most_zero_cor[1]:.2f}), ({left_most_zero_cor[2]:.2f}, {left_most_zero_cor[3]:.2f}))"
if len(kept_one_boxes) >= 1:
right_most_one_cor = kept_one_boxes[-1]
textual_description += f"\nThere is a one on the right side of the number line. Its coordinate is (({right_most_one_cor[0]:.2f}, {right_most_one_cor[1]:.2f}), ({right_most_one_cor[2]:.2f}, {right_most_one_cor[3]:.2f}))"
present_classes = ['fraction', 'tick']
for cid, boxes in class_summary.items():
class_name = cid
if class_name not in present_classes:
continue
count = len(boxes)
boxes.sort(key=lambda x: x[0]) # it has been the x of the top-left corner
if count > 0:
textual_description += f"\nThere are {count} {class_name}s. Their coordinates are: "
for box in boxes:
textual_description += f"(({box[0]:.2f}, {box[1]:.2f}), ({box[2]:.2f}, {box[3]:.2f})), "
if (class_name == "fraction"):
textual_description += f"\nThe fraction numbers from left to right are: {fraction_values}. "
tick2fra = tick2fraction(class_summary['tick'], class_summary['fraction'])
tick2fraction_des = ""
for cor in tick2fra:
tick_part, fraction_part = cor.split('-')
fraction_idx = int(fraction_part[1:])
tick_idx = int(tick_part[1:])
tick2fraction_des += f"{to_ordinal(fraction_idx + 1)} fraction is associated with {to_ordinal(tick_idx + 1)} tick. "
textual_description += tick2fraction_des
return textual_description
def drawWithAllBox_info(pil_image, box_info):
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow', 'brown', 'pink', 'gray', 'lime', 'navy']
draw = ImageDraw.Draw(pil_image)
for box in box_info:
x, y, w, h = box['x'], box['y'], box['width'], box['height']
class_id = box['class_id']
color = 'black'
if class_id < len(colors):
color = colors[class_id]
draw.rectangle([x-w/2, y-h/2, x+w/2, y+h/2], outline=color, width=2)
return pil_image
def online_process_image(image):
if image is None:
# Ensure we always return 3 values to prevent errors.
return None, "", ""
pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image)
roboflow_ret = RoboFlowGetOutlineBoxesPIL(pil_image)
all_box_info = parse_roboflow_result(roboflow_ret, classes)
del roboflow_ret
kept_box_info = filter_overlapping_boxes(all_box_info)
del all_box_info
boxed_img = drawWithAllBox_info(pil_image, kept_box_info)
textual = generate_textual_description(kept_box_info)
json_str = json.dumps(kept_box_info, indent=2)
return boxed_img, textual, json_str
def process_image(image, selected_model = "yolo_accurate"):
print("start processing image")
if image is None:
# Ensure we always return 3 values to prevent errors.
return None, "", ""
#pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image)
all_box_info = localOutlineBox(image, selected_model)
kept_box_info = filter_overlapping_boxes(all_box_info)
del all_box_info
boxed_img = drawWithAllBox_info(image, kept_box_info)
textual = generate_textual_description(kept_box_info)
json_str = json.dumps(kept_box_info, indent=2)
return boxed_img, textual, json_str
with gr.Blocks(
css=r"""
/* 精确命中这个 Textbox 的文本区 */
textarea[aria-label="Textual Description"] {
font-size: 28px !important;
line-height: 1.6 !important;
font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif !important;
color: #222 !important;
}
/* 兜底选择器 */
#desc textarea,
.large-font textarea,
div[data-testid="textbox"] textarea {
font-size: 24px !important;
}
"""
) as demo:
# --- Authentication Layer ---
with gr.Row():
token_input = gr.Textbox(
label="Invite Token",
type="password",
placeholder="Enter your invite token to unlock the app"
)
unlock_btn = gr.Button("Unlock")
status_text = gr.Markdown()
# --- Main Application (initially hidden) ---
with gr.Column(visible=False) as main_app:
img_input = gr.Image(type="pil", label="Upload Image")
with gr.Row():
model_list = gr.Dropdown(
choices=["yolo_accurate", "yolo_extra_large"],
value="yolo_accurate",
label="Select Model",
info="Choose the YOLO model for detection"
)
run_btn = gr.Button("Run Detection")
img_out = gr.Image(type="pil", label="Image with Boxes")
text_out = gr.Textbox(label="Textual Description", lines=8, elem_id="desc", elem_classes=["large-font"])
json_state = gr.State("")
download_btn = gr.DownloadButton(
label="Download Box Info as JSON",
interactive=False # Start as disabled.
)
# --- Backend Functions ---
def create_json_file(json_str):
"""Creates a temp file with JSON content and returns its path."""
if not json_str:
return None
with tempfile.NamedTemporaryFile(
prefix="detection_info_",
mode='w', delete=False, suffix='.json', encoding='utf-8'
) as f:
f.write(json_str)
return f.name
def _process_and_prepare_download(image, selected_model):
"""Processes the image, creates the JSON file, and updates the UI."""
boxed_img, textual, json_str = process_image(image, selected_model)
filepath = create_json_file(json_str)
# Use the legacy gr.update() for compatibility with older Gradio versions.
download_update = gr.update(value=filepath, interactive=True)
return boxed_img, textual, json_str, download_update
def check_token(token):
# Securely check if the token is correct
if ACCESS_TOKEN and token == ACCESS_TOKEN:
return gr.update(visible=True), "Token accepted. You can now use the application."
else:
return gr.update(visible=False), "Invalid token. Please try again."
# --- Event Listeners ---
unlock_btn.click(
check_token,
inputs=token_input,
outputs=[main_app, status_text]
)
run_btn.click(
_process_and_prepare_download,
inputs=[img_input, model_list],
# The output now includes the download button itself.
outputs=[img_out, text_out, json_state, download_btn]
)
# The download_btn no longer needs its own click event.
demo.launch()