|
|
""" |
|
|
Base OCR Interface |
|
|
|
|
|
Defines the abstract OCR engine interface and common data structures. |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
import numpy as np |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from ..schemas.core import BoundingBox, OCRRegion |
|
|
|
|
|
|
|
|
class OCRLanguage(str, Enum): |
|
|
"""Supported OCR languages.""" |
|
|
ENGLISH = "en" |
|
|
CHINESE_SIMPLIFIED = "ch" |
|
|
CHINESE_TRADITIONAL = "chinese_cht" |
|
|
FRENCH = "fr" |
|
|
GERMAN = "german" |
|
|
SPANISH = "es" |
|
|
ITALIAN = "it" |
|
|
PORTUGUESE = "pt" |
|
|
RUSSIAN = "ru" |
|
|
JAPANESE = "japan" |
|
|
KOREAN = "korean" |
|
|
ARABIC = "ar" |
|
|
HINDI = "hi" |
|
|
LATIN = "latin" |
|
|
|
|
|
|
|
|
class OCRConfig(BaseModel): |
|
|
"""Configuration for OCR processing.""" |
|
|
|
|
|
engine: str = Field(default="paddle", description="OCR engine: paddle or tesseract") |
|
|
|
|
|
|
|
|
languages: List[str] = Field( |
|
|
default=["en"], |
|
|
description="Languages to detect (ISO codes)" |
|
|
) |
|
|
|
|
|
|
|
|
det_db_thresh: float = Field( |
|
|
default=0.3, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Detection threshold for text regions" |
|
|
) |
|
|
det_db_box_thresh: float = Field( |
|
|
default=0.5, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Box detection threshold" |
|
|
) |
|
|
|
|
|
|
|
|
rec_batch_num: int = Field( |
|
|
default=6, |
|
|
ge=1, |
|
|
description="Recognition batch size" |
|
|
) |
|
|
min_confidence: float = Field( |
|
|
default=0.5, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Minimum confidence threshold" |
|
|
) |
|
|
|
|
|
|
|
|
use_gpu: bool = Field(default=True, description="Use GPU acceleration") |
|
|
gpu_id: int = Field(default=0, ge=0, description="GPU device ID") |
|
|
use_angle_cls: bool = Field( |
|
|
default=True, |
|
|
description="Use angle classification for rotated text" |
|
|
) |
|
|
use_dilation: bool = Field( |
|
|
default=False, |
|
|
description="Use dilation for detection" |
|
|
) |
|
|
|
|
|
|
|
|
drop_score: float = Field( |
|
|
default=0.5, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Drop results below this score" |
|
|
) |
|
|
return_word_boxes: bool = Field( |
|
|
default=False, |
|
|
description="Return word-level boxes (vs line-level)" |
|
|
) |
|
|
|
|
|
|
|
|
preprocess_resize: Optional[int] = Field( |
|
|
default=None, |
|
|
description="Resize image max dimension before OCR" |
|
|
) |
|
|
preprocess_denoise: bool = Field( |
|
|
default=False, |
|
|
description="Apply denoising before OCR" |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OCRResult: |
|
|
""" |
|
|
Result of OCR processing for a single image/page. |
|
|
""" |
|
|
regions: List[OCRRegion] = field(default_factory=list) |
|
|
full_text: str = "" |
|
|
confidence_avg: float = 0.0 |
|
|
processing_time_ms: float = 0.0 |
|
|
engine: str = "unknown" |
|
|
language_detected: Optional[str] = None |
|
|
|
|
|
|
|
|
success: bool = True |
|
|
error: Optional[str] = None |
|
|
|
|
|
def get_text_in_bbox(self, bbox: BoundingBox) -> str: |
|
|
"""Get text within a bounding box.""" |
|
|
texts = [] |
|
|
for region in self.regions: |
|
|
if bbox.contains(region.bbox) or bbox.iou(region.bbox) > 0.5: |
|
|
texts.append(region.text) |
|
|
return " ".join(texts) |
|
|
|
|
|
def filter_by_confidence(self, min_confidence: float) -> "OCRResult": |
|
|
"""Return new result with regions above confidence threshold.""" |
|
|
filtered_regions = [r for r in self.regions if r.confidence >= min_confidence] |
|
|
return OCRResult( |
|
|
regions=filtered_regions, |
|
|
full_text=" ".join(r.text for r in filtered_regions), |
|
|
confidence_avg=sum(r.confidence for r in filtered_regions) / len(filtered_regions) if filtered_regions else 0, |
|
|
processing_time_ms=self.processing_time_ms, |
|
|
engine=self.engine, |
|
|
language_detected=self.language_detected, |
|
|
success=self.success, |
|
|
error=self.error, |
|
|
) |
|
|
|
|
|
|
|
|
class OCREngine(ABC): |
|
|
""" |
|
|
Abstract base class for OCR engines. |
|
|
Defines the interface that all OCR implementations must follow. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[OCRConfig] = None): |
|
|
""" |
|
|
Initialize OCR engine. |
|
|
|
|
|
Args: |
|
|
config: OCR configuration |
|
|
""" |
|
|
self.config = config or OCRConfig() |
|
|
self._initialized = False |
|
|
|
|
|
@abstractmethod |
|
|
def initialize(self): |
|
|
"""Initialize the OCR engine (load models, etc.).""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def recognize( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
page_number: int = 0, |
|
|
) -> OCRResult: |
|
|
""" |
|
|
Perform OCR on an image. |
|
|
|
|
|
Args: |
|
|
image: Image as numpy array (RGB, HWC format) |
|
|
page_number: Page number for multi-page documents |
|
|
|
|
|
Returns: |
|
|
OCRResult with recognized text and regions |
|
|
""" |
|
|
pass |
|
|
|
|
|
def recognize_batch( |
|
|
self, |
|
|
images: List[np.ndarray], |
|
|
page_numbers: Optional[List[int]] = None, |
|
|
) -> List[OCRResult]: |
|
|
""" |
|
|
Perform OCR on multiple images. |
|
|
|
|
|
Args: |
|
|
images: List of images |
|
|
page_numbers: Optional page numbers |
|
|
|
|
|
Returns: |
|
|
List of OCRResult |
|
|
""" |
|
|
if page_numbers is None: |
|
|
page_numbers = list(range(len(images))) |
|
|
|
|
|
results = [] |
|
|
for img, page_num in zip(images, page_numbers): |
|
|
results.append(self.recognize(img, page_num)) |
|
|
return results |
|
|
|
|
|
@abstractmethod |
|
|
def get_supported_languages(self) -> List[str]: |
|
|
"""Return list of supported language codes.""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Return engine name.""" |
|
|
return self.__class__.__name__ |
|
|
|
|
|
@property |
|
|
def is_initialized(self) -> bool: |
|
|
"""Check if engine is initialized.""" |
|
|
return self._initialized |
|
|
|
|
|
def __enter__(self): |
|
|
"""Context manager entry.""" |
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
"""Context manager exit.""" |
|
|
pass |
|
|
|