File size: 9,579 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
OCR Model Interface

Abstract interface for Optical Character Recognition models.
Supports both local engines and cloud services.
"""

from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

from ..chunks.models import BoundingBox
from .base import (
    BaseModel,
    BatchableModel,
    ImageInput,
    ModelCapability,
    ModelConfig,
)


class OCREngine(str, Enum):
    """Supported OCR engines."""

    PADDLEOCR = "paddleocr"
    TESSERACT = "tesseract"
    EASYOCR = "easyocr"
    CUSTOM = "custom"


@dataclass
class OCRConfig(ModelConfig):
    """Configuration for OCR models."""

    engine: OCREngine = OCREngine.PADDLEOCR
    languages: List[str] = field(default_factory=lambda: ["en"])
    detect_orientation: bool = True
    detect_tables: bool = True
    min_confidence: float = 0.5
    # PaddleOCR specific
    use_angle_cls: bool = True
    use_gpu: bool = True
    # Tesseract specific
    tesseract_config: str = ""
    psm_mode: int = 3  # Page segmentation mode

    def __post_init__(self):
        super().__post_init__()
        if not self.name:
            self.name = f"ocr_{self.engine.value}"


@dataclass
class OCRWord:
    """A single recognized word with its bounding box."""

    text: str
    bbox: BoundingBox
    confidence: float
    language: Optional[str] = None
    is_handwritten: bool = False
    font_size: Optional[float] = None
    is_bold: bool = False
    is_italic: bool = False


@dataclass
class OCRLine:
    """A line of text composed of words."""

    text: str
    bbox: BoundingBox
    confidence: float
    words: List[OCRWord] = field(default_factory=list)
    line_index: int = 0

    @property
    def word_count(self) -> int:
        return len(self.words)

    @classmethod
    def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine":
        """Create a line from a list of words."""
        if not words:
            raise ValueError("Cannot create line from empty word list")

        text = " ".join(w.text for w in words)
        confidence = sum(w.confidence for w in words) / len(words)

        # Compute bounding box that encompasses all words
        x_min = min(w.bbox.x_min for w in words)
        y_min = min(w.bbox.y_min for w in words)
        x_max = max(w.bbox.x_max for w in words)
        y_max = max(w.bbox.y_max for w in words)

        bbox = BoundingBox(
            x_min=x_min, y_min=y_min,
            x_max=x_max, y_max=y_max,
            normalized=words[0].bbox.normalized
        )

        return cls(
            text=text,
            bbox=bbox,
            confidence=confidence,
            words=words,
            line_index=line_index
        )


@dataclass
class OCRBlock:
    """A block of text composed of lines (e.g., a paragraph)."""

    text: str
    bbox: BoundingBox
    confidence: float
    lines: List[OCRLine] = field(default_factory=list)
    block_type: str = "text"  # text, table, figure, etc.

    @property
    def line_count(self) -> int:
        return len(self.lines)

    @classmethod
    def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock":
        """Create a block from a list of lines."""
        if not lines:
            raise ValueError("Cannot create block from empty line list")

        text = "\n".join(line.text for line in lines)
        confidence = sum(line.confidence for line in lines) / len(lines)

        x_min = min(line.bbox.x_min for line in lines)
        y_min = min(line.bbox.y_min for line in lines)
        x_max = max(line.bbox.x_max for line in lines)
        y_max = max(line.bbox.y_max for line in lines)

        bbox = BoundingBox(
            x_min=x_min, y_min=y_min,
            x_max=x_max, y_max=y_max,
            normalized=lines[0].bbox.normalized
        )

        return cls(
            text=text,
            bbox=bbox,
            confidence=confidence,
            lines=lines,
            block_type=block_type
        )


@dataclass
class OCRResult:
    """Complete OCR result for a single page/image."""

    text: str  # Full text of the page
    blocks: List[OCRBlock] = field(default_factory=list)
    lines: List[OCRLine] = field(default_factory=list)
    words: List[OCRWord] = field(default_factory=list)
    confidence: float = 0.0
    language_detected: Optional[str] = None
    orientation: float = 0.0  # Degrees
    deskew_angle: float = 0.0
    image_width: int = 0
    image_height: int = 0
    processing_time_ms: float = 0.0
    engine_metadata: Dict[str, Any] = field(default_factory=dict)

    @property
    def word_count(self) -> int:
        return len(self.words)

    @property
    def line_count(self) -> int:
        return len(self.lines)

    @property
    def block_count(self) -> int:
        return len(self.blocks)

    def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str:
        """
        Get text within a specific bounding box region.

        Args:
            bbox: Region to extract text from
            threshold: Minimum IoU overlap required

        Returns:
            Concatenated text of words in region
        """
        words_in_region = []
        for word in self.words:
            iou = word.bbox.iou(bbox)
            if iou >= threshold or bbox.contains(word.bbox.center):
                words_in_region.append(word)

        # Sort by position (top to bottom, left to right)
        words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min))
        return " ".join(w.text for w in words_in_region)


class OCRModel(BatchableModel):
    """
    Abstract base class for OCR models.

    Implementations should handle:
    - Text detection (finding text regions)
    - Text recognition (converting regions to text)
    - Word/line/block segmentation
    - Confidence scoring
    """

    def __init__(self, config: Optional[OCRConfig] = None):
        super().__init__(config or OCRConfig(name="ocr"))
        self.config: OCRConfig = self.config

    def get_capabilities(self) -> List[ModelCapability]:
        return [ModelCapability.OCR]

    @abstractmethod
    def recognize(
        self,
        image: ImageInput,
        **kwargs
    ) -> OCRResult:
        """
        Perform OCR on a single image.

        Args:
            image: Input image (numpy array, PIL Image, or path)
            **kwargs: Additional engine-specific parameters

        Returns:
            OCRResult with detected text and locations
        """
        pass

    def process_batch(
        self,
        inputs: List[ImageInput],
        **kwargs
    ) -> List[OCRResult]:
        """
        Process multiple images.

        Default implementation processes sequentially.
        Override for optimized batch processing.
        """
        return [self.recognize(img, **kwargs) for img in inputs]

    def detect_text_regions(
        self,
        image: ImageInput,
        **kwargs
    ) -> List[BoundingBox]:
        """
        Detect text regions without performing recognition.

        Useful for layout analysis or selective OCR.

        Args:
            image: Input image
            **kwargs: Additional parameters

        Returns:
            List of bounding boxes containing text
        """
        # Default: run full OCR and extract bboxes
        result = self.recognize(image, **kwargs)
        return [block.bbox for block in result.blocks]

    def recognize_region(
        self,
        image: ImageInput,
        region: BoundingBox,
        **kwargs
    ) -> OCRResult:
        """
        Perform OCR on a specific region of an image.

        Args:
            image: Full image
            region: Region to OCR
            **kwargs: Additional parameters

        Returns:
            OCR result for the region
        """
        from .base import ensure_pil_image

        pil_image = ensure_pil_image(image)

        # Convert normalized coords to pixels if needed
        if region.normalized:
            pixel_bbox = region.to_pixel(pil_image.width, pil_image.height)
        else:
            pixel_bbox = region

        # Crop the region
        cropped = pil_image.crop((
            int(pixel_bbox.x_min),
            int(pixel_bbox.y_min),
            int(pixel_bbox.x_max),
            int(pixel_bbox.y_max)
        ))

        # Run OCR on cropped region
        result = self.recognize(cropped, **kwargs)

        # Adjust bounding boxes to original image coordinates
        offset_x = pixel_bbox.x_min
        offset_y = pixel_bbox.y_min

        for word in result.words:
            word.bbox = BoundingBox(
                x_min=word.bbox.x_min + offset_x,
                y_min=word.bbox.y_min + offset_y,
                x_max=word.bbox.x_max + offset_x,
                y_max=word.bbox.y_max + offset_y,
                normalized=False
            )

        for line in result.lines:
            line.bbox = BoundingBox(
                x_min=line.bbox.x_min + offset_x,
                y_min=line.bbox.y_min + offset_y,
                x_max=line.bbox.x_max + offset_x,
                y_max=line.bbox.y_max + offset_y,
                normalized=False
            )

        for block in result.blocks:
            block.bbox = BoundingBox(
                x_min=block.bbox.x_min + offset_x,
                y_min=block.bbox.y_min + offset_y,
                x_max=block.bbox.x_max + offset_x,
                y_max=block.bbox.y_max + offset_y,
                normalized=False
            )

        return result