SPARKNET / tests /unit /test_vision_ocr.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
Test script for VisionOCRAgent
Tests OCR functionality with llava:7b vision model.
"""
import asyncio
import sys
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))
from agents.vision_ocr_agent import VisionOCRAgent
from loguru import logger
# Configure logger
logger.remove()
logger.add(sys.stderr, level="INFO")
async def test_vision_ocr():
"""Test VisionOCRAgent with sample patent."""
# Initialize agent
logger.info("Initializing VisionOCRAgent...")
agent = VisionOCRAgent(model_name="llava:7b")
# Check if model is available
if not agent.is_available():
logger.error("llava:7b model not available. Please run: ollama pull llava:7b")
return False
logger.success("VisionOCRAgent initialized successfully")
# Test with a patent PDF
test_patent = "/home/mhamdan/SPARKNET/Dataset/Google 08.02.2012.pdf"
if not Path(test_patent).exists():
# Try another patent
test_patent = "/home/mhamdan/SPARKNET/uploads/patents/d58fc23c-58ce-4e1c-9ca7-2c63493f90eb.pdf"
if not Path(test_patent).exists():
logger.error("No test patent found")
return False
logger.info(f"Testing with patent: {test_patent}")
try:
# Test 1: Extract text from first page (if we can convert PDF to image)
# For now, let's test the agent's availability
logger.info("Test 1: Agent availability - PASSED")
# Note: For full testing, we'd need to:
# 1. Convert PDF page to image
# 2. Call extract_text_from_image()
# 3. Call analyze_patent_page()
logger.success("VisionOCRAgent basic tests completed successfully")
return True
except Exception as e:
logger.error(f"Test failed: {e}")
return False
async def test_with_image(image_path: str):
"""Test OCR with a specific image."""
agent = VisionOCRAgent(model_name="llava:7b")
if not agent.is_available():
logger.error("Model not available")
return
logger.info(f"Testing OCR with image: {image_path}")
# Test text extraction
text = await agent.extract_text_from_image(image_path)
logger.info(f"Extracted text length: {len(text)} characters")
logger.info(f"Text preview: {text[:200]}...")
# Test patent page analysis
analysis = await agent.analyze_patent_page(image_path)
logger.info(f"Patent analysis: {analysis}")
if __name__ == "__main__":
# Run basic tests
success = asyncio.run(test_vision_ocr())
if success:
logger.success("All tests passed!")
sys.exit(0)
else:
logger.error("Tests failed")
sys.exit(1)