""" OCR Model Interface Abstract interface for Optical Character Recognition models. Supports both local engines and cloud services. """ from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Tuple from ..chunks.models import BoundingBox from .base import ( BaseModel, BatchableModel, ImageInput, ModelCapability, ModelConfig, ) class OCREngine(str, Enum): """Supported OCR engines.""" PADDLEOCR = "paddleocr" TESSERACT = "tesseract" EASYOCR = "easyocr" CUSTOM = "custom" @dataclass class OCRConfig(ModelConfig): """Configuration for OCR models.""" engine: OCREngine = OCREngine.PADDLEOCR languages: List[str] = field(default_factory=lambda: ["en"]) detect_orientation: bool = True detect_tables: bool = True min_confidence: float = 0.5 # PaddleOCR specific use_angle_cls: bool = True use_gpu: bool = True # Tesseract specific tesseract_config: str = "" psm_mode: int = 3 # Page segmentation mode def __post_init__(self): super().__post_init__() if not self.name: self.name = f"ocr_{self.engine.value}" @dataclass class OCRWord: """A single recognized word with its bounding box.""" text: str bbox: BoundingBox confidence: float language: Optional[str] = None is_handwritten: bool = False font_size: Optional[float] = None is_bold: bool = False is_italic: bool = False @dataclass class OCRLine: """A line of text composed of words.""" text: str bbox: BoundingBox confidence: float words: List[OCRWord] = field(default_factory=list) line_index: int = 0 @property def word_count(self) -> int: return len(self.words) @classmethod def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine": """Create a line from a list of words.""" if not words: raise ValueError("Cannot create line from empty word list") text = " ".join(w.text for w in words) confidence = sum(w.confidence for w in words) / len(words) # Compute bounding box that encompasses all words x_min = min(w.bbox.x_min for w in words) y_min = min(w.bbox.y_min for w in words) x_max = max(w.bbox.x_max for w in words) y_max = max(w.bbox.y_max for w in words) bbox = BoundingBox( x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=words[0].bbox.normalized ) return cls( text=text, bbox=bbox, confidence=confidence, words=words, line_index=line_index ) @dataclass class OCRBlock: """A block of text composed of lines (e.g., a paragraph).""" text: str bbox: BoundingBox confidence: float lines: List[OCRLine] = field(default_factory=list) block_type: str = "text" # text, table, figure, etc. @property def line_count(self) -> int: return len(self.lines) @classmethod def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock": """Create a block from a list of lines.""" if not lines: raise ValueError("Cannot create block from empty line list") text = "\n".join(line.text for line in lines) confidence = sum(line.confidence for line in lines) / len(lines) x_min = min(line.bbox.x_min for line in lines) y_min = min(line.bbox.y_min for line in lines) x_max = max(line.bbox.x_max for line in lines) y_max = max(line.bbox.y_max for line in lines) bbox = BoundingBox( x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=lines[0].bbox.normalized ) return cls( text=text, bbox=bbox, confidence=confidence, lines=lines, block_type=block_type ) @dataclass class OCRResult: """Complete OCR result for a single page/image.""" text: str # Full text of the page blocks: List[OCRBlock] = field(default_factory=list) lines: List[OCRLine] = field(default_factory=list) words: List[OCRWord] = field(default_factory=list) confidence: float = 0.0 language_detected: Optional[str] = None orientation: float = 0.0 # Degrees deskew_angle: float = 0.0 image_width: int = 0 image_height: int = 0 processing_time_ms: float = 0.0 engine_metadata: Dict[str, Any] = field(default_factory=dict) @property def word_count(self) -> int: return len(self.words) @property def line_count(self) -> int: return len(self.lines) @property def block_count(self) -> int: return len(self.blocks) def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str: """ Get text within a specific bounding box region. Args: bbox: Region to extract text from threshold: Minimum IoU overlap required Returns: Concatenated text of words in region """ words_in_region = [] for word in self.words: iou = word.bbox.iou(bbox) if iou >= threshold or bbox.contains(word.bbox.center): words_in_region.append(word) # Sort by position (top to bottom, left to right) words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min)) return " ".join(w.text for w in words_in_region) class OCRModel(BatchableModel): """ Abstract base class for OCR models. Implementations should handle: - Text detection (finding text regions) - Text recognition (converting regions to text) - Word/line/block segmentation - Confidence scoring """ def __init__(self, config: Optional[OCRConfig] = None): super().__init__(config or OCRConfig(name="ocr")) self.config: OCRConfig = self.config def get_capabilities(self) -> List[ModelCapability]: return [ModelCapability.OCR] @abstractmethod def recognize( self, image: ImageInput, **kwargs ) -> OCRResult: """ Perform OCR on a single image. Args: image: Input image (numpy array, PIL Image, or path) **kwargs: Additional engine-specific parameters Returns: OCRResult with detected text and locations """ pass def process_batch( self, inputs: List[ImageInput], **kwargs ) -> List[OCRResult]: """ Process multiple images. Default implementation processes sequentially. Override for optimized batch processing. """ return [self.recognize(img, **kwargs) for img in inputs] def detect_text_regions( self, image: ImageInput, **kwargs ) -> List[BoundingBox]: """ Detect text regions without performing recognition. Useful for layout analysis or selective OCR. Args: image: Input image **kwargs: Additional parameters Returns: List of bounding boxes containing text """ # Default: run full OCR and extract bboxes result = self.recognize(image, **kwargs) return [block.bbox for block in result.blocks] def recognize_region( self, image: ImageInput, region: BoundingBox, **kwargs ) -> OCRResult: """ Perform OCR on a specific region of an image. Args: image: Full image region: Region to OCR **kwargs: Additional parameters Returns: OCR result for the region """ from .base import ensure_pil_image pil_image = ensure_pil_image(image) # Convert normalized coords to pixels if needed if region.normalized: pixel_bbox = region.to_pixel(pil_image.width, pil_image.height) else: pixel_bbox = region # Crop the region cropped = pil_image.crop(( int(pixel_bbox.x_min), int(pixel_bbox.y_min), int(pixel_bbox.x_max), int(pixel_bbox.y_max) )) # Run OCR on cropped region result = self.recognize(cropped, **kwargs) # Adjust bounding boxes to original image coordinates offset_x = pixel_bbox.x_min offset_y = pixel_bbox.y_min for word in result.words: word.bbox = BoundingBox( x_min=word.bbox.x_min + offset_x, y_min=word.bbox.y_min + offset_y, x_max=word.bbox.x_max + offset_x, y_max=word.bbox.y_max + offset_y, normalized=False ) for line in result.lines: line.bbox = BoundingBox( x_min=line.bbox.x_min + offset_x, y_min=line.bbox.y_min + offset_y, x_max=line.bbox.x_max + offset_x, y_max=line.bbox.y_max + offset_y, normalized=False ) for block in result.blocks: block.bbox = BoundingBox( x_min=block.bbox.x_min + offset_x, y_min=block.bbox.y_min + offset_y, x_max=block.bbox.x_max + offset_x, y_max=block.bbox.y_max + offset_y, normalized=False ) return result