Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| import json | |
| from PIL import Image,ImageDraw | |
| import tempfile | |
| from inference_sdk import InferenceHTTPClient | |
| from ultralytics import YOLO | |
| import ultralytics | |
| classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'tick', 'fraction'] | |
| API_KEY = os.environ.get("ROBOFLOW_API_KEY") | |
| ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN") | |
| def RoboFlowGetOutlineBoxesPIL(pil_img): | |
| client = InferenceHTTPClient( | |
| api_url="https://detect.roboflow.com", | |
| api_key= API_KEY | |
| ) | |
| result = client.run_workflow( | |
| workspace_name="mathnet-mmpuo", | |
| workflow_id="custom-workflow-2", | |
| images={ | |
| "image": pil_img | |
| }, | |
| use_cache=True # cache workflow definition for 15 minutes | |
| ) | |
| return result | |
| def localOutlineBox(image, selected_model="yolo_accurate"): | |
| model_paths = { | |
| "yolo_accurate": "./vt_dataset_yolov12_v7_weights.pt", | |
| "yolo_extra_large": "./VT_dataset_2_Yolov12_Extra_large.pt" | |
| } | |
| model_path = model_paths.get(selected_model, "./vt_dataset_yolov12_v7_weights.pt") | |
| model_path = "./vt_dataset_yolov12_v7_weights.pt" | |
| if os.path.exists(model_path): | |
| print("model exists") | |
| print(f'current model path is: {model_path}') | |
| else: | |
| print("model is not available") | |
| model = YOLO(model_path) | |
| yolo_ret = model(image, verbose=False) | |
| useful_ret = yolo_ret[0] | |
| names = model.names | |
| all_box_info = [] | |
| for bb in useful_ret.boxes: | |
| box_info = {} | |
| x, y, w, h = bb.xywh[0].tolist() | |
| cls_name = names[int(bb.cls)] | |
| box_info['class'] = cls_name | |
| box_info['x'] = x | |
| box_info['y'] = y | |
| box_info['width'] = w | |
| box_info['height'] = h | |
| box_info['confidence'] = float(bb.conf) | |
| box_info['class_id'] = int(bb.cls) | |
| all_box_info.append(box_info) | |
| print(f"total length of all box_info is: {len(all_box_info)}") | |
| return all_box_info | |
| def calculate_iou(box1, box2): | |
| """计算两个框的IoU | |
| box格式: (x, y, width, height) | |
| """ | |
| # 计算每个框的左上角和右下角坐标 | |
| box1_x1 = box1['x'] - box1['width']/2 | |
| box1_y1 = box1['y'] - box1['height']/2 | |
| box1_x2 = box1['x'] + box1['width']/2 | |
| box1_y2 = box1['y'] + box1['height']/2 | |
| box2_x1 = box2['x'] - box2['width']/2 | |
| box2_y1 = box2['y'] - box2['height']/2 | |
| box2_x2 = box2['x'] + box2['width']/2 | |
| box2_y2 = box2['y'] + box2['height']/2 | |
| # 计算交集区域的坐标 | |
| inter_x1 = max(box1_x1, box2_x1) | |
| inter_y1 = max(box1_y1, box2_y1) | |
| inter_x2 = min(box1_x2, box2_x2) | |
| inter_y2 = min(box1_y2, box2_y2) | |
| # 计算交集面积 | |
| if inter_x1 < inter_x2 and inter_y1 < inter_y2: | |
| inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) | |
| else: | |
| return 0.0 | |
| # 计算两个框的面积 | |
| box1_area = box1['width'] * box1['height'] | |
| box2_area = box2['width'] * box2['height'] | |
| # 计算并集面积 | |
| union_area = box1_area + box2_area - inter_area | |
| # 返回IoU | |
| return inter_area / union_area | |
| def parse_roboflow_result(result, kept_classes): | |
| all_box_info = [] | |
| for box_info in result[0]['predictions']['predictions']['predictions']: | |
| if box_info['class'] in kept_classes: | |
| all_box_info.append(box_info) | |
| return all_box_info | |
| def filter_overlapping_boxes(filter_box_info, iou_threshold=0.5): | |
| digit_classes = {'zero', 'one', 'two', 'three', 'four', | |
| 'five', 'six', 'seven', 'eight', 'nine'} | |
| # 分离数字框和其他框 | |
| digit_boxes = [] | |
| other_boxes = [] | |
| for box in filter_box_info: | |
| if box['class'] in digit_classes: | |
| digit_boxes.append(box) | |
| else: | |
| other_boxes.append(box) | |
| digit_boxes.sort(key=lambda x: x['confidence'], reverse=True) | |
| kept_boxes = [] | |
| for i, box in enumerate(digit_boxes): | |
| should_keep = True | |
| for kept_box in kept_boxes: | |
| if calculate_iou(box, kept_box) > iou_threshold: | |
| should_keep = False | |
| break | |
| if should_keep: | |
| kept_boxes.append(box) | |
| kept_other_boxes = [] | |
| for i, box in enumerate(other_boxes): | |
| should_keep = True | |
| for kept_box in kept_other_boxes: | |
| if calculate_iou(box, kept_box) > iou_threshold: | |
| should_keep = False | |
| break | |
| if should_keep: | |
| kept_other_boxes.append(box) | |
| return kept_other_boxes + kept_boxes | |
| def getCenterXDis(box1, box2): | |
| x0_1, y0_1, x1_1, y1_1 = box1 | |
| x0_2, y0_2, x1_2, y1_2 = box2 | |
| #centeral corrdinates | |
| center_x1 = (x0_1 + x1_1) / 2 | |
| center_y1 = (y0_1 + y1_1) / 2 | |
| center_x2 = (x0_2 + x1_2) / 2 | |
| center_y2 = (y0_2 + y1_2) / 2 | |
| return abs(center_x1 - center_x2) | |
| def to_ordinal(n): | |
| """ | |
| Converts an integer to its ordinal string representation. | |
| e.g., 1 -> "1st", 2 -> "2nd", 13 -> "13th" | |
| """ | |
| if not isinstance(n, int): | |
| raise TypeError("Input must be an integer.") | |
| # Check for 11th, 12th, 13th, which are special cases | |
| if 11 <= (n % 100) <= 13: | |
| suffix = 'th' | |
| else: | |
| # Check the last digit for all other cases | |
| last_digit = n % 10 | |
| if last_digit == 1: | |
| suffix = 'st' | |
| elif last_digit == 2: | |
| suffix = 'nd' | |
| elif last_digit == 3: | |
| suffix = 'rd' | |
| else: | |
| suffix = 'th' | |
| return f"{n}{suffix}" | |
| def packFilterBoxInfo(filter_box_info): | |
| # 数字类别映射 | |
| digit_classes = { | |
| 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', | |
| 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'zero': '0' | |
| } | |
| fraction_boxes = [] | |
| number_boxes = [] | |
| for box in filter_box_info: | |
| if box['class'] == 'fraction': | |
| fraction_boxes.append(box) | |
| elif box['class'] in digit_classes: | |
| number_boxes.append(box) | |
| fraction_boxes.sort(key=lambda x: x['x'] - x['width']/2) | |
| fraction_values = [] | |
| for frac_box in fraction_boxes: | |
| # fraction框的边界 | |
| frac_x = frac_box['x'] | |
| frac_y = frac_box['y'] | |
| frac_width = frac_box['width'] | |
| frac_height = frac_box['height'] | |
| # 定义分子分母的区域 | |
| numerator_numbers = [] | |
| denominator_numbers = [] | |
| # 遍历所有数字,判断是否在当前fraction框内 | |
| for num_box in number_boxes: | |
| # 检查数字是否在fraction框的水平范围内 | |
| if (frac_x - frac_width/2 <= num_box['x'] <= frac_x + frac_width/2 and frac_y - frac_height/2 <= num_box['y'] <= frac_y + frac_height/2): | |
| # 获取数字值 | |
| digit = digit_classes[num_box['class']] | |
| # 根据y坐标判断是分子还是分母 | |
| if num_box['y'] < frac_y: # 在分数线上方 | |
| numerator_numbers.append((num_box['x'], num_box['y'], num_box['width'], num_box['height'], digit)) | |
| else: # 在分数线下方 | |
| denominator_numbers.append((num_box['x'], num_box['y'], num_box['width'], num_box['height'], digit)) | |
| # 按x坐标排序 | |
| numerator_numbers.sort(key=lambda x: x[0]-x[2]/2) | |
| denominator_numbers.sort(key=lambda x: x[0]-x[2]/2) | |
| # 提取排序后的数字 | |
| numerator = ''.join(digit for _, _, _, _, digit in numerator_numbers) | |
| denominator = ''.join(digit for _, _, _, _, digit in denominator_numbers) | |
| if numerator == "": | |
| numerator = "?" | |
| if denominator == "": | |
| denominator = "?" | |
| fraction_values.append(f"{numerator}/{denominator}") | |
| return fraction_values | |
| #Assume its coordinate are top-left, bottom-right | |
| def getOverlap(box1, box2): | |
| b1_x1, b1_y1, b1_x2, b1_y2 = box1 | |
| b2_x1, b2_y1, b2_x2, b2_y2 = box2 | |
| inter_x1 = max(b1_x1, b2_x1) | |
| inter_y1= max(b1_y1, b2_y1) | |
| inter_x2 = min(b1_x2, b2_x2) | |
| inter_y2 = min(b1_y2, b2_y2) | |
| if inter_x1 < inter_x2 and inter_y1 < inter_y2: | |
| inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) | |
| else: | |
| return 0.0 | |
| #b1_area = abs(b1_x2 - b1_x1) * abs(b1_y2 - b1_y1) | |
| b2_area = abs(b2_x2 - b2_x1) * abs(b2_y2 - b2_y1) | |
| return inter_area / b2_area | |
| def tick2fraction(ticks, fractions): | |
| ret = [] | |
| used_ticks = set() | |
| for fi, frac in enumerate(fractions): | |
| all_dis = [] | |
| for ti, tick in enumerate(ticks): | |
| if ti not in used_ticks: | |
| dis = getCenterXDis(tick, frac) | |
| all_dis.append((dis, ti)) | |
| if len(all_dis) == 0: | |
| #print(f"no tick found for fraction {fi}") | |
| #no tick found for this fraction | |
| break | |
| all_dis.sort(key=lambda x: x[0]) | |
| min_dis_index = all_dis[0][1] | |
| used_ticks.add(min_dis_index) | |
| ret.append(f"T{min_dis_index}-F{fi}") | |
| return ret | |
| def generate_textual_description(box_info): | |
| fraction_values = packFilterBoxInfo(box_info) | |
| # Create a dictionary to store information by class ID | |
| class_summary = {c: [] for c in classes} | |
| for box in box_info: | |
| c_name = box['class'] | |
| if c_name not in class_summary: | |
| continue | |
| else: | |
| x, y, w, h = box['x'], box['y'], box['width'], box['height'] | |
| class_summary[c_name].append([x-w/2, y-h/2, x+w/2, y+h/2]) | |
| # Generate a summary for each class | |
| #the index of the left one | |
| kept_zero_boxes = [] | |
| for zero_box in class_summary['zero']: | |
| kept_zero = True | |
| for fra_box in class_summary['fraction']: | |
| if getOverlap(fra_box, zero_box) >= 0.5: | |
| kept_zero = False | |
| break | |
| for tick_box in class_summary['tick']: | |
| if getOverlap(tick_box, zero_box) >= 0.5: | |
| kept_zero = False | |
| break | |
| if kept_zero: | |
| kept_zero_boxes.append(zero_box) | |
| kept_one_boxes = [] | |
| for one_box in class_summary['one']: | |
| kept_one = True | |
| for fra_box in class_summary['fraction']: | |
| if getOverlap(fra_box, one_box) >= 0.5: | |
| kept_one = False | |
| break | |
| for tick_box in class_summary['tick']: | |
| if getOverlap(tick_box, one_box) >= 0.5: | |
| kept_one = False | |
| break | |
| if kept_one: | |
| kept_one_boxes.append(one_box) | |
| kept_zero_boxes.sort(key = lambda x: x[0]) | |
| kept_one_boxes.sort(key = lambda x: x[0]) | |
| textual_description = "" #final output | |
| textual_description += "The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)." | |
| #print(f"The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)") | |
| if len(kept_zero_boxes) >= 1: | |
| left_most_zero_cor = kept_zero_boxes[0] | |
| textual_description += f"\nThere is a zero on the left side of the number line. Its coordinate is (({left_most_zero_cor[0]:.2f}, {left_most_zero_cor[1]:.2f}), ({left_most_zero_cor[2]:.2f}, {left_most_zero_cor[3]:.2f}))" | |
| if len(kept_one_boxes) >= 1: | |
| right_most_one_cor = kept_one_boxes[-1] | |
| textual_description += f"\nThere is a one on the right side of the number line. Its coordinate is (({right_most_one_cor[0]:.2f}, {right_most_one_cor[1]:.2f}), ({right_most_one_cor[2]:.2f}, {right_most_one_cor[3]:.2f}))" | |
| present_classes = ['fraction', 'tick'] | |
| for cid, boxes in class_summary.items(): | |
| class_name = cid | |
| if class_name not in present_classes: | |
| continue | |
| count = len(boxes) | |
| boxes.sort(key=lambda x: x[0]) # it has been the x of the top-left corner | |
| if count > 0: | |
| textual_description += f"\nThere are {count} {class_name}s. Their coordinates are: " | |
| for box in boxes: | |
| textual_description += f"(({box[0]:.2f}, {box[1]:.2f}), ({box[2]:.2f}, {box[3]:.2f})), " | |
| if (class_name == "fraction"): | |
| textual_description += f"\nThe fraction numbers from left to right are: {fraction_values}. " | |
| tick2fra = tick2fraction(class_summary['tick'], class_summary['fraction']) | |
| tick2fraction_des = "" | |
| for cor in tick2fra: | |
| tick_part, fraction_part = cor.split('-') | |
| fraction_idx = int(fraction_part[1:]) | |
| tick_idx = int(tick_part[1:]) | |
| tick2fraction_des += f"{to_ordinal(fraction_idx + 1)} fraction is associated with {to_ordinal(tick_idx + 1)} tick. " | |
| textual_description += tick2fraction_des | |
| return textual_description | |
| def drawWithAllBox_info(pil_image, box_info): | |
| colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow', 'brown', 'pink', 'gray', 'lime', 'navy'] | |
| draw = ImageDraw.Draw(pil_image) | |
| for box in box_info: | |
| x, y, w, h = box['x'], box['y'], box['width'], box['height'] | |
| class_id = box['class_id'] | |
| color = 'black' | |
| if class_id < len(colors): | |
| color = colors[class_id] | |
| draw.rectangle([x-w/2, y-h/2, x+w/2, y+h/2], outline=color, width=2) | |
| return pil_image | |
| def online_process_image(image): | |
| if image is None: | |
| # Ensure we always return 3 values to prevent errors. | |
| return None, "", "" | |
| pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image) | |
| roboflow_ret = RoboFlowGetOutlineBoxesPIL(pil_image) | |
| all_box_info = parse_roboflow_result(roboflow_ret, classes) | |
| del roboflow_ret | |
| kept_box_info = filter_overlapping_boxes(all_box_info) | |
| del all_box_info | |
| boxed_img = drawWithAllBox_info(pil_image, kept_box_info) | |
| textual = generate_textual_description(kept_box_info) | |
| json_str = json.dumps(kept_box_info, indent=2) | |
| return boxed_img, textual, json_str | |
| def process_image(image, selected_model = "yolo_accurate"): | |
| print("start processing image") | |
| if image is None: | |
| # Ensure we always return 3 values to prevent errors. | |
| return None, "", "" | |
| #pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image) | |
| all_box_info = localOutlineBox(image, selected_model) | |
| kept_box_info = filter_overlapping_boxes(all_box_info) | |
| del all_box_info | |
| boxed_img = drawWithAllBox_info(image, kept_box_info) | |
| textual = generate_textual_description(kept_box_info) | |
| json_str = json.dumps(kept_box_info, indent=2) | |
| return boxed_img, textual, json_str | |
| with gr.Blocks( | |
| css=r""" | |
| /* 精确命中这个 Textbox 的文本区 */ | |
| textarea[aria-label="Textual Description"] { | |
| font-size: 28px !important; | |
| line-height: 1.6 !important; | |
| font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif !important; | |
| color: #222 !important; | |
| } | |
| /* 兜底选择器 */ | |
| #desc textarea, | |
| .large-font textarea, | |
| div[data-testid="textbox"] textarea { | |
| font-size: 24px !important; | |
| } | |
| """ | |
| ) as demo: | |
| # --- Authentication Layer --- | |
| with gr.Row(): | |
| token_input = gr.Textbox( | |
| label="Invite Token", | |
| type="password", | |
| placeholder="Enter your invite token to unlock the app" | |
| ) | |
| unlock_btn = gr.Button("Unlock") | |
| status_text = gr.Markdown() | |
| # --- Main Application (initially hidden) --- | |
| with gr.Column(visible=False) as main_app: | |
| img_input = gr.Image(type="pil", label="Upload Image") | |
| with gr.Row(): | |
| model_list = gr.Dropdown( | |
| choices=["yolo_accurate", "yolo_extra_large"], | |
| value="yolo_accurate", | |
| label="Select Model", | |
| info="Choose the YOLO model for detection" | |
| ) | |
| run_btn = gr.Button("Run Detection") | |
| img_out = gr.Image(type="pil", label="Image with Boxes") | |
| text_out = gr.Textbox(label="Textual Description", lines=8, elem_id="desc", elem_classes=["large-font"]) | |
| json_state = gr.State("") | |
| download_btn = gr.DownloadButton( | |
| label="Download Box Info as JSON", | |
| interactive=False # Start as disabled. | |
| ) | |
| # --- Backend Functions --- | |
| def create_json_file(json_str): | |
| """Creates a temp file with JSON content and returns its path.""" | |
| if not json_str: | |
| return None | |
| with tempfile.NamedTemporaryFile( | |
| prefix="detection_info_", | |
| mode='w', delete=False, suffix='.json', encoding='utf-8' | |
| ) as f: | |
| f.write(json_str) | |
| return f.name | |
| def _process_and_prepare_download(image, selected_model): | |
| """Processes the image, creates the JSON file, and updates the UI.""" | |
| boxed_img, textual, json_str = process_image(image, selected_model) | |
| filepath = create_json_file(json_str) | |
| # Use the legacy gr.update() for compatibility with older Gradio versions. | |
| download_update = gr.update(value=filepath, interactive=True) | |
| return boxed_img, textual, json_str, download_update | |
| def check_token(token): | |
| # Securely check if the token is correct | |
| if ACCESS_TOKEN and token == ACCESS_TOKEN: | |
| return gr.update(visible=True), "Token accepted. You can now use the application." | |
| else: | |
| return gr.update(visible=False), "Invalid token. Please try again." | |
| # --- Event Listeners --- | |
| unlock_btn.click( | |
| check_token, | |
| inputs=token_input, | |
| outputs=[main_app, status_text] | |
| ) | |
| run_btn.click( | |
| _process_and_prepare_download, | |
| inputs=[img_input, model_list], | |
| # The output now includes the download button itself. | |
| outputs=[img_out, text_out, json_state, download_btn] | |
| ) | |
| # The download_btn no longer needs its own click event. | |
| demo.launch() |