Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import json | |
| import base64 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import subprocess | |
| import sys | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from typing import Optional, Dict, Any, List, Union | |
| # Setup OmniParser repository and models | |
| def setup_omniparser(): | |
| """Clone OmniParser repository and download model weights""" | |
| try: | |
| # Check if OmniParser repository exists | |
| if not os.path.exists("OmniParser"): | |
| print("Cloning OmniParser repository...") | |
| subprocess.run(["git", "clone", "https://github.com/microsoft/OmniParser.git"], check=True) | |
| # Add OmniParser to Python path | |
| omniparser_path = os.path.abspath("OmniParser") | |
| if omniparser_path not in sys.path: | |
| sys.path.append(omniparser_path) | |
| print(f"Added {omniparser_path} to Python path") | |
| # Create weights directory | |
| os.makedirs("OmniParser/weights/icon_detect", exist_ok=True) | |
| os.makedirs("OmniParser/weights/icon_caption_florence", exist_ok=True) | |
| # Download model weights if they don't exist | |
| if not os.path.exists("OmniParser/weights/icon_detect/model.pt") or not os.path.exists("OmniParser/weights/icon_caption_florence/model.safetensors"): | |
| print("Downloading model weights...") | |
| # Download detection model files | |
| for f in ["train_args.yaml", "model.pt", "model.yaml"]: | |
| hf_hub_download( | |
| repo_id="microsoft/OmniParser-v2.0", | |
| filename=f"icon_detect/{f}", | |
| local_dir="OmniParser/weights" | |
| ) | |
| # Download caption model files | |
| for f in ["config.json", "generation_config.json", "model.safetensors"]: | |
| hf_hub_download( | |
| repo_id="microsoft/OmniParser-v2.0", | |
| filename=f"icon_caption/{f}", | |
| local_dir="OmniParser/weights" | |
| ) | |
| # Rename the caption folder to match expected path | |
| if os.path.exists("OmniParser/weights/icon_caption") and not os.path.exists("OmniParser/weights/icon_caption_florence"): | |
| os.rename("OmniParser/weights/icon_caption", "OmniParser/weights/icon_caption_florence") | |
| # Patch PaddleOCR initialization in utils.py to fix compatibility issue | |
| utils_path = os.path.join(omniparser_path, "util", "utils.py") | |
| if os.path.exists(utils_path): | |
| print("Patching utils.py to fix compatibility issues...") | |
| # Create a simplified version of utils.py with essential functions | |
| simplified_utils = """import os | |
| import io | |
| import cv2 | |
| import base64 | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw | |
| def check_ocr_box(image, display_img=False, output_bb_format='xyxy', goal_filtering=None, | |
| easyocr_args=None, use_paddleocr=True): | |
| """ | |
| Custom implementation of check_ocr_box that uses EasyOCR | |
| """ | |
| try: | |
| import easyocr | |
| # Convert PIL Image to numpy array | |
| img_np = np.array(image) | |
| # Initialize EasyOCR | |
| reader = easyocr.Reader(['en']) | |
| # Run OCR | |
| results = reader.readtext(img_np) | |
| # Extract text and bounding boxes | |
| texts = [] | |
| boxes = [] | |
| for result in results: | |
| box, text, _ = result | |
| texts.append(text) | |
| # Convert box format if needed | |
| if output_bb_format == 'xyxy': | |
| # Convert from [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] to [x1,y1,x3,y3] | |
| x1, y1 = box[0] | |
| x3, y3 = box[2] | |
| boxes.append([x1, y1, x3, y3]) | |
| else: | |
| boxes.append(box) | |
| return (texts, boxes), False | |
| except Exception as e: | |
| print(f"Error in OCR: {str(e)}") | |
| return ([], []), False | |
| def get_yolo_model(model_path): | |
| """ | |
| Load YOLO model for icon detection | |
| """ | |
| try: | |
| from ultralytics import YOLO | |
| model = YOLO(model_path) | |
| return model | |
| except Exception as e: | |
| print(f"Error loading YOLO model: {str(e)}") | |
| return None | |
| def get_caption_model_processor(model_name, model_name_or_path): | |
| """ | |
| Load caption model and processor | |
| """ | |
| try: | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| processor = AutoProcessor.from_pretrained(model_name_or_path) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| return (model, processor) | |
| except Exception as e: | |
| print(f"Error loading caption model: {str(e)}") | |
| return None | |
| def get_som_labeled_img(image, yolo_model, BOX_TRESHOLD=0.05, output_coord_in_ratio=True, | |
| ocr_bbox=None, draw_bbox_config=None, caption_model_processor=None, | |
| ocr_text=None, iou_threshold=0.1, imgsz=640): | |
| """ | |
| Simplified implementation of get_som_labeled_img | |
| """ | |
| try: | |
| # Create a copy of the image for visualization | |
| vis_img = image.copy() | |
| draw = ImageDraw.Draw(vis_img) | |
| # Run YOLO detection | |
| results = yolo_model(image, imgsz=imgsz) | |
| # Process results | |
| elements = [] | |
| for i, det in enumerate(results[0].boxes.data): | |
| x1, y1, x2, y2, conf, cls = det | |
| if conf < BOX_TRESHOLD: | |
| continue | |
| # Draw bounding box | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=2) | |
| # Generate caption | |
| caption = f"UI Element {i}" | |
| # Add to elements list | |
| elements.append({ | |
| "id": i, | |
| "text": "", | |
| "caption": caption, | |
| "coordinates": [x1/image.width, y1/image.height, x2/image.width, y2/image.height], | |
| "is_interactable": True, | |
| "confidence": float(conf) | |
| }) | |
| # Convert to base64 | |
| buffered = io.BytesIO() | |
| vis_img.save(buffered, format="PNG") | |
| img_str = "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode() | |
| return img_str, [], elements | |
| except Exception as e: | |
| print(f"Error in get_som_labeled_img: {str(e)}") | |
| return "Error processing image", [], [] | |
| """ | |
| # Write the simplified utils.py | |
| with open(utils_path, 'w') as f: | |
| f.write(simplified_utils) | |
| print("Created simplified utils.py with essential functions") | |
| print("OmniParser setup completed successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Error setting up OmniParser: {str(e)}") | |
| return False | |
| # Setup OmniParser | |
| setup_success = setup_omniparser() | |
| # Create our own implementation of check_ocr_box to avoid PaddleOCR issues | |
| def custom_check_ocr_box(image, display_img=False, output_bb_format='xyxy', goal_filtering=None, | |
| easyocr_args=None, use_paddleocr=True): | |
| """ | |
| Custom implementation of check_ocr_box that doesn't rely on PaddleOCR | |
| """ | |
| print("Using custom OCR implementation (EasyOCR only)") | |
| try: | |
| import easyocr | |
| import numpy as np | |
| # Convert PIL Image to numpy array | |
| img_np = np.array(image) | |
| # Initialize EasyOCR | |
| reader = easyocr.Reader(['en']) | |
| # Run OCR | |
| results = reader.readtext(img_np) | |
| # Extract text and bounding boxes | |
| texts = [] | |
| boxes = [] | |
| for result in results: | |
| box, text, _ = result | |
| texts.append(text) | |
| # Convert box format if needed | |
| if output_bb_format == 'xyxy': | |
| # Convert from [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] to [x1,y1,x3,y3] | |
| x1, y1 = box[0] | |
| x3, y3 = box[2] | |
| boxes.append([x1, y1, x3, y3]) | |
| else: | |
| boxes.append(box) | |
| return (texts, boxes), False | |
| except Exception as e: | |
| print(f"Error in custom OCR: {str(e)}") | |
| return ([], []), False | |
| # Import OmniParser utilities | |
| if setup_success: | |
| try: | |
| # First try to import the patched version | |
| from OmniParser.util.utils import get_yolo_model, get_caption_model_processor, get_som_labeled_img | |
| # Try to import check_ocr_box, but use our custom version if it fails | |
| try: | |
| from OmniParser.util.utils import check_ocr_box | |
| print("Successfully imported all OmniParser utilities") | |
| except (ImportError, ValueError) as e: | |
| print(f"Using custom OCR implementation due to error: {str(e)}") | |
| check_ocr_box = custom_check_ocr_box | |
| except ImportError as e: | |
| print(f"Error importing OmniParser utilities: {str(e)}") | |
| # Fallback to a simple error message | |
| def error_message(*args, **kwargs): | |
| return "Error: OmniParser utilities could not be imported. Please check the logs." | |
| # Create dummy functions that return error messages | |
| check_ocr_box = get_yolo_model = get_caption_model_processor = get_som_labeled_img = error_message | |
| else: | |
| print("Using dummy functions due to setup failure") | |
| # Create dummy functions that return error messages | |
| def error_message(*args, **kwargs): | |
| return "Error: OmniParser setup failed. Please check the logs." | |
| check_ocr_box = get_yolo_model = get_caption_model_processor = get_som_labeled_img = error_message | |
| # Initialize models | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Initialize models with correct paths | |
| try: | |
| # YOLO model for object detection | |
| yolo_model = get_yolo_model(model_path='OmniParser/weights/icon_detect/model.pt') | |
| # VLM (Vision Language Model) for captioning | |
| caption_model_processor = get_caption_model_processor( | |
| model_name="florence2", | |
| model_name_or_path="OmniParser/weights/icon_caption_florence" | |
| ) | |
| print("Models initialized successfully") | |
| models_initialized = True | |
| # ENHANCEMENT OPPORTUNITY: Data Fusion | |
| # The current implementation uses YOLO for detection and VLM for captioning separately. | |
| # A more integrated approach could: | |
| # 1. Use YOLO for initial detection of UI elements | |
| # 2. Use VLM to refine the detections and provide more context | |
| # 3. Implement a confidence-based merging strategy for overlapping detections | |
| # 4. Use SAM (Segment Anything Model) for more precise segmentation of UI elements | |
| # | |
| # Example implementation: | |
| # ``` | |
| # def enhanced_detection(image, yolo_model, vlm_model, sam_model): | |
| # # Get YOLO detections | |
| # yolo_boxes = yolo_model(image) | |
| # | |
| # # Use VLM to analyze the entire image for context | |
| # global_context = vlm_model.analyze_image(image) | |
| # | |
| # # For each YOLO box, use VLM to get more detailed information | |
| # refined_detections = [] | |
| # for box in yolo_boxes: | |
| # # Crop the region | |
| # region = crop_image(image, box) | |
| # | |
| # # Get VLM description | |
| # description = vlm_model.describe_region(region, context=global_context) | |
| # | |
| # # Use SAM for precise segmentation | |
| # mask = sam_model.segment(image, box) | |
| # | |
| # refined_detections.append({ | |
| # "box": box, | |
| # "description": description, | |
| # "mask": mask, | |
| # "confidence": combine_confidence(box.conf, description.conf) | |
| # }) | |
| # | |
| # return refined_detections | |
| # ``` | |
| except Exception as e: | |
| print(f"Error initializing models: {str(e)}") | |
| # Create dummy models for graceful failure | |
| yolo_model = None | |
| caption_model_processor = None | |
| models_initialized = False | |
| # Fallback implementation for when OmniParser fails | |
| def fallback_process_image(image): | |
| """ | |
| Fallback implementation that simulates OmniParser functionality | |
| for when the actual models fail to load | |
| """ | |
| from PIL import Image, ImageDraw, ImageFont | |
| import random | |
| # Create a copy of the image for visualization | |
| vis_img = image.copy() | |
| draw = ImageDraw.Draw(vis_img) | |
| # Define some mock UI element types | |
| element_types = ["Button", "Text Field", "Checkbox", "Dropdown", "Menu Item", "Icon", "Link"] | |
| # Generate some random elements | |
| elements = [] | |
| num_elements = min(10, int(image.width * image.height / 50000)) # Scale with image size | |
| for i in range(num_elements): | |
| # Generate random position and size | |
| x1 = random.randint(0, image.width - 100) | |
| y1 = random.randint(0, image.height - 50) | |
| width = random.randint(50, 200) | |
| height = random.randint(30, 80) | |
| x2 = min(x1 + width, image.width) | |
| y2 = min(y1 + height, image.height) | |
| # Generate random element type and caption | |
| element_type = random.choice(element_types) | |
| captions = { | |
| "Button": ["Submit", "Cancel", "OK", "Apply", "Save"], | |
| "Text Field": ["Enter text", "Username", "Password", "Search", "Email"], | |
| "Checkbox": ["Select option", "Enable feature", "Remember me", "Agree to terms"], | |
| "Dropdown": ["Select item", "Choose option", "Select country", "Language"], | |
| "Menu Item": ["File", "Edit", "View", "Help", "Tools", "Settings"], | |
| "Icon": ["Home", "Settings", "Profile", "Notification", "Search"], | |
| "Link": ["Learn more", "Click here", "Details", "Documentation", "Help"] | |
| } | |
| text = random.choice(captions[element_type]) | |
| caption = f"{element_type}: {text}" | |
| # Add to elements list | |
| elements.append({ | |
| "id": i, | |
| "text": text, | |
| "caption": caption, | |
| "coordinates": [x1/image.width, y1/image.height, x2/image.width, y2/image.height], | |
| "is_interactable": element_type in ["Button", "Checkbox", "Dropdown", "Link", "Text Field"], | |
| "confidence": random.uniform(0.7, 0.95) | |
| }) | |
| # Draw on visualization | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=2) | |
| draw.text((x1, y1 - 10), f"{i}: {text}", fill="red") | |
| return { | |
| "elements": elements, | |
| "visualization": vis_img, | |
| "note": "This is a fallback visualization as OmniParser models could not be loaded." | |
| } | |
| def process_image( | |
| image: Image.Image, | |
| box_threshold: float = 0.05, | |
| iou_threshold: float = 0.1, | |
| use_paddleocr: bool = True, | |
| imgsz: int = 640 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Process an image with OmniParser and return structured data | |
| Args: | |
| image: PIL Image to process | |
| box_threshold: Threshold for bounding box confidence | |
| iou_threshold: Threshold for IOU overlap | |
| use_paddleocr: Whether to use PaddleOCR for text detection | |
| imgsz: Image size for icon detection | |
| Returns: | |
| Dictionary with parsed elements and visualization | |
| """ | |
| # Check if models are initialized | |
| if not models_initialized or yolo_model is None or caption_model_processor is None: | |
| print("Models not initialized properly, using fallback implementation") | |
| return fallback_process_image(image) | |
| try: | |
| # Calculate overlay ratio based on image size | |
| box_overlay_ratio = image.size[0] / 3200 | |
| # Configure drawing parameters | |
| draw_bbox_config = { | |
| 'text_scale': 0.8 * box_overlay_ratio, | |
| 'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
| 'text_padding': max(int(3 * box_overlay_ratio), 1), | |
| 'thickness': max(int(3 * box_overlay_ratio), 1), | |
| } | |
| # Run OCR to detect text | |
| try: | |
| # ENHANCEMENT OPPORTUNITY: OCR Integration | |
| # The current implementation uses OCR separately from YOLO detection. | |
| # A more integrated approach could: | |
| # 1. Use OCR results to refine YOLO detections | |
| # 2. Merge overlapping text and UI element detections | |
| # 3. Use text content to improve element classification | |
| # | |
| # Example implementation: | |
| # ``` | |
| # def integrated_ocr_detection(image, ocr_results, yolo_detections): | |
| # merged_detections = [] | |
| # | |
| # # For each YOLO detection | |
| # for yolo_box in yolo_detections: | |
| # # Find overlapping OCR text | |
| # overlapping_text = [] | |
| # for text, text_box in ocr_results: | |
| # if calculate_iou(yolo_box, text_box) > threshold: | |
| # overlapping_text.append(text) | |
| # | |
| # # Use text content to refine element classification | |
| # element_type = classify_element_with_text(yolo_box, overlapping_text) | |
| # | |
| # merged_detections.append({ | |
| # "box": yolo_box, | |
| # "text": " ".join(overlapping_text), | |
| # "type": element_type | |
| # }) | |
| # | |
| # return merged_detections | |
| # ``` | |
| ocr_bbox_rslt, is_goal_filtered = check_ocr_box( | |
| image, | |
| display_img=False, | |
| output_bb_format='xyxy', | |
| goal_filtering=None, | |
| easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
| use_paddleocr=use_paddleocr | |
| ) | |
| # Check if OCR returned an error message (string) | |
| if isinstance(ocr_bbox_rslt, str): | |
| print(f"OCR error: {ocr_bbox_rslt}, using fallback implementation") | |
| return fallback_process_image(image) | |
| text, ocr_bbox = ocr_bbox_rslt | |
| except Exception as e: | |
| print(f"OCR error: {str(e)}, using fallback implementation") | |
| return fallback_process_image(image) | |
| # Process image with OmniParser | |
| try: | |
| # ENHANCEMENT OPPORTUNITY: SAM Integration | |
| # The current implementation doesn't use SAM (Segment Anything Model). | |
| # Integrating SAM could: | |
| # 1. Provide more precise segmentation of UI elements | |
| # 2. Better handle complex UI layouts with overlapping elements | |
| # 3. Improve detection of irregular-shaped elements | |
| # | |
| # Example implementation: | |
| # ``` | |
| # def integrate_sam(image, boxes, sam_model): | |
| # # Initialize SAM predictor | |
| # predictor = SamPredictor(sam_model) | |
| # predictor.set_image(np.array(image)) | |
| # | |
| # refined_elements = [] | |
| # for box in boxes: | |
| # # Convert box to SAM input format | |
| # input_box = np.array([box[0], box[1], box[2], box[3]]) | |
| # | |
| # # Get SAM mask | |
| # masks, scores, _ = predictor.predict( | |
| # box=input_box, | |
| # multimask_output=False | |
| # ) | |
| # | |
| # # Use the mask to refine the element boundaries | |
| # refined_elements.append({ | |
| # "box": box, | |
| # "mask": masks[0], | |
| # "mask_confidence": scores[0] | |
| # }) | |
| # | |
| # return refined_elements | |
| # ``` | |
| dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
| image, | |
| yolo_model, | |
| BOX_TRESHOLD=box_threshold, | |
| output_coord_in_ratio=True, | |
| ocr_bbox=ocr_bbox, | |
| draw_bbox_config=draw_bbox_config, | |
| caption_model_processor=caption_model_processor, | |
| ocr_text=text, | |
| iou_threshold=iou_threshold, | |
| imgsz=imgsz | |
| ) | |
| # Check if get_som_labeled_img returned an error message (string) | |
| if isinstance(dino_labled_img, str) and not dino_labled_img.startswith("data:"): | |
| print(f"OmniParser error: {dino_labled_img}, using fallback implementation") | |
| return fallback_process_image(image) | |
| # Convert base64 image to PIL Image | |
| visualization = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) | |
| # Create structured output | |
| elements = [] | |
| for i, element in enumerate(parsed_content_list): | |
| # ENHANCEMENT OPPORTUNITY: Confidence Scoring | |
| # The current implementation uses a simple confidence score. | |
| # A more sophisticated approach could: | |
| # 1. Combine confidence scores from multiple models (YOLO, VLM, OCR) | |
| # 2. Consider element context and relationships | |
| # 3. Use historical data to improve confidence scoring | |
| # | |
| # Example implementation: | |
| # ``` | |
| # def calculate_confidence(yolo_conf, vlm_conf, ocr_conf, element_type): | |
| # # Base confidence from YOLO | |
| # base_conf = yolo_conf | |
| # | |
| # # Adjust based on VLM confidence | |
| # if vlm_conf > 0.8: | |
| # base_conf = (base_conf + vlm_conf) / 2 | |
| # | |
| # # Adjust based on element type | |
| # if element_type == "button" and ocr_conf > 0.9: | |
| # base_conf = (base_conf + ocr_conf) / 2 | |
| # | |
| # # Normalize to 0-1 range | |
| # return min(1.0, base_conf) | |
| # ``` | |
| elements.append({ | |
| "id": i, | |
| "text": element.get("text", ""), | |
| "caption": element.get("caption", ""), | |
| "coordinates": element.get("coordinates", []), | |
| "is_interactable": element.get("is_interactable", False), | |
| "confidence": element.get("confidence", 0.0) | |
| }) | |
| # ENHANCEMENT OPPORTUNITY: Predictive Monitoring | |
| # The current implementation doesn't include predictive monitoring. | |
| # Adding this could: | |
| # 1. Verify that detected elements make sense in the UI context | |
| # 2. Identify missing or incorrectly detected elements | |
| # 3. Provide feedback for improving detection accuracy | |
| # | |
| # Example implementation: | |
| # ``` | |
| # def verify_detections(elements, image, vlm_model): | |
| # # Use VLM to analyze the entire image | |
| # global_description = vlm_model.describe_image(image) | |
| # | |
| # # Check if detected elements match the global description | |
| # expected_elements = extract_expected_elements(global_description) | |
| # | |
| # # Compare detected vs expected | |
| # missing_elements = [e for e in expected_elements if not any( | |
| # similar_element(e, detected) for detected in elements | |
| # )] | |
| # | |
| # # Provide feedback | |
| # return { | |
| # "verified_elements": elements, | |
| # "missing_elements": missing_elements, | |
| # "confidence": calculate_overall_confidence(elements, expected_elements) | |
| # } | |
| # ``` | |
| # Return structured data and visualization | |
| return { | |
| "elements": elements, | |
| "visualization": visualization | |
| } | |
| except Exception as e: | |
| print(f"OmniParser error: {str(e)}, using fallback implementation") | |
| return fallback_process_image(image) | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}, using fallback implementation") | |
| # Use fallback implementation | |
| return fallback_process_image(image) | |
| # API endpoint function | |
| def api_endpoint(image): | |
| """ | |
| API endpoint that accepts an image and returns parsed elements | |
| Args: | |
| image: Uploaded image file | |
| Returns: | |
| JSON with parsed elements | |
| """ | |
| if image is None: | |
| return json.dumps({"error": "No image provided"}) | |
| try: | |
| # Process the image | |
| result = process_image(image) | |
| # Check if there was an error | |
| if "error" in result: | |
| return json.dumps({ | |
| "status": "error", | |
| "error": result["error"], | |
| "elements": [] | |
| }) | |
| # Convert visualization to base64 for JSON response | |
| buffered = io.BytesIO() | |
| result["visualization"].save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Create response | |
| response = { | |
| "status": "success", | |
| "elements": result["elements"], | |
| "visualization": img_str | |
| } | |
| return json.dumps(response) | |
| except Exception as e: | |
| print(f"API endpoint error: {str(e)}") | |
| return json.dumps({ | |
| "status": "error", | |
| "error": f"API processing error: {str(e)}", | |
| "elements": [] | |
| }) | |
| # Function to handle UI submission | |
| def handle_submission(image, box_threshold=0.05, iou_threshold=0.1, use_paddleocr=True, imgsz=640): | |
| """Handle UI submission and provide appropriate feedback""" | |
| if image is None: | |
| return {"error": "No image provided"}, None | |
| # Process the image | |
| result = process_image( | |
| image, | |
| box_threshold=box_threshold, | |
| iou_threshold=iou_threshold, | |
| use_paddleocr=use_paddleocr, | |
| imgsz=imgsz | |
| ) | |
| # Return the result | |
| if "error" in result: | |
| return {"error": result["error"]}, result.get("visualization", None) | |
| elif "note" in result: | |
| # This is from the fallback implementation | |
| return { | |
| "note": result["note"], | |
| "elements": result["elements"] | |
| }, result["visualization"] | |
| else: | |
| return {"elements": result["elements"]}, result["visualization"] | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # OmniParser v2.0 API | |
| Upload an image to parse UI elements and get structured data. | |
| ## Quick Start | |
| You can use the [test UI image](/file=static/test_ui.png) to try out the API, or upload your own UI screenshot. | |
| ## API Usage | |
| You can use this API by sending a POST request with a file upload to this URL. | |
| ```python | |
| import requests | |
| # Replace with your actual API URL after deployment | |
| OMNIPARSER_API_URL = "https://your-username-omniparser-api.hf.space/api/parse" | |
| # Upload a file | |
| files = {'image': open('screenshot.png', 'rb')} | |
| # Send request | |
| response = requests.post(OMNIPARSER_API_URL, files=files) | |
| # Get JSON result | |
| result = response.json() | |
| ``` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type='pil', label='Upload image') | |
| # Function to load test image | |
| def load_test_image(): | |
| if os.path.exists("static/test_ui.png"): | |
| return Image.open("static/test_ui.png") | |
| return None | |
| test_image_button = gr.Button(value='Load Test Image') | |
| test_image_button.click(fn=load_test_image, inputs=[], outputs=[image_input]) | |
| with gr.Accordion("Advanced Options", open=False): | |
| box_threshold = gr.Slider( | |
| label='Box Threshold', | |
| minimum=0.01, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.05 | |
| ) | |
| iou_threshold = gr.Slider( | |
| label='IOU Threshold', | |
| minimum=0.01, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.1 | |
| ) | |
| use_paddleocr = gr.Checkbox( | |
| label='Use PaddleOCR', | |
| value=True | |
| ) | |
| imgsz = gr.Slider( | |
| label='Icon Detect Image Size', | |
| minimum=640, | |
| maximum=1920, | |
| step=32, | |
| value=640 | |
| ) | |
| submit_button = gr.Button(value='Parse Image', variant='primary') | |
| # Status message | |
| status = gr.Markdown("Ready to parse images") | |
| with gr.Column(): | |
| json_output = gr.JSON(label='Parsed Elements (JSON)') | |
| image_output = gr.Image(type='pil', label='Visualization') | |
| # Connect the interface | |
| submit_button.click( | |
| fn=handle_submission, | |
| inputs=[image_input, box_threshold, iou_threshold, use_paddleocr, imgsz], | |
| outputs=[json_output, image_output], | |
| api_name="parse" # This creates the /api/parse endpoint | |
| ) | |
| # Function to get status | |
| def get_status(): | |
| if models_initialized: | |
| return f"✅ OmniParser v2.0 API - Running on {'GPU' if torch.cuda.is_available() else 'CPU'}" | |
| else: | |
| return "⚠️ OmniParser v2.0 API - Running in fallback mode (models not loaded)" | |
| # Update status on load | |
| demo.load( | |
| fn=get_status, | |
| outputs=status | |
| ) | |
| # Create test image if it doesn't exist | |
| try: | |
| if not os.path.exists("static/test_ui.png"): | |
| print("Creating test UI image...") | |
| from create_test_image import create_test_ui_image | |
| test_image_path = create_test_ui_image() | |
| print(f"Test image created at {test_image_path}") | |
| except Exception as e: | |
| print(f"Error creating test image: {str(e)}") | |
| # Launch the app | |
| demo.launch() |