|
|
""" |
|
|
Test script for VisionOCRAgent |
|
|
|
|
|
Tests OCR functionality with llava:7b vision model. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent / "src")) |
|
|
|
|
|
from agents.vision_ocr_agent import VisionOCRAgent |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
logger.remove() |
|
|
logger.add(sys.stderr, level="INFO") |
|
|
|
|
|
|
|
|
async def test_vision_ocr(): |
|
|
"""Test VisionOCRAgent with sample patent.""" |
|
|
|
|
|
|
|
|
logger.info("Initializing VisionOCRAgent...") |
|
|
agent = VisionOCRAgent(model_name="llava:7b") |
|
|
|
|
|
|
|
|
if not agent.is_available(): |
|
|
logger.error("llava:7b model not available. Please run: ollama pull llava:7b") |
|
|
return False |
|
|
|
|
|
logger.success("VisionOCRAgent initialized successfully") |
|
|
|
|
|
|
|
|
test_patent = "/home/mhamdan/SPARKNET/Dataset/Google 08.02.2012.pdf" |
|
|
|
|
|
if not Path(test_patent).exists(): |
|
|
|
|
|
test_patent = "/home/mhamdan/SPARKNET/uploads/patents/d58fc23c-58ce-4e1c-9ca7-2c63493f90eb.pdf" |
|
|
|
|
|
if not Path(test_patent).exists(): |
|
|
logger.error("No test patent found") |
|
|
return False |
|
|
|
|
|
logger.info(f"Testing with patent: {test_patent}") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
logger.info("Test 1: Agent availability - PASSED") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.success("VisionOCRAgent basic tests completed successfully") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
async def test_with_image(image_path: str): |
|
|
"""Test OCR with a specific image.""" |
|
|
agent = VisionOCRAgent(model_name="llava:7b") |
|
|
|
|
|
if not agent.is_available(): |
|
|
logger.error("Model not available") |
|
|
return |
|
|
|
|
|
logger.info(f"Testing OCR with image: {image_path}") |
|
|
|
|
|
|
|
|
text = await agent.extract_text_from_image(image_path) |
|
|
logger.info(f"Extracted text length: {len(text)} characters") |
|
|
logger.info(f"Text preview: {text[:200]}...") |
|
|
|
|
|
|
|
|
analysis = await agent.analyze_patent_page(image_path) |
|
|
logger.info(f"Patent analysis: {analysis}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
success = asyncio.run(test_vision_ocr()) |
|
|
|
|
|
if success: |
|
|
logger.success("All tests passed!") |
|
|
sys.exit(0) |
|
|
else: |
|
|
logger.error("Tests failed") |
|
|
sys.exit(1) |
|
|
|