|
|
""" |
|
|
Layout Detection Model Interface |
|
|
|
|
|
Abstract interface for document layout analysis models. |
|
|
Detects regions like text blocks, tables, figures, headers, etc. |
|
|
""" |
|
|
|
|
|
from abc import abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from ..chunks.models import BoundingBox, ChunkType |
|
|
from .base import ( |
|
|
BaseModel, |
|
|
BatchableModel, |
|
|
ImageInput, |
|
|
ModelCapability, |
|
|
ModelConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class LayoutRegionType(str, Enum): |
|
|
"""Types of layout regions that can be detected.""" |
|
|
|
|
|
|
|
|
TEXT = "text" |
|
|
TITLE = "title" |
|
|
HEADING = "heading" |
|
|
PARAGRAPH = "paragraph" |
|
|
LIST = "list" |
|
|
|
|
|
|
|
|
TABLE = "table" |
|
|
FIGURE = "figure" |
|
|
CHART = "chart" |
|
|
FORMULA = "formula" |
|
|
CODE = "code" |
|
|
|
|
|
|
|
|
HEADER = "header" |
|
|
FOOTER = "footer" |
|
|
PAGE_NUMBER = "page_number" |
|
|
CAPTION = "caption" |
|
|
FOOTNOTE = "footnote" |
|
|
|
|
|
|
|
|
LOGO = "logo" |
|
|
SIGNATURE = "signature" |
|
|
STAMP = "stamp" |
|
|
WATERMARK = "watermark" |
|
|
FORM_FIELD = "form_field" |
|
|
CHECKBOX = "checkbox" |
|
|
|
|
|
|
|
|
UNKNOWN = "unknown" |
|
|
|
|
|
def to_chunk_type(self) -> ChunkType: |
|
|
"""Convert layout region type to chunk type.""" |
|
|
mapping = { |
|
|
LayoutRegionType.TEXT: ChunkType.TEXT, |
|
|
LayoutRegionType.TITLE: ChunkType.TITLE, |
|
|
LayoutRegionType.HEADING: ChunkType.HEADING, |
|
|
LayoutRegionType.PARAGRAPH: ChunkType.PARAGRAPH, |
|
|
LayoutRegionType.LIST: ChunkType.LIST, |
|
|
LayoutRegionType.TABLE: ChunkType.TABLE, |
|
|
LayoutRegionType.FIGURE: ChunkType.FIGURE, |
|
|
LayoutRegionType.CHART: ChunkType.CHART, |
|
|
LayoutRegionType.FORMULA: ChunkType.FORMULA, |
|
|
LayoutRegionType.CODE: ChunkType.CODE, |
|
|
LayoutRegionType.HEADER: ChunkType.HEADER, |
|
|
LayoutRegionType.FOOTER: ChunkType.FOOTER, |
|
|
LayoutRegionType.PAGE_NUMBER: ChunkType.PAGE_NUMBER, |
|
|
LayoutRegionType.CAPTION: ChunkType.CAPTION, |
|
|
LayoutRegionType.FOOTNOTE: ChunkType.FOOTNOTE, |
|
|
LayoutRegionType.LOGO: ChunkType.LOGO, |
|
|
LayoutRegionType.SIGNATURE: ChunkType.SIGNATURE, |
|
|
LayoutRegionType.STAMP: ChunkType.STAMP, |
|
|
LayoutRegionType.WATERMARK: ChunkType.WATERMARK, |
|
|
LayoutRegionType.FORM_FIELD: ChunkType.FORM_FIELD, |
|
|
LayoutRegionType.CHECKBOX: ChunkType.CHECKBOX, |
|
|
} |
|
|
return mapping.get(self, ChunkType.TEXT) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LayoutConfig(ModelConfig): |
|
|
"""Configuration for layout detection models.""" |
|
|
|
|
|
min_confidence: float = 0.5 |
|
|
merge_overlapping: bool = True |
|
|
overlap_threshold: float = 0.5 |
|
|
detect_reading_order: bool = True |
|
|
detect_columns: bool = True |
|
|
region_types: Optional[List[LayoutRegionType]] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if not self.name: |
|
|
self.name = "layout_detector" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LayoutRegion: |
|
|
"""A detected layout region.""" |
|
|
|
|
|
region_type: LayoutRegionType |
|
|
bbox: BoundingBox |
|
|
confidence: float |
|
|
region_id: str = "" |
|
|
|
|
|
|
|
|
reading_order: int = -1 |
|
|
|
|
|
|
|
|
parent_id: Optional[str] = None |
|
|
child_ids: List[str] = field(default_factory=list) |
|
|
|
|
|
|
|
|
column_index: int = 0 |
|
|
num_columns: int = 1 |
|
|
|
|
|
|
|
|
attributes: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def __post_init__(self): |
|
|
if not self.region_id: |
|
|
import hashlib |
|
|
content = f"{self.region_type.value}_{self.bbox.xyxy}" |
|
|
self.region_id = hashlib.md5(content.encode()).hexdigest()[:12] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LayoutResult: |
|
|
"""Complete layout analysis result for a page.""" |
|
|
|
|
|
regions: List[LayoutRegion] = field(default_factory=list) |
|
|
reading_order: List[str] = field(default_factory=list) |
|
|
num_columns: int = 1 |
|
|
page_orientation: float = 0.0 |
|
|
image_width: int = 0 |
|
|
image_height: int = 0 |
|
|
processing_time_ms: float = 0.0 |
|
|
model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def get_regions_by_type(self, region_type: LayoutRegionType) -> List[LayoutRegion]: |
|
|
"""Get all regions of a specific type.""" |
|
|
return [r for r in self.regions if r.region_type == region_type] |
|
|
|
|
|
def get_region_by_id(self, region_id: str) -> Optional[LayoutRegion]: |
|
|
"""Get a region by its ID.""" |
|
|
for region in self.regions: |
|
|
if region.region_id == region_id: |
|
|
return region |
|
|
return None |
|
|
|
|
|
def get_ordered_regions(self) -> List[LayoutRegion]: |
|
|
"""Get regions in reading order.""" |
|
|
if not self.reading_order: |
|
|
|
|
|
return sorted( |
|
|
self.regions, |
|
|
key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
|
|
) |
|
|
|
|
|
ordered = [] |
|
|
for region_id in self.reading_order: |
|
|
region = self.get_region_by_id(region_id) |
|
|
if region: |
|
|
ordered.append(region) |
|
|
return ordered |
|
|
|
|
|
def get_tables(self) -> List[LayoutRegion]: |
|
|
"""Get all table regions.""" |
|
|
return self.get_regions_by_type(LayoutRegionType.TABLE) |
|
|
|
|
|
def get_figures(self) -> List[LayoutRegion]: |
|
|
"""Get all figure regions.""" |
|
|
return self.get_regions_by_type(LayoutRegionType.FIGURE) |
|
|
|
|
|
def get_text_regions(self) -> List[LayoutRegion]: |
|
|
"""Get all text-based regions.""" |
|
|
text_types = { |
|
|
LayoutRegionType.TEXT, |
|
|
LayoutRegionType.TITLE, |
|
|
LayoutRegionType.HEADING, |
|
|
LayoutRegionType.PARAGRAPH, |
|
|
LayoutRegionType.LIST, |
|
|
LayoutRegionType.CAPTION, |
|
|
LayoutRegionType.FOOTNOTE, |
|
|
} |
|
|
return [r for r in self.regions if r.region_type in text_types] |
|
|
|
|
|
|
|
|
class LayoutModel(BatchableModel): |
|
|
""" |
|
|
Abstract base class for layout detection models. |
|
|
|
|
|
Implementations should detect: |
|
|
- Document regions (text, tables, figures, etc.) |
|
|
- Reading order |
|
|
- Column structure |
|
|
- Region hierarchy |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[LayoutConfig] = None): |
|
|
super().__init__(config or LayoutConfig(name="layout")) |
|
|
self.config: LayoutConfig = self.config |
|
|
|
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
caps = [ModelCapability.LAYOUT_DETECTION] |
|
|
if self.config.detect_reading_order: |
|
|
caps.append(ModelCapability.READING_ORDER) |
|
|
return caps |
|
|
|
|
|
@abstractmethod |
|
|
def detect( |
|
|
self, |
|
|
image: ImageInput, |
|
|
**kwargs |
|
|
) -> LayoutResult: |
|
|
""" |
|
|
Detect layout regions in an image. |
|
|
|
|
|
Args: |
|
|
image: Input document image |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
LayoutResult with detected regions |
|
|
""" |
|
|
pass |
|
|
|
|
|
def process_batch( |
|
|
self, |
|
|
inputs: List[ImageInput], |
|
|
**kwargs |
|
|
) -> List[LayoutResult]: |
|
|
"""Process multiple images.""" |
|
|
return [self.detect(img, **kwargs) for img in inputs] |
|
|
|
|
|
def detect_tables( |
|
|
self, |
|
|
image: ImageInput, |
|
|
**kwargs |
|
|
) -> List[LayoutRegion]: |
|
|
""" |
|
|
Detect only table regions. |
|
|
|
|
|
Convenience method that filters layout detection results. |
|
|
""" |
|
|
result = self.detect(image, **kwargs) |
|
|
return result.get_tables() |
|
|
|
|
|
def detect_figures( |
|
|
self, |
|
|
image: ImageInput, |
|
|
**kwargs |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect only figure regions.""" |
|
|
result = self.detect(image, **kwargs) |
|
|
return result.get_figures() |
|
|
|
|
|
|
|
|
class ReadingOrderModel(BaseModel): |
|
|
""" |
|
|
Abstract base class for reading order determination. |
|
|
|
|
|
Some implementations may be separate from layout detection, |
|
|
requiring a specialized model for complex layouts. |
|
|
""" |
|
|
|
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
return [ModelCapability.READING_ORDER] |
|
|
|
|
|
@abstractmethod |
|
|
def determine_order( |
|
|
self, |
|
|
regions: List[LayoutRegion], |
|
|
image: Optional[ImageInput] = None, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Determine reading order for a list of regions. |
|
|
|
|
|
Args: |
|
|
regions: Layout regions to order |
|
|
image: Optional image for visual cues |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
List of region_ids in reading order |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class HeuristicReadingOrderModel(ReadingOrderModel): |
|
|
""" |
|
|
Simple heuristic-based reading order model. |
|
|
|
|
|
Uses geometric analysis for column detection and ordering. |
|
|
Suitable for simple document layouts. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[ModelConfig] = None): |
|
|
super().__init__(config or ModelConfig(name="heuristic_reading_order")) |
|
|
|
|
|
def load(self) -> None: |
|
|
self._is_loaded = True |
|
|
|
|
|
def unload(self) -> None: |
|
|
self._is_loaded = False |
|
|
|
|
|
def determine_order( |
|
|
self, |
|
|
regions: List[LayoutRegion], |
|
|
image: Optional[ImageInput] = None, |
|
|
column_threshold: float = 0.3, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Determine reading order using heuristics. |
|
|
|
|
|
Strategy: |
|
|
1. Detect columns based on x-coordinate clustering |
|
|
2. Within each column, sort top-to-bottom |
|
|
3. Process columns left-to-right |
|
|
""" |
|
|
if not regions: |
|
|
return [] |
|
|
|
|
|
|
|
|
columns = self._detect_columns(regions, column_threshold) |
|
|
|
|
|
|
|
|
ordered_ids = [] |
|
|
for column in columns: |
|
|
column_regions = sorted(column, key=lambda r: r.bbox.y_min) |
|
|
ordered_ids.extend(r.region_id for r in column_regions) |
|
|
|
|
|
return ordered_ids |
|
|
|
|
|
def _detect_columns( |
|
|
self, |
|
|
regions: List[LayoutRegion], |
|
|
threshold: float |
|
|
) -> List[List[LayoutRegion]]: |
|
|
"""Detect columns by x-coordinate clustering.""" |
|
|
if not regions: |
|
|
return [] |
|
|
|
|
|
|
|
|
sorted_regions = sorted(regions, key=lambda r: r.bbox.x_min) |
|
|
|
|
|
columns = [] |
|
|
current_column = [sorted_regions[0]] |
|
|
|
|
|
for region in sorted_regions[1:]: |
|
|
|
|
|
prev_region = current_column[-1] |
|
|
|
|
|
|
|
|
overlap_start = max(region.bbox.x_min, prev_region.bbox.x_min) |
|
|
overlap_end = min(region.bbox.x_max, prev_region.bbox.x_max) |
|
|
|
|
|
if overlap_end > overlap_start: |
|
|
|
|
|
current_column.append(region) |
|
|
else: |
|
|
|
|
|
columns.append(current_column) |
|
|
current_column = [region] |
|
|
|
|
|
columns.append(current_column) |
|
|
return columns |
|
|
|