|
|
""" |
|
|
Layout Detection Base Interface |
|
|
|
|
|
Defines the abstract interface for document layout detection. |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List, Optional, Dict, Any |
|
|
from dataclasses import dataclass, field |
|
|
from pydantic import BaseModel, Field |
|
|
import numpy as np |
|
|
|
|
|
from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion |
|
|
|
|
|
|
|
|
class LayoutConfig(BaseModel): |
|
|
"""Configuration for layout detection.""" |
|
|
|
|
|
method: str = Field( |
|
|
default="rule_based", |
|
|
description="Detection method: rule_based, paddle_structure, layoutlm" |
|
|
) |
|
|
|
|
|
|
|
|
min_confidence: float = Field( |
|
|
default=0.5, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Minimum confidence for detected regions" |
|
|
) |
|
|
|
|
|
|
|
|
detect_tables: bool = Field(default=True, description="Detect table regions") |
|
|
detect_figures: bool = Field(default=True, description="Detect figure regions") |
|
|
detect_headers: bool = Field(default=True, description="Detect header/footer") |
|
|
detect_titles: bool = Field(default=True, description="Detect title/heading") |
|
|
detect_lists: bool = Field(default=True, description="Detect list structures") |
|
|
|
|
|
|
|
|
merge_threshold: float = Field( |
|
|
default=0.7, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="IoU threshold for merging overlapping regions" |
|
|
) |
|
|
|
|
|
|
|
|
use_gpu: bool = Field(default=True, description="Use GPU acceleration") |
|
|
gpu_id: int = Field(default=0, ge=0, description="GPU device ID") |
|
|
|
|
|
|
|
|
table_min_rows: int = Field(default=2, ge=1, description="Minimum rows for table") |
|
|
table_min_cols: int = Field(default=2, ge=1, description="Minimum columns for table") |
|
|
|
|
|
|
|
|
title_max_lines: int = Field(default=3, description="Max lines for title") |
|
|
heading_font_ratio: float = Field( |
|
|
default=1.2, |
|
|
description="Font size ratio vs body text for headings" |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LayoutResult: |
|
|
"""Result of layout detection for a page.""" |
|
|
page: int |
|
|
regions: List[LayoutRegion] = field(default_factory=list) |
|
|
image_width: int = 0 |
|
|
image_height: int = 0 |
|
|
processing_time_ms: float = 0.0 |
|
|
|
|
|
|
|
|
success: bool = True |
|
|
error: Optional[str] = None |
|
|
|
|
|
def get_regions_by_type(self, layout_type: LayoutType) -> List[LayoutRegion]: |
|
|
"""Get regions of a specific type.""" |
|
|
return [r for r in self.regions if r.type == layout_type] |
|
|
|
|
|
def get_tables(self) -> List[LayoutRegion]: |
|
|
"""Get table regions.""" |
|
|
return self.get_regions_by_type(LayoutType.TABLE) |
|
|
|
|
|
def get_figures(self) -> List[LayoutRegion]: |
|
|
"""Get figure regions.""" |
|
|
return self.get_regions_by_type(LayoutType.FIGURE) |
|
|
|
|
|
def get_text_regions(self) -> List[LayoutRegion]: |
|
|
"""Get text-based regions (paragraph, title, heading, list).""" |
|
|
text_types = { |
|
|
LayoutType.TEXT, |
|
|
LayoutType.TITLE, |
|
|
LayoutType.HEADING, |
|
|
LayoutType.PARAGRAPH, |
|
|
LayoutType.LIST, |
|
|
} |
|
|
return [r for r in self.regions if r.type in text_types] |
|
|
|
|
|
|
|
|
class LayoutDetector(ABC): |
|
|
""" |
|
|
Abstract base class for layout detectors. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[LayoutConfig] = None): |
|
|
""" |
|
|
Initialize layout detector. |
|
|
|
|
|
Args: |
|
|
config: Layout detection configuration |
|
|
""" |
|
|
self.config = config or LayoutConfig() |
|
|
self._initialized = False |
|
|
|
|
|
@abstractmethod |
|
|
def initialize(self): |
|
|
"""Initialize the detector (load models, etc.).""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def detect( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
page_number: int = 0, |
|
|
ocr_regions: Optional[List[OCRRegion]] = None, |
|
|
) -> LayoutResult: |
|
|
""" |
|
|
Detect layout regions in an image. |
|
|
|
|
|
Args: |
|
|
image: Image as numpy array (RGB, HWC format) |
|
|
page_number: Page number |
|
|
ocr_regions: Optional OCR regions for text-aware detection |
|
|
|
|
|
Returns: |
|
|
LayoutResult with detected regions |
|
|
""" |
|
|
pass |
|
|
|
|
|
def detect_batch( |
|
|
self, |
|
|
images: List[np.ndarray], |
|
|
page_numbers: Optional[List[int]] = None, |
|
|
ocr_results: Optional[List[List[OCRRegion]]] = None, |
|
|
) -> List[LayoutResult]: |
|
|
""" |
|
|
Detect layout in multiple images. |
|
|
|
|
|
Args: |
|
|
images: List of images |
|
|
page_numbers: Optional page numbers |
|
|
ocr_results: Optional OCR regions for each page |
|
|
|
|
|
Returns: |
|
|
List of LayoutResult |
|
|
""" |
|
|
if page_numbers is None: |
|
|
page_numbers = list(range(len(images))) |
|
|
if ocr_results is None: |
|
|
ocr_results = [None] * len(images) |
|
|
|
|
|
results = [] |
|
|
for img, page_num, ocr in zip(images, page_numbers, ocr_results): |
|
|
results.append(self.detect(img, page_num, ocr)) |
|
|
return results |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
"""Return detector name.""" |
|
|
return self.__class__.__name__ |
|
|
|
|
|
@property |
|
|
def is_initialized(self) -> bool: |
|
|
"""Check if detector is initialized.""" |
|
|
return self._initialized |
|
|
|
|
|
def __enter__(self): |
|
|
"""Context manager entry.""" |
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
"""Context manager exit.""" |
|
|
pass |
|
|
|