File size: 9,579 Bytes
d520909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
"""
OCR Model Interface
Abstract interface for Optical Character Recognition models.
Supports both local engines and cloud services.
"""
from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from ..chunks.models import BoundingBox
from .base import (
BaseModel,
BatchableModel,
ImageInput,
ModelCapability,
ModelConfig,
)
class OCREngine(str, Enum):
"""Supported OCR engines."""
PADDLEOCR = "paddleocr"
TESSERACT = "tesseract"
EASYOCR = "easyocr"
CUSTOM = "custom"
@dataclass
class OCRConfig(ModelConfig):
"""Configuration for OCR models."""
engine: OCREngine = OCREngine.PADDLEOCR
languages: List[str] = field(default_factory=lambda: ["en"])
detect_orientation: bool = True
detect_tables: bool = True
min_confidence: float = 0.5
# PaddleOCR specific
use_angle_cls: bool = True
use_gpu: bool = True
# Tesseract specific
tesseract_config: str = ""
psm_mode: int = 3 # Page segmentation mode
def __post_init__(self):
super().__post_init__()
if not self.name:
self.name = f"ocr_{self.engine.value}"
@dataclass
class OCRWord:
"""A single recognized word with its bounding box."""
text: str
bbox: BoundingBox
confidence: float
language: Optional[str] = None
is_handwritten: bool = False
font_size: Optional[float] = None
is_bold: bool = False
is_italic: bool = False
@dataclass
class OCRLine:
"""A line of text composed of words."""
text: str
bbox: BoundingBox
confidence: float
words: List[OCRWord] = field(default_factory=list)
line_index: int = 0
@property
def word_count(self) -> int:
return len(self.words)
@classmethod
def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine":
"""Create a line from a list of words."""
if not words:
raise ValueError("Cannot create line from empty word list")
text = " ".join(w.text for w in words)
confidence = sum(w.confidence for w in words) / len(words)
# Compute bounding box that encompasses all words
x_min = min(w.bbox.x_min for w in words)
y_min = min(w.bbox.y_min for w in words)
x_max = max(w.bbox.x_max for w in words)
y_max = max(w.bbox.y_max for w in words)
bbox = BoundingBox(
x_min=x_min, y_min=y_min,
x_max=x_max, y_max=y_max,
normalized=words[0].bbox.normalized
)
return cls(
text=text,
bbox=bbox,
confidence=confidence,
words=words,
line_index=line_index
)
@dataclass
class OCRBlock:
"""A block of text composed of lines (e.g., a paragraph)."""
text: str
bbox: BoundingBox
confidence: float
lines: List[OCRLine] = field(default_factory=list)
block_type: str = "text" # text, table, figure, etc.
@property
def line_count(self) -> int:
return len(self.lines)
@classmethod
def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock":
"""Create a block from a list of lines."""
if not lines:
raise ValueError("Cannot create block from empty line list")
text = "\n".join(line.text for line in lines)
confidence = sum(line.confidence for line in lines) / len(lines)
x_min = min(line.bbox.x_min for line in lines)
y_min = min(line.bbox.y_min for line in lines)
x_max = max(line.bbox.x_max for line in lines)
y_max = max(line.bbox.y_max for line in lines)
bbox = BoundingBox(
x_min=x_min, y_min=y_min,
x_max=x_max, y_max=y_max,
normalized=lines[0].bbox.normalized
)
return cls(
text=text,
bbox=bbox,
confidence=confidence,
lines=lines,
block_type=block_type
)
@dataclass
class OCRResult:
"""Complete OCR result for a single page/image."""
text: str # Full text of the page
blocks: List[OCRBlock] = field(default_factory=list)
lines: List[OCRLine] = field(default_factory=list)
words: List[OCRWord] = field(default_factory=list)
confidence: float = 0.0
language_detected: Optional[str] = None
orientation: float = 0.0 # Degrees
deskew_angle: float = 0.0
image_width: int = 0
image_height: int = 0
processing_time_ms: float = 0.0
engine_metadata: Dict[str, Any] = field(default_factory=dict)
@property
def word_count(self) -> int:
return len(self.words)
@property
def line_count(self) -> int:
return len(self.lines)
@property
def block_count(self) -> int:
return len(self.blocks)
def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str:
"""
Get text within a specific bounding box region.
Args:
bbox: Region to extract text from
threshold: Minimum IoU overlap required
Returns:
Concatenated text of words in region
"""
words_in_region = []
for word in self.words:
iou = word.bbox.iou(bbox)
if iou >= threshold or bbox.contains(word.bbox.center):
words_in_region.append(word)
# Sort by position (top to bottom, left to right)
words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min))
return " ".join(w.text for w in words_in_region)
class OCRModel(BatchableModel):
"""
Abstract base class for OCR models.
Implementations should handle:
- Text detection (finding text regions)
- Text recognition (converting regions to text)
- Word/line/block segmentation
- Confidence scoring
"""
def __init__(self, config: Optional[OCRConfig] = None):
super().__init__(config or OCRConfig(name="ocr"))
self.config: OCRConfig = self.config
def get_capabilities(self) -> List[ModelCapability]:
return [ModelCapability.OCR]
@abstractmethod
def recognize(
self,
image: ImageInput,
**kwargs
) -> OCRResult:
"""
Perform OCR on a single image.
Args:
image: Input image (numpy array, PIL Image, or path)
**kwargs: Additional engine-specific parameters
Returns:
OCRResult with detected text and locations
"""
pass
def process_batch(
self,
inputs: List[ImageInput],
**kwargs
) -> List[OCRResult]:
"""
Process multiple images.
Default implementation processes sequentially.
Override for optimized batch processing.
"""
return [self.recognize(img, **kwargs) for img in inputs]
def detect_text_regions(
self,
image: ImageInput,
**kwargs
) -> List[BoundingBox]:
"""
Detect text regions without performing recognition.
Useful for layout analysis or selective OCR.
Args:
image: Input image
**kwargs: Additional parameters
Returns:
List of bounding boxes containing text
"""
# Default: run full OCR and extract bboxes
result = self.recognize(image, **kwargs)
return [block.bbox for block in result.blocks]
def recognize_region(
self,
image: ImageInput,
region: BoundingBox,
**kwargs
) -> OCRResult:
"""
Perform OCR on a specific region of an image.
Args:
image: Full image
region: Region to OCR
**kwargs: Additional parameters
Returns:
OCR result for the region
"""
from .base import ensure_pil_image
pil_image = ensure_pil_image(image)
# Convert normalized coords to pixels if needed
if region.normalized:
pixel_bbox = region.to_pixel(pil_image.width, pil_image.height)
else:
pixel_bbox = region
# Crop the region
cropped = pil_image.crop((
int(pixel_bbox.x_min),
int(pixel_bbox.y_min),
int(pixel_bbox.x_max),
int(pixel_bbox.y_max)
))
# Run OCR on cropped region
result = self.recognize(cropped, **kwargs)
# Adjust bounding boxes to original image coordinates
offset_x = pixel_bbox.x_min
offset_y = pixel_bbox.y_min
for word in result.words:
word.bbox = BoundingBox(
x_min=word.bbox.x_min + offset_x,
y_min=word.bbox.y_min + offset_y,
x_max=word.bbox.x_max + offset_x,
y_max=word.bbox.y_max + offset_y,
normalized=False
)
for line in result.lines:
line.bbox = BoundingBox(
x_min=line.bbox.x_min + offset_x,
y_min=line.bbox.y_min + offset_y,
x_max=line.bbox.x_max + offset_x,
y_max=line.bbox.y_max + offset_y,
normalized=False
)
for block in result.blocks:
block.bbox = BoundingBox(
x_min=block.bbox.x_min + offset_x,
y_min=block.bbox.y_min + offset_y,
x_max=block.bbox.x_max + offset_x,
y_max=block.bbox.y_max + offset_y,
normalized=False
)
return result
|