MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
OCR Model Interface
Abstract interface for Optical Character Recognition models.
Supports both local engines and cloud services.
"""
from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from ..chunks.models import BoundingBox
from .base import (
BaseModel,
BatchableModel,
ImageInput,
ModelCapability,
ModelConfig,
)
class OCREngine(str, Enum):
"""Supported OCR engines."""
PADDLEOCR = "paddleocr"
TESSERACT = "tesseract"
EASYOCR = "easyocr"
CUSTOM = "custom"
@dataclass
class OCRConfig(ModelConfig):
"""Configuration for OCR models."""
engine: OCREngine = OCREngine.PADDLEOCR
languages: List[str] = field(default_factory=lambda: ["en"])
detect_orientation: bool = True
detect_tables: bool = True
min_confidence: float = 0.5
# PaddleOCR specific
use_angle_cls: bool = True
use_gpu: bool = True
# Tesseract specific
tesseract_config: str = ""
psm_mode: int = 3 # Page segmentation mode
def __post_init__(self):
super().__post_init__()
if not self.name:
self.name = f"ocr_{self.engine.value}"
@dataclass
class OCRWord:
"""A single recognized word with its bounding box."""
text: str
bbox: BoundingBox
confidence: float
language: Optional[str] = None
is_handwritten: bool = False
font_size: Optional[float] = None
is_bold: bool = False
is_italic: bool = False
@dataclass
class OCRLine:
"""A line of text composed of words."""
text: str
bbox: BoundingBox
confidence: float
words: List[OCRWord] = field(default_factory=list)
line_index: int = 0
@property
def word_count(self) -> int:
return len(self.words)
@classmethod
def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine":
"""Create a line from a list of words."""
if not words:
raise ValueError("Cannot create line from empty word list")
text = " ".join(w.text for w in words)
confidence = sum(w.confidence for w in words) / len(words)
# Compute bounding box that encompasses all words
x_min = min(w.bbox.x_min for w in words)
y_min = min(w.bbox.y_min for w in words)
x_max = max(w.bbox.x_max for w in words)
y_max = max(w.bbox.y_max for w in words)
bbox = BoundingBox(
x_min=x_min, y_min=y_min,
x_max=x_max, y_max=y_max,
normalized=words[0].bbox.normalized
)
return cls(
text=text,
bbox=bbox,
confidence=confidence,
words=words,
line_index=line_index
)
@dataclass
class OCRBlock:
"""A block of text composed of lines (e.g., a paragraph)."""
text: str
bbox: BoundingBox
confidence: float
lines: List[OCRLine] = field(default_factory=list)
block_type: str = "text" # text, table, figure, etc.
@property
def line_count(self) -> int:
return len(self.lines)
@classmethod
def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock":
"""Create a block from a list of lines."""
if not lines:
raise ValueError("Cannot create block from empty line list")
text = "\n".join(line.text for line in lines)
confidence = sum(line.confidence for line in lines) / len(lines)
x_min = min(line.bbox.x_min for line in lines)
y_min = min(line.bbox.y_min for line in lines)
x_max = max(line.bbox.x_max for line in lines)
y_max = max(line.bbox.y_max for line in lines)
bbox = BoundingBox(
x_min=x_min, y_min=y_min,
x_max=x_max, y_max=y_max,
normalized=lines[0].bbox.normalized
)
return cls(
text=text,
bbox=bbox,
confidence=confidence,
lines=lines,
block_type=block_type
)
@dataclass
class OCRResult:
"""Complete OCR result for a single page/image."""
text: str # Full text of the page
blocks: List[OCRBlock] = field(default_factory=list)
lines: List[OCRLine] = field(default_factory=list)
words: List[OCRWord] = field(default_factory=list)
confidence: float = 0.0
language_detected: Optional[str] = None
orientation: float = 0.0 # Degrees
deskew_angle: float = 0.0
image_width: int = 0
image_height: int = 0
processing_time_ms: float = 0.0
engine_metadata: Dict[str, Any] = field(default_factory=dict)
@property
def word_count(self) -> int:
return len(self.words)
@property
def line_count(self) -> int:
return len(self.lines)
@property
def block_count(self) -> int:
return len(self.blocks)
def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str:
"""
Get text within a specific bounding box region.
Args:
bbox: Region to extract text from
threshold: Minimum IoU overlap required
Returns:
Concatenated text of words in region
"""
words_in_region = []
for word in self.words:
iou = word.bbox.iou(bbox)
if iou >= threshold or bbox.contains(word.bbox.center):
words_in_region.append(word)
# Sort by position (top to bottom, left to right)
words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min))
return " ".join(w.text for w in words_in_region)
class OCRModel(BatchableModel):
"""
Abstract base class for OCR models.
Implementations should handle:
- Text detection (finding text regions)
- Text recognition (converting regions to text)
- Word/line/block segmentation
- Confidence scoring
"""
def __init__(self, config: Optional[OCRConfig] = None):
super().__init__(config or OCRConfig(name="ocr"))
self.config: OCRConfig = self.config
def get_capabilities(self) -> List[ModelCapability]:
return [ModelCapability.OCR]
@abstractmethod
def recognize(
self,
image: ImageInput,
**kwargs
) -> OCRResult:
"""
Perform OCR on a single image.
Args:
image: Input image (numpy array, PIL Image, or path)
**kwargs: Additional engine-specific parameters
Returns:
OCRResult with detected text and locations
"""
pass
def process_batch(
self,
inputs: List[ImageInput],
**kwargs
) -> List[OCRResult]:
"""
Process multiple images.
Default implementation processes sequentially.
Override for optimized batch processing.
"""
return [self.recognize(img, **kwargs) for img in inputs]
def detect_text_regions(
self,
image: ImageInput,
**kwargs
) -> List[BoundingBox]:
"""
Detect text regions without performing recognition.
Useful for layout analysis or selective OCR.
Args:
image: Input image
**kwargs: Additional parameters
Returns:
List of bounding boxes containing text
"""
# Default: run full OCR and extract bboxes
result = self.recognize(image, **kwargs)
return [block.bbox for block in result.blocks]
def recognize_region(
self,
image: ImageInput,
region: BoundingBox,
**kwargs
) -> OCRResult:
"""
Perform OCR on a specific region of an image.
Args:
image: Full image
region: Region to OCR
**kwargs: Additional parameters
Returns:
OCR result for the region
"""
from .base import ensure_pil_image
pil_image = ensure_pil_image(image)
# Convert normalized coords to pixels if needed
if region.normalized:
pixel_bbox = region.to_pixel(pil_image.width, pil_image.height)
else:
pixel_bbox = region
# Crop the region
cropped = pil_image.crop((
int(pixel_bbox.x_min),
int(pixel_bbox.y_min),
int(pixel_bbox.x_max),
int(pixel_bbox.y_max)
))
# Run OCR on cropped region
result = self.recognize(cropped, **kwargs)
# Adjust bounding boxes to original image coordinates
offset_x = pixel_bbox.x_min
offset_y = pixel_bbox.y_min
for word in result.words:
word.bbox = BoundingBox(
x_min=word.bbox.x_min + offset_x,
y_min=word.bbox.y_min + offset_y,
x_max=word.bbox.x_max + offset_x,
y_max=word.bbox.y_max + offset_y,
normalized=False
)
for line in result.lines:
line.bbox = BoundingBox(
x_min=line.bbox.x_min + offset_x,
y_min=line.bbox.y_min + offset_y,
x_max=line.bbox.x_max + offset_x,
y_max=line.bbox.y_max + offset_y,
normalized=False
)
for block in result.blocks:
block.bbox = BoundingBox(
x_min=block.bbox.x_min + offset_x,
y_min=block.bbox.y_min + offset_y,
x_max=block.bbox.x_max + offset_x,
y_max=block.bbox.y_max + offset_y,
normalized=False
)
return result