import gradio as gr import os import sys import json from PIL import Image,ImageDraw import tempfile from inference_sdk import InferenceHTTPClient from ultralytics import YOLO import ultralytics classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'tick', 'fraction'] API_KEY = os.environ.get("ROBOFLOW_API_KEY") ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN") def RoboFlowGetOutlineBoxesPIL(pil_img): client = InferenceHTTPClient( api_url="https://detect.roboflow.com", api_key= API_KEY ) result = client.run_workflow( workspace_name="mathnet-mmpuo", workflow_id="custom-workflow-2", images={ "image": pil_img }, use_cache=True # cache workflow definition for 15 minutes ) return result def localOutlineBox(image, selected_model="yolo_accurate"): model_paths = { "yolo_accurate": "./vt_dataset_yolov12_v7_weights.pt", "yolo_extra_large": "./VT_dataset_2_Yolov12_Extra_large.pt" } model_path = model_paths.get(selected_model, "./vt_dataset_yolov12_v7_weights.pt") model_path = "./vt_dataset_yolov12_v7_weights.pt" if os.path.exists(model_path): print("model exists") print(f'current model path is: {model_path}') else: print("model is not available") model = YOLO(model_path) yolo_ret = model(image, verbose=False) useful_ret = yolo_ret[0] names = model.names all_box_info = [] for bb in useful_ret.boxes: box_info = {} x, y, w, h = bb.xywh[0].tolist() cls_name = names[int(bb.cls)] box_info['class'] = cls_name box_info['x'] = x box_info['y'] = y box_info['width'] = w box_info['height'] = h box_info['confidence'] = float(bb.conf) box_info['class_id'] = int(bb.cls) all_box_info.append(box_info) print(f"total length of all box_info is: {len(all_box_info)}") return all_box_info def calculate_iou(box1, box2): """计算两个框的IoU box格式: (x, y, width, height) """ # 计算每个框的左上角和右下角坐标 box1_x1 = box1['x'] - box1['width']/2 box1_y1 = box1['y'] - box1['height']/2 box1_x2 = box1['x'] + box1['width']/2 box1_y2 = box1['y'] + box1['height']/2 box2_x1 = box2['x'] - box2['width']/2 box2_y1 = box2['y'] - box2['height']/2 box2_x2 = box2['x'] + box2['width']/2 box2_y2 = box2['y'] + box2['height']/2 # 计算交集区域的坐标 inter_x1 = max(box1_x1, box2_x1) inter_y1 = max(box1_y1, box2_y1) inter_x2 = min(box1_x2, box2_x2) inter_y2 = min(box1_y2, box2_y2) # 计算交集面积 if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) else: return 0.0 # 计算两个框的面积 box1_area = box1['width'] * box1['height'] box2_area = box2['width'] * box2['height'] # 计算并集面积 union_area = box1_area + box2_area - inter_area # 返回IoU return inter_area / union_area def parse_roboflow_result(result, kept_classes): all_box_info = [] for box_info in result[0]['predictions']['predictions']['predictions']: if box_info['class'] in kept_classes: all_box_info.append(box_info) return all_box_info def filter_overlapping_boxes(filter_box_info, iou_threshold=0.5): digit_classes = {'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine'} # 分离数字框和其他框 digit_boxes = [] other_boxes = [] for box in filter_box_info: if box['class'] in digit_classes: digit_boxes.append(box) else: other_boxes.append(box) digit_boxes.sort(key=lambda x: x['confidence'], reverse=True) kept_boxes = [] for i, box in enumerate(digit_boxes): should_keep = True for kept_box in kept_boxes: if calculate_iou(box, kept_box) > iou_threshold: should_keep = False break if should_keep: kept_boxes.append(box) kept_other_boxes = [] for i, box in enumerate(other_boxes): should_keep = True for kept_box in kept_other_boxes: if calculate_iou(box, kept_box) > iou_threshold: should_keep = False break if should_keep: kept_other_boxes.append(box) return kept_other_boxes + kept_boxes def getCenterXDis(box1, box2): x0_1, y0_1, x1_1, y1_1 = box1 x0_2, y0_2, x1_2, y1_2 = box2 #centeral corrdinates center_x1 = (x0_1 + x1_1) / 2 center_y1 = (y0_1 + y1_1) / 2 center_x2 = (x0_2 + x1_2) / 2 center_y2 = (y0_2 + y1_2) / 2 return abs(center_x1 - center_x2) def to_ordinal(n): """ Converts an integer to its ordinal string representation. e.g., 1 -> "1st", 2 -> "2nd", 13 -> "13th" """ if not isinstance(n, int): raise TypeError("Input must be an integer.") # Check for 11th, 12th, 13th, which are special cases if 11 <= (n % 100) <= 13: suffix = 'th' else: # Check the last digit for all other cases last_digit = n % 10 if last_digit == 1: suffix = 'st' elif last_digit == 2: suffix = 'nd' elif last_digit == 3: suffix = 'rd' else: suffix = 'th' return f"{n}{suffix}" def packFilterBoxInfo(filter_box_info): # 数字类别映射 digit_classes = { 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'zero': '0' } fraction_boxes = [] number_boxes = [] for box in filter_box_info: if box['class'] == 'fraction': fraction_boxes.append(box) elif box['class'] in digit_classes: number_boxes.append(box) fraction_boxes.sort(key=lambda x: x['x'] - x['width']/2) fraction_values = [] for frac_box in fraction_boxes: # fraction框的边界 frac_x = frac_box['x'] frac_y = frac_box['y'] frac_width = frac_box['width'] frac_height = frac_box['height'] # 定义分子分母的区域 numerator_numbers = [] denominator_numbers = [] # 遍历所有数字,判断是否在当前fraction框内 for num_box in number_boxes: # 检查数字是否在fraction框的水平范围内 if (frac_x - frac_width/2 <= num_box['x'] <= frac_x + frac_width/2 and frac_y - frac_height/2 <= num_box['y'] <= frac_y + frac_height/2): # 获取数字值 digit = digit_classes[num_box['class']] # 根据y坐标判断是分子还是分母 if num_box['y'] < frac_y: # 在分数线上方 numerator_numbers.append((num_box['x'], num_box['y'], num_box['width'], num_box['height'], digit)) else: # 在分数线下方 denominator_numbers.append((num_box['x'], num_box['y'], num_box['width'], num_box['height'], digit)) # 按x坐标排序 numerator_numbers.sort(key=lambda x: x[0]-x[2]/2) denominator_numbers.sort(key=lambda x: x[0]-x[2]/2) # 提取排序后的数字 numerator = ''.join(digit for _, _, _, _, digit in numerator_numbers) denominator = ''.join(digit for _, _, _, _, digit in denominator_numbers) if numerator == "": numerator = "?" if denominator == "": denominator = "?" fraction_values.append(f"{numerator}/{denominator}") return fraction_values #Assume its coordinate are top-left, bottom-right def getOverlap(box1, box2): b1_x1, b1_y1, b1_x2, b1_y2 = box1 b2_x1, b2_y1, b2_x2, b2_y2 = box2 inter_x1 = max(b1_x1, b2_x1) inter_y1= max(b1_y1, b2_y1) inter_x2 = min(b1_x2, b2_x2) inter_y2 = min(b1_y2, b2_y2) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) else: return 0.0 #b1_area = abs(b1_x2 - b1_x1) * abs(b1_y2 - b1_y1) b2_area = abs(b2_x2 - b2_x1) * abs(b2_y2 - b2_y1) return inter_area / b2_area def tick2fraction(ticks, fractions): ret = [] used_ticks = set() for fi, frac in enumerate(fractions): all_dis = [] for ti, tick in enumerate(ticks): if ti not in used_ticks: dis = getCenterXDis(tick, frac) all_dis.append((dis, ti)) if len(all_dis) == 0: #print(f"no tick found for fraction {fi}") #no tick found for this fraction break all_dis.sort(key=lambda x: x[0]) min_dis_index = all_dis[0][1] used_ticks.add(min_dis_index) ret.append(f"T{min_dis_index}-F{fi}") return ret def generate_textual_description(box_info): fraction_values = packFilterBoxInfo(box_info) # Create a dictionary to store information by class ID class_summary = {c: [] for c in classes} for box in box_info: c_name = box['class'] if c_name not in class_summary: continue else: x, y, w, h = box['x'], box['y'], box['width'], box['height'] class_summary[c_name].append([x-w/2, y-h/2, x+w/2, y+h/2]) # Generate a summary for each class #the index of the left one kept_zero_boxes = [] for zero_box in class_summary['zero']: kept_zero = True for fra_box in class_summary['fraction']: if getOverlap(fra_box, zero_box) >= 0.5: kept_zero = False break for tick_box in class_summary['tick']: if getOverlap(tick_box, zero_box) >= 0.5: kept_zero = False break if kept_zero: kept_zero_boxes.append(zero_box) kept_one_boxes = [] for one_box in class_summary['one']: kept_one = True for fra_box in class_summary['fraction']: if getOverlap(fra_box, one_box) >= 0.5: kept_one = False break for tick_box in class_summary['tick']: if getOverlap(tick_box, one_box) >= 0.5: kept_one = False break if kept_one: kept_one_boxes.append(one_box) kept_zero_boxes.sort(key = lambda x: x[0]) kept_one_boxes.sort(key = lambda x: x[0]) textual_description = "" #final output textual_description += "The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)." #print(f"The key elements are interpreted via visual translator. Their coordinates are represented as outlined boxes (top-left, bottom-right)") if len(kept_zero_boxes) >= 1: left_most_zero_cor = kept_zero_boxes[0] textual_description += f"\nThere is a zero on the left side of the number line. Its coordinate is (({left_most_zero_cor[0]:.2f}, {left_most_zero_cor[1]:.2f}), ({left_most_zero_cor[2]:.2f}, {left_most_zero_cor[3]:.2f}))" if len(kept_one_boxes) >= 1: right_most_one_cor = kept_one_boxes[-1] textual_description += f"\nThere is a one on the right side of the number line. Its coordinate is (({right_most_one_cor[0]:.2f}, {right_most_one_cor[1]:.2f}), ({right_most_one_cor[2]:.2f}, {right_most_one_cor[3]:.2f}))" present_classes = ['fraction', 'tick'] for cid, boxes in class_summary.items(): class_name = cid if class_name not in present_classes: continue count = len(boxes) boxes.sort(key=lambda x: x[0]) # it has been the x of the top-left corner if count > 0: textual_description += f"\nThere are {count} {class_name}s. Their coordinates are: " for box in boxes: textual_description += f"(({box[0]:.2f}, {box[1]:.2f}), ({box[2]:.2f}, {box[3]:.2f})), " if (class_name == "fraction"): textual_description += f"\nThe fraction numbers from left to right are: {fraction_values}. " tick2fra = tick2fraction(class_summary['tick'], class_summary['fraction']) tick2fraction_des = "" for cor in tick2fra: tick_part, fraction_part = cor.split('-') fraction_idx = int(fraction_part[1:]) tick_idx = int(tick_part[1:]) tick2fraction_des += f"{to_ordinal(fraction_idx + 1)} fraction is associated with {to_ordinal(tick_idx + 1)} tick. " textual_description += tick2fraction_des return textual_description def drawWithAllBox_info(pil_image, box_info): colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta', 'yellow', 'brown', 'pink', 'gray', 'lime', 'navy'] draw = ImageDraw.Draw(pil_image) for box in box_info: x, y, w, h = box['x'], box['y'], box['width'], box['height'] class_id = box['class_id'] color = 'black' if class_id < len(colors): color = colors[class_id] draw.rectangle([x-w/2, y-h/2, x+w/2, y+h/2], outline=color, width=2) return pil_image def online_process_image(image): if image is None: # Ensure we always return 3 values to prevent errors. return None, "", "" pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image) roboflow_ret = RoboFlowGetOutlineBoxesPIL(pil_image) all_box_info = parse_roboflow_result(roboflow_ret, classes) del roboflow_ret kept_box_info = filter_overlapping_boxes(all_box_info) del all_box_info boxed_img = drawWithAllBox_info(pil_image, kept_box_info) textual = generate_textual_description(kept_box_info) json_str = json.dumps(kept_box_info, indent=2) return boxed_img, textual, json_str def process_image(image, selected_model = "yolo_accurate"): print("start processing image") if image is None: # Ensure we always return 3 values to prevent errors. return None, "", "" #pil_image = image.copy() if hasattr(image, 'copy') else Image.fromarray(image) all_box_info = localOutlineBox(image, selected_model) kept_box_info = filter_overlapping_boxes(all_box_info) del all_box_info boxed_img = drawWithAllBox_info(image, kept_box_info) textual = generate_textual_description(kept_box_info) json_str = json.dumps(kept_box_info, indent=2) return boxed_img, textual, json_str with gr.Blocks( css=r""" /* 精确命中这个 Textbox 的文本区 */ textarea[aria-label="Textual Description"] { font-size: 28px !important; line-height: 1.6 !important; font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif !important; color: #222 !important; } /* 兜底选择器 */ #desc textarea, .large-font textarea, div[data-testid="textbox"] textarea { font-size: 24px !important; } """ ) as demo: # --- Authentication Layer --- with gr.Row(): token_input = gr.Textbox( label="Invite Token", type="password", placeholder="Enter your invite token to unlock the app" ) unlock_btn = gr.Button("Unlock") status_text = gr.Markdown() # --- Main Application (initially hidden) --- with gr.Column(visible=False) as main_app: img_input = gr.Image(type="pil", label="Upload Image") with gr.Row(): model_list = gr.Dropdown( choices=["yolo_accurate", "yolo_extra_large"], value="yolo_accurate", label="Select Model", info="Choose the YOLO model for detection" ) run_btn = gr.Button("Run Detection") img_out = gr.Image(type="pil", label="Image with Boxes") text_out = gr.Textbox(label="Textual Description", lines=8, elem_id="desc", elem_classes=["large-font"]) json_state = gr.State("") download_btn = gr.DownloadButton( label="Download Box Info as JSON", interactive=False # Start as disabled. ) # --- Backend Functions --- def create_json_file(json_str): """Creates a temp file with JSON content and returns its path.""" if not json_str: return None with tempfile.NamedTemporaryFile( prefix="detection_info_", mode='w', delete=False, suffix='.json', encoding='utf-8' ) as f: f.write(json_str) return f.name def _process_and_prepare_download(image, selected_model): """Processes the image, creates the JSON file, and updates the UI.""" boxed_img, textual, json_str = process_image(image, selected_model) filepath = create_json_file(json_str) # Use the legacy gr.update() for compatibility with older Gradio versions. download_update = gr.update(value=filepath, interactive=True) return boxed_img, textual, json_str, download_update def check_token(token): # Securely check if the token is correct if ACCESS_TOKEN and token == ACCESS_TOKEN: return gr.update(visible=True), "Token accepted. You can now use the application." else: return gr.update(visible=False), "Invalid token. Please try again." # --- Event Listeners --- unlock_btn.click( check_token, inputs=token_input, outputs=[main_app, status_text] ) run_btn.click( _process_and_prepare_download, inputs=[img_input, model_list], # The output now includes the download button itself. outputs=[img_out, text_out, json_state, download_btn] ) # The download_btn no longer needs its own click event. demo.launch()