import os import io import json import base64 import gradio as gr import numpy as np import torch import subprocess import sys from PIL import Image from huggingface_hub import hf_hub_download from typing import Optional, Dict, Any, List, Union # Setup OmniParser repository and models def setup_omniparser(): """Clone OmniParser repository and download model weights""" try: # Check if OmniParser repository exists if not os.path.exists("OmniParser"): print("Cloning OmniParser repository...") subprocess.run(["git", "clone", "https://github.com/microsoft/OmniParser.git"], check=True) # Add OmniParser to Python path omniparser_path = os.path.abspath("OmniParser") if omniparser_path not in sys.path: sys.path.append(omniparser_path) print(f"Added {omniparser_path} to Python path") # Create weights directory os.makedirs("OmniParser/weights/icon_detect", exist_ok=True) os.makedirs("OmniParser/weights/icon_caption_florence", exist_ok=True) # Download model weights if they don't exist if not os.path.exists("OmniParser/weights/icon_detect/model.pt") or not os.path.exists("OmniParser/weights/icon_caption_florence/model.safetensors"): print("Downloading model weights...") # Download detection model files for f in ["train_args.yaml", "model.pt", "model.yaml"]: hf_hub_download( repo_id="microsoft/OmniParser-v2.0", filename=f"icon_detect/{f}", local_dir="OmniParser/weights" ) # Download caption model files for f in ["config.json", "generation_config.json", "model.safetensors"]: hf_hub_download( repo_id="microsoft/OmniParser-v2.0", filename=f"icon_caption/{f}", local_dir="OmniParser/weights" ) # Rename the caption folder to match expected path if os.path.exists("OmniParser/weights/icon_caption") and not os.path.exists("OmniParser/weights/icon_caption_florence"): os.rename("OmniParser/weights/icon_caption", "OmniParser/weights/icon_caption_florence") # Patch PaddleOCR initialization in utils.py to fix compatibility issue utils_path = os.path.join(omniparser_path, "util", "utils.py") if os.path.exists(utils_path): print("Patching utils.py to fix compatibility issues...") # Create a simplified version of utils.py with essential functions simplified_utils = """import os import io import cv2 import base64 import numpy as np import torch from PIL import Image, ImageDraw def check_ocr_box(image, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args=None, use_paddleocr=True): """ Custom implementation of check_ocr_box that uses EasyOCR """ try: import easyocr # Convert PIL Image to numpy array img_np = np.array(image) # Initialize EasyOCR reader = easyocr.Reader(['en']) # Run OCR results = reader.readtext(img_np) # Extract text and bounding boxes texts = [] boxes = [] for result in results: box, text, _ = result texts.append(text) # Convert box format if needed if output_bb_format == 'xyxy': # Convert from [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] to [x1,y1,x3,y3] x1, y1 = box[0] x3, y3 = box[2] boxes.append([x1, y1, x3, y3]) else: boxes.append(box) return (texts, boxes), False except Exception as e: print(f"Error in OCR: {str(e)}") return ([], []), False def get_yolo_model(model_path): """ Load YOLO model for icon detection """ try: from ultralytics import YOLO model = YOLO(model_path) return model except Exception as e: print(f"Error loading YOLO model: {str(e)}") return None def get_caption_model_processor(model_name, model_name_or_path): """ Load caption model and processor """ try: from transformers import AutoProcessor, AutoModelForCausalLM processor = AutoProcessor.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.float16, device_map="auto" ) return (model, processor) except Exception as e: print(f"Error loading caption model: {str(e)}") return None def get_som_labeled_img(image, yolo_model, BOX_TRESHOLD=0.05, output_coord_in_ratio=True, ocr_bbox=None, draw_bbox_config=None, caption_model_processor=None, ocr_text=None, iou_threshold=0.1, imgsz=640): """ Simplified implementation of get_som_labeled_img """ try: # Create a copy of the image for visualization vis_img = image.copy() draw = ImageDraw.Draw(vis_img) # Run YOLO detection results = yolo_model(image, imgsz=imgsz) # Process results elements = [] for i, det in enumerate(results[0].boxes.data): x1, y1, x2, y2, conf, cls = det if conf < BOX_TRESHOLD: continue # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline="red", width=2) # Generate caption caption = f"UI Element {i}" # Add to elements list elements.append({ "id": i, "text": "", "caption": caption, "coordinates": [x1/image.width, y1/image.height, x2/image.width, y2/image.height], "is_interactable": True, "confidence": float(conf) }) # Convert to base64 buffered = io.BytesIO() vis_img.save(buffered, format="PNG") img_str = "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode() return img_str, [], elements except Exception as e: print(f"Error in get_som_labeled_img: {str(e)}") return "Error processing image", [], [] """ # Write the simplified utils.py with open(utils_path, 'w') as f: f.write(simplified_utils) print("Created simplified utils.py with essential functions") print("OmniParser setup completed successfully!") return True except Exception as e: print(f"Error setting up OmniParser: {str(e)}") return False # Setup OmniParser setup_success = setup_omniparser() # Create our own implementation of check_ocr_box to avoid PaddleOCR issues def custom_check_ocr_box(image, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args=None, use_paddleocr=True): """ Custom implementation of check_ocr_box that doesn't rely on PaddleOCR """ print("Using custom OCR implementation (EasyOCR only)") try: import easyocr import numpy as np # Convert PIL Image to numpy array img_np = np.array(image) # Initialize EasyOCR reader = easyocr.Reader(['en']) # Run OCR results = reader.readtext(img_np) # Extract text and bounding boxes texts = [] boxes = [] for result in results: box, text, _ = result texts.append(text) # Convert box format if needed if output_bb_format == 'xyxy': # Convert from [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] to [x1,y1,x3,y3] x1, y1 = box[0] x3, y3 = box[2] boxes.append([x1, y1, x3, y3]) else: boxes.append(box) return (texts, boxes), False except Exception as e: print(f"Error in custom OCR: {str(e)}") return ([], []), False # Import OmniParser utilities if setup_success: try: # First try to import the patched version from OmniParser.util.utils import get_yolo_model, get_caption_model_processor, get_som_labeled_img # Try to import check_ocr_box, but use our custom version if it fails try: from OmniParser.util.utils import check_ocr_box print("Successfully imported all OmniParser utilities") except (ImportError, ValueError) as e: print(f"Using custom OCR implementation due to error: {str(e)}") check_ocr_box = custom_check_ocr_box except ImportError as e: print(f"Error importing OmniParser utilities: {str(e)}") # Fallback to a simple error message def error_message(*args, **kwargs): return "Error: OmniParser utilities could not be imported. Please check the logs." # Create dummy functions that return error messages check_ocr_box = get_yolo_model = get_caption_model_processor = get_som_labeled_img = error_message else: print("Using dummy functions due to setup failure") # Create dummy functions that return error messages def error_message(*args, **kwargs): return "Error: OmniParser setup failed. Please check the logs." check_ocr_box = get_yolo_model = get_caption_model_processor = get_som_labeled_img = error_message # Initialize models device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Initialize models with correct paths try: # YOLO model for object detection yolo_model = get_yolo_model(model_path='OmniParser/weights/icon_detect/model.pt') # VLM (Vision Language Model) for captioning caption_model_processor = get_caption_model_processor( model_name="florence2", model_name_or_path="OmniParser/weights/icon_caption_florence" ) print("Models initialized successfully") models_initialized = True # ENHANCEMENT OPPORTUNITY: Data Fusion # The current implementation uses YOLO for detection and VLM for captioning separately. # A more integrated approach could: # 1. Use YOLO for initial detection of UI elements # 2. Use VLM to refine the detections and provide more context # 3. Implement a confidence-based merging strategy for overlapping detections # 4. Use SAM (Segment Anything Model) for more precise segmentation of UI elements # # Example implementation: # ``` # def enhanced_detection(image, yolo_model, vlm_model, sam_model): # # Get YOLO detections # yolo_boxes = yolo_model(image) # # # Use VLM to analyze the entire image for context # global_context = vlm_model.analyze_image(image) # # # For each YOLO box, use VLM to get more detailed information # refined_detections = [] # for box in yolo_boxes: # # Crop the region # region = crop_image(image, box) # # # Get VLM description # description = vlm_model.describe_region(region, context=global_context) # # # Use SAM for precise segmentation # mask = sam_model.segment(image, box) # # refined_detections.append({ # "box": box, # "description": description, # "mask": mask, # "confidence": combine_confidence(box.conf, description.conf) # }) # # return refined_detections # ``` except Exception as e: print(f"Error initializing models: {str(e)}") # Create dummy models for graceful failure yolo_model = None caption_model_processor = None models_initialized = False # Fallback implementation for when OmniParser fails def fallback_process_image(image): """ Fallback implementation that simulates OmniParser functionality for when the actual models fail to load """ from PIL import Image, ImageDraw, ImageFont import random # Create a copy of the image for visualization vis_img = image.copy() draw = ImageDraw.Draw(vis_img) # Define some mock UI element types element_types = ["Button", "Text Field", "Checkbox", "Dropdown", "Menu Item", "Icon", "Link"] # Generate some random elements elements = [] num_elements = min(10, int(image.width * image.height / 50000)) # Scale with image size for i in range(num_elements): # Generate random position and size x1 = random.randint(0, image.width - 100) y1 = random.randint(0, image.height - 50) width = random.randint(50, 200) height = random.randint(30, 80) x2 = min(x1 + width, image.width) y2 = min(y1 + height, image.height) # Generate random element type and caption element_type = random.choice(element_types) captions = { "Button": ["Submit", "Cancel", "OK", "Apply", "Save"], "Text Field": ["Enter text", "Username", "Password", "Search", "Email"], "Checkbox": ["Select option", "Enable feature", "Remember me", "Agree to terms"], "Dropdown": ["Select item", "Choose option", "Select country", "Language"], "Menu Item": ["File", "Edit", "View", "Help", "Tools", "Settings"], "Icon": ["Home", "Settings", "Profile", "Notification", "Search"], "Link": ["Learn more", "Click here", "Details", "Documentation", "Help"] } text = random.choice(captions[element_type]) caption = f"{element_type}: {text}" # Add to elements list elements.append({ "id": i, "text": text, "caption": caption, "coordinates": [x1/image.width, y1/image.height, x2/image.width, y2/image.height], "is_interactable": element_type in ["Button", "Checkbox", "Dropdown", "Link", "Text Field"], "confidence": random.uniform(0.7, 0.95) }) # Draw on visualization draw.rectangle([x1, y1, x2, y2], outline="red", width=2) draw.text((x1, y1 - 10), f"{i}: {text}", fill="red") return { "elements": elements, "visualization": vis_img, "note": "This is a fallback visualization as OmniParser models could not be loaded." } def process_image( image: Image.Image, box_threshold: float = 0.05, iou_threshold: float = 0.1, use_paddleocr: bool = True, imgsz: int = 640 ) -> Dict[str, Any]: """ Process an image with OmniParser and return structured data Args: image: PIL Image to process box_threshold: Threshold for bounding box confidence iou_threshold: Threshold for IOU overlap use_paddleocr: Whether to use PaddleOCR for text detection imgsz: Image size for icon detection Returns: Dictionary with parsed elements and visualization """ # Check if models are initialized if not models_initialized or yolo_model is None or caption_model_processor is None: print("Models not initialized properly, using fallback implementation") return fallback_process_image(image) try: # Calculate overlay ratio based on image size box_overlay_ratio = image.size[0] / 3200 # Configure drawing parameters draw_bbox_config = { 'text_scale': 0.8 * box_overlay_ratio, 'text_thickness': max(int(2 * box_overlay_ratio), 1), 'text_padding': max(int(3 * box_overlay_ratio), 1), 'thickness': max(int(3 * box_overlay_ratio), 1), } # Run OCR to detect text try: # ENHANCEMENT OPPORTUNITY: OCR Integration # The current implementation uses OCR separately from YOLO detection. # A more integrated approach could: # 1. Use OCR results to refine YOLO detections # 2. Merge overlapping text and UI element detections # 3. Use text content to improve element classification # # Example implementation: # ``` # def integrated_ocr_detection(image, ocr_results, yolo_detections): # merged_detections = [] # # # For each YOLO detection # for yolo_box in yolo_detections: # # Find overlapping OCR text # overlapping_text = [] # for text, text_box in ocr_results: # if calculate_iou(yolo_box, text_box) > threshold: # overlapping_text.append(text) # # # Use text content to refine element classification # element_type = classify_element_with_text(yolo_box, overlapping_text) # # merged_detections.append({ # "box": yolo_box, # "text": " ".join(overlapping_text), # "type": element_type # }) # # return merged_detections # ``` ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=use_paddleocr ) # Check if OCR returned an error message (string) if isinstance(ocr_bbox_rslt, str): print(f"OCR error: {ocr_bbox_rslt}, using fallback implementation") return fallback_process_image(image) text, ocr_bbox = ocr_bbox_rslt except Exception as e: print(f"OCR error: {str(e)}, using fallback implementation") return fallback_process_image(image) # Process image with OmniParser try: # ENHANCEMENT OPPORTUNITY: SAM Integration # The current implementation doesn't use SAM (Segment Anything Model). # Integrating SAM could: # 1. Provide more precise segmentation of UI elements # 2. Better handle complex UI layouts with overlapping elements # 3. Improve detection of irregular-shaped elements # # Example implementation: # ``` # def integrate_sam(image, boxes, sam_model): # # Initialize SAM predictor # predictor = SamPredictor(sam_model) # predictor.set_image(np.array(image)) # # refined_elements = [] # for box in boxes: # # Convert box to SAM input format # input_box = np.array([box[0], box[1], box[2], box[3]]) # # # Get SAM mask # masks, scores, _ = predictor.predict( # box=input_box, # multimask_output=False # ) # # # Use the mask to refine the element boundaries # refined_elements.append({ # "box": box, # "mask": masks[0], # "mask_confidence": scores[0] # }) # # return refined_elements # ``` dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, imgsz=imgsz ) # Check if get_som_labeled_img returned an error message (string) if isinstance(dino_labled_img, str) and not dino_labled_img.startswith("data:"): print(f"OmniParser error: {dino_labled_img}, using fallback implementation") return fallback_process_image(image) # Convert base64 image to PIL Image visualization = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) # Create structured output elements = [] for i, element in enumerate(parsed_content_list): # ENHANCEMENT OPPORTUNITY: Confidence Scoring # The current implementation uses a simple confidence score. # A more sophisticated approach could: # 1. Combine confidence scores from multiple models (YOLO, VLM, OCR) # 2. Consider element context and relationships # 3. Use historical data to improve confidence scoring # # Example implementation: # ``` # def calculate_confidence(yolo_conf, vlm_conf, ocr_conf, element_type): # # Base confidence from YOLO # base_conf = yolo_conf # # # Adjust based on VLM confidence # if vlm_conf > 0.8: # base_conf = (base_conf + vlm_conf) / 2 # # # Adjust based on element type # if element_type == "button" and ocr_conf > 0.9: # base_conf = (base_conf + ocr_conf) / 2 # # # Normalize to 0-1 range # return min(1.0, base_conf) # ``` elements.append({ "id": i, "text": element.get("text", ""), "caption": element.get("caption", ""), "coordinates": element.get("coordinates", []), "is_interactable": element.get("is_interactable", False), "confidence": element.get("confidence", 0.0) }) # ENHANCEMENT OPPORTUNITY: Predictive Monitoring # The current implementation doesn't include predictive monitoring. # Adding this could: # 1. Verify that detected elements make sense in the UI context # 2. Identify missing or incorrectly detected elements # 3. Provide feedback for improving detection accuracy # # Example implementation: # ``` # def verify_detections(elements, image, vlm_model): # # Use VLM to analyze the entire image # global_description = vlm_model.describe_image(image) # # # Check if detected elements match the global description # expected_elements = extract_expected_elements(global_description) # # # Compare detected vs expected # missing_elements = [e for e in expected_elements if not any( # similar_element(e, detected) for detected in elements # )] # # # Provide feedback # return { # "verified_elements": elements, # "missing_elements": missing_elements, # "confidence": calculate_overall_confidence(elements, expected_elements) # } # ``` # Return structured data and visualization return { "elements": elements, "visualization": visualization } except Exception as e: print(f"OmniParser error: {str(e)}, using fallback implementation") return fallback_process_image(image) except Exception as e: print(f"Error processing image: {str(e)}, using fallback implementation") # Use fallback implementation return fallback_process_image(image) # API endpoint function def api_endpoint(image): """ API endpoint that accepts an image and returns parsed elements Args: image: Uploaded image file Returns: JSON with parsed elements """ if image is None: return json.dumps({"error": "No image provided"}) try: # Process the image result = process_image(image) # Check if there was an error if "error" in result: return json.dumps({ "status": "error", "error": result["error"], "elements": [] }) # Convert visualization to base64 for JSON response buffered = io.BytesIO() result["visualization"].save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Create response response = { "status": "success", "elements": result["elements"], "visualization": img_str } return json.dumps(response) except Exception as e: print(f"API endpoint error: {str(e)}") return json.dumps({ "status": "error", "error": f"API processing error: {str(e)}", "elements": [] }) # Function to handle UI submission def handle_submission(image, box_threshold=0.05, iou_threshold=0.1, use_paddleocr=True, imgsz=640): """Handle UI submission and provide appropriate feedback""" if image is None: return {"error": "No image provided"}, None # Process the image result = process_image( image, box_threshold=box_threshold, iou_threshold=iou_threshold, use_paddleocr=use_paddleocr, imgsz=imgsz ) # Return the result if "error" in result: return {"error": result["error"]}, result.get("visualization", None) elif "note" in result: # This is from the fallback implementation return { "note": result["note"], "elements": result["elements"] }, result["visualization"] else: return {"elements": result["elements"]}, result["visualization"] # Create Gradio interface with gr.Blocks() as demo: gr.Markdown(""" # OmniParser v2.0 API Upload an image to parse UI elements and get structured data. ## Quick Start You can use the [test UI image](/file=static/test_ui.png) to try out the API, or upload your own UI screenshot. ## API Usage You can use this API by sending a POST request with a file upload to this URL. ```python import requests # Replace with your actual API URL after deployment OMNIPARSER_API_URL = "https://your-username-omniparser-api.hf.space/api/parse" # Upload a file files = {'image': open('screenshot.png', 'rb')} # Send request response = requests.post(OMNIPARSER_API_URL, files=files) # Get JSON result result = response.json() ``` """) with gr.Row(): with gr.Column(): image_input = gr.Image(type='pil', label='Upload image') # Function to load test image def load_test_image(): if os.path.exists("static/test_ui.png"): return Image.open("static/test_ui.png") return None test_image_button = gr.Button(value='Load Test Image') test_image_button.click(fn=load_test_image, inputs=[], outputs=[image_input]) with gr.Accordion("Advanced Options", open=False): box_threshold = gr.Slider( label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05 ) iou_threshold = gr.Slider( label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1 ) use_paddleocr = gr.Checkbox( label='Use PaddleOCR', value=True ) imgsz = gr.Slider( label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640 ) submit_button = gr.Button(value='Parse Image', variant='primary') # Status message status = gr.Markdown("Ready to parse images") with gr.Column(): json_output = gr.JSON(label='Parsed Elements (JSON)') image_output = gr.Image(type='pil', label='Visualization') # Connect the interface submit_button.click( fn=handle_submission, inputs=[image_input, box_threshold, iou_threshold, use_paddleocr, imgsz], outputs=[json_output, image_output], api_name="parse" # This creates the /api/parse endpoint ) # Function to get status def get_status(): if models_initialized: return f"✅ OmniParser v2.0 API - Running on {'GPU' if torch.cuda.is_available() else 'CPU'}" else: return "⚠️ OmniParser v2.0 API - Running in fallback mode (models not loaded)" # Update status on load demo.load( fn=get_status, outputs=status ) # Create test image if it doesn't exist try: if not os.path.exists("static/test_ui.png"): print("Creating test UI image...") from create_test_image import create_test_ui_image test_image_path = create_test_ui_image() print(f"Test image created at {test_image_path}") except Exception as e: print(f"Error creating test image: {str(e)}") # Launch the app demo.launch()