|
|
""" |
|
|
OCR Model Interface |
|
|
|
|
|
Abstract interface for Optical Character Recognition models. |
|
|
Supports both local engines and cloud services. |
|
|
""" |
|
|
|
|
|
from abc import abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from ..chunks.models import BoundingBox |
|
|
from .base import ( |
|
|
BaseModel, |
|
|
BatchableModel, |
|
|
ImageInput, |
|
|
ModelCapability, |
|
|
ModelConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class OCREngine(str, Enum): |
|
|
"""Supported OCR engines.""" |
|
|
|
|
|
PADDLEOCR = "paddleocr" |
|
|
TESSERACT = "tesseract" |
|
|
EASYOCR = "easyocr" |
|
|
CUSTOM = "custom" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OCRConfig(ModelConfig): |
|
|
"""Configuration for OCR models.""" |
|
|
|
|
|
engine: OCREngine = OCREngine.PADDLEOCR |
|
|
languages: List[str] = field(default_factory=lambda: ["en"]) |
|
|
detect_orientation: bool = True |
|
|
detect_tables: bool = True |
|
|
min_confidence: float = 0.5 |
|
|
|
|
|
use_angle_cls: bool = True |
|
|
use_gpu: bool = True |
|
|
|
|
|
tesseract_config: str = "" |
|
|
psm_mode: int = 3 |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if not self.name: |
|
|
self.name = f"ocr_{self.engine.value}" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OCRWord: |
|
|
"""A single recognized word with its bounding box.""" |
|
|
|
|
|
text: str |
|
|
bbox: BoundingBox |
|
|
confidence: float |
|
|
language: Optional[str] = None |
|
|
is_handwritten: bool = False |
|
|
font_size: Optional[float] = None |
|
|
is_bold: bool = False |
|
|
is_italic: bool = False |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OCRLine: |
|
|
"""A line of text composed of words.""" |
|
|
|
|
|
text: str |
|
|
bbox: BoundingBox |
|
|
confidence: float |
|
|
words: List[OCRWord] = field(default_factory=list) |
|
|
line_index: int = 0 |
|
|
|
|
|
@property |
|
|
def word_count(self) -> int: |
|
|
return len(self.words) |
|
|
|
|
|
@classmethod |
|
|
def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine": |
|
|
"""Create a line from a list of words.""" |
|
|
if not words: |
|
|
raise ValueError("Cannot create line from empty word list") |
|
|
|
|
|
text = " ".join(w.text for w in words) |
|
|
confidence = sum(w.confidence for w in words) / len(words) |
|
|
|
|
|
|
|
|
x_min = min(w.bbox.x_min for w in words) |
|
|
y_min = min(w.bbox.y_min for w in words) |
|
|
x_max = max(w.bbox.x_max for w in words) |
|
|
y_max = max(w.bbox.y_max for w in words) |
|
|
|
|
|
bbox = BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=words[0].bbox.normalized |
|
|
) |
|
|
|
|
|
return cls( |
|
|
text=text, |
|
|
bbox=bbox, |
|
|
confidence=confidence, |
|
|
words=words, |
|
|
line_index=line_index |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OCRBlock: |
|
|
"""A block of text composed of lines (e.g., a paragraph).""" |
|
|
|
|
|
text: str |
|
|
bbox: BoundingBox |
|
|
confidence: float |
|
|
lines: List[OCRLine] = field(default_factory=list) |
|
|
block_type: str = "text" |
|
|
|
|
|
@property |
|
|
def line_count(self) -> int: |
|
|
return len(self.lines) |
|
|
|
|
|
@classmethod |
|
|
def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock": |
|
|
"""Create a block from a list of lines.""" |
|
|
if not lines: |
|
|
raise ValueError("Cannot create block from empty line list") |
|
|
|
|
|
text = "\n".join(line.text for line in lines) |
|
|
confidence = sum(line.confidence for line in lines) / len(lines) |
|
|
|
|
|
x_min = min(line.bbox.x_min for line in lines) |
|
|
y_min = min(line.bbox.y_min for line in lines) |
|
|
x_max = max(line.bbox.x_max for line in lines) |
|
|
y_max = max(line.bbox.y_max for line in lines) |
|
|
|
|
|
bbox = BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=lines[0].bbox.normalized |
|
|
) |
|
|
|
|
|
return cls( |
|
|
text=text, |
|
|
bbox=bbox, |
|
|
confidence=confidence, |
|
|
lines=lines, |
|
|
block_type=block_type |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OCRResult: |
|
|
"""Complete OCR result for a single page/image.""" |
|
|
|
|
|
text: str |
|
|
blocks: List[OCRBlock] = field(default_factory=list) |
|
|
lines: List[OCRLine] = field(default_factory=list) |
|
|
words: List[OCRWord] = field(default_factory=list) |
|
|
confidence: float = 0.0 |
|
|
language_detected: Optional[str] = None |
|
|
orientation: float = 0.0 |
|
|
deskew_angle: float = 0.0 |
|
|
image_width: int = 0 |
|
|
image_height: int = 0 |
|
|
processing_time_ms: float = 0.0 |
|
|
engine_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@property |
|
|
def word_count(self) -> int: |
|
|
return len(self.words) |
|
|
|
|
|
@property |
|
|
def line_count(self) -> int: |
|
|
return len(self.lines) |
|
|
|
|
|
@property |
|
|
def block_count(self) -> int: |
|
|
return len(self.blocks) |
|
|
|
|
|
def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str: |
|
|
""" |
|
|
Get text within a specific bounding box region. |
|
|
|
|
|
Args: |
|
|
bbox: Region to extract text from |
|
|
threshold: Minimum IoU overlap required |
|
|
|
|
|
Returns: |
|
|
Concatenated text of words in region |
|
|
""" |
|
|
words_in_region = [] |
|
|
for word in self.words: |
|
|
iou = word.bbox.iou(bbox) |
|
|
if iou >= threshold or bbox.contains(word.bbox.center): |
|
|
words_in_region.append(word) |
|
|
|
|
|
|
|
|
words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min)) |
|
|
return " ".join(w.text for w in words_in_region) |
|
|
|
|
|
|
|
|
class OCRModel(BatchableModel): |
|
|
""" |
|
|
Abstract base class for OCR models. |
|
|
|
|
|
Implementations should handle: |
|
|
- Text detection (finding text regions) |
|
|
- Text recognition (converting regions to text) |
|
|
- Word/line/block segmentation |
|
|
- Confidence scoring |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[OCRConfig] = None): |
|
|
super().__init__(config or OCRConfig(name="ocr")) |
|
|
self.config: OCRConfig = self.config |
|
|
|
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
return [ModelCapability.OCR] |
|
|
|
|
|
@abstractmethod |
|
|
def recognize( |
|
|
self, |
|
|
image: ImageInput, |
|
|
**kwargs |
|
|
) -> OCRResult: |
|
|
""" |
|
|
Perform OCR on a single image. |
|
|
|
|
|
Args: |
|
|
image: Input image (numpy array, PIL Image, or path) |
|
|
**kwargs: Additional engine-specific parameters |
|
|
|
|
|
Returns: |
|
|
OCRResult with detected text and locations |
|
|
""" |
|
|
pass |
|
|
|
|
|
def process_batch( |
|
|
self, |
|
|
inputs: List[ImageInput], |
|
|
**kwargs |
|
|
) -> List[OCRResult]: |
|
|
""" |
|
|
Process multiple images. |
|
|
|
|
|
Default implementation processes sequentially. |
|
|
Override for optimized batch processing. |
|
|
""" |
|
|
return [self.recognize(img, **kwargs) for img in inputs] |
|
|
|
|
|
def detect_text_regions( |
|
|
self, |
|
|
image: ImageInput, |
|
|
**kwargs |
|
|
) -> List[BoundingBox]: |
|
|
""" |
|
|
Detect text regions without performing recognition. |
|
|
|
|
|
Useful for layout analysis or selective OCR. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
List of bounding boxes containing text |
|
|
""" |
|
|
|
|
|
result = self.recognize(image, **kwargs) |
|
|
return [block.bbox for block in result.blocks] |
|
|
|
|
|
def recognize_region( |
|
|
self, |
|
|
image: ImageInput, |
|
|
region: BoundingBox, |
|
|
**kwargs |
|
|
) -> OCRResult: |
|
|
""" |
|
|
Perform OCR on a specific region of an image. |
|
|
|
|
|
Args: |
|
|
image: Full image |
|
|
region: Region to OCR |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
OCR result for the region |
|
|
""" |
|
|
from .base import ensure_pil_image |
|
|
|
|
|
pil_image = ensure_pil_image(image) |
|
|
|
|
|
|
|
|
if region.normalized: |
|
|
pixel_bbox = region.to_pixel(pil_image.width, pil_image.height) |
|
|
else: |
|
|
pixel_bbox = region |
|
|
|
|
|
|
|
|
cropped = pil_image.crop(( |
|
|
int(pixel_bbox.x_min), |
|
|
int(pixel_bbox.y_min), |
|
|
int(pixel_bbox.x_max), |
|
|
int(pixel_bbox.y_max) |
|
|
)) |
|
|
|
|
|
|
|
|
result = self.recognize(cropped, **kwargs) |
|
|
|
|
|
|
|
|
offset_x = pixel_bbox.x_min |
|
|
offset_y = pixel_bbox.y_min |
|
|
|
|
|
for word in result.words: |
|
|
word.bbox = BoundingBox( |
|
|
x_min=word.bbox.x_min + offset_x, |
|
|
y_min=word.bbox.y_min + offset_y, |
|
|
x_max=word.bbox.x_max + offset_x, |
|
|
y_max=word.bbox.y_max + offset_y, |
|
|
normalized=False |
|
|
) |
|
|
|
|
|
for line in result.lines: |
|
|
line.bbox = BoundingBox( |
|
|
x_min=line.bbox.x_min + offset_x, |
|
|
y_min=line.bbox.y_min + offset_y, |
|
|
x_max=line.bbox.x_max + offset_x, |
|
|
y_max=line.bbox.y_max + offset_y, |
|
|
normalized=False |
|
|
) |
|
|
|
|
|
for block in result.blocks: |
|
|
block.bbox = BoundingBox( |
|
|
x_min=block.bbox.x_min + offset_x, |
|
|
y_min=block.bbox.y_min + offset_y, |
|
|
x_max=block.bbox.x_max + offset_x, |
|
|
y_max=block.bbox.y_max + offset_y, |
|
|
normalized=False |
|
|
) |
|
|
|
|
|
return result |
|
|
|