MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Reading Order Reconstructor Implementation
Rule-based reading order reconstruction for document elements.
"""
import time
from typing import List, Optional, Dict, Any, Tuple
from collections import defaultdict
import numpy as np
from loguru import logger
from .base import ReadingOrderReconstructor, ReadingOrderConfig, ReadingOrderResult
from ..schemas.core import BoundingBox, LayoutRegion, OCRRegion, LayoutType
class RuleBasedReadingOrder(ReadingOrderReconstructor):
"""
Rule-based reading order reconstruction.
Handles:
- Single column documents
- Multi-column layouts
- Mixed layouts (text + figures)
- Headers and footers
"""
def initialize(self):
"""Initialize (no model loading needed)."""
self._initialized = True
logger.info("Initialized rule-based reading order reconstructor")
def reconstruct(
self,
regions: List[Any],
layout_regions: Optional[List[LayoutRegion]] = None,
page_width: Optional[int] = None,
page_height: Optional[int] = None,
) -> ReadingOrderResult:
"""Reconstruct reading order using rule-based approach."""
if not self._initialized:
self.initialize()
start_time = time.time()
if not regions:
return ReadingOrderResult(success=True)
# Extract bounding boxes from regions
bboxes = self._extract_bboxes(regions)
if not bboxes:
return ReadingOrderResult(success=True)
# Estimate page dimensions if not provided
if page_width is None:
page_width = max(b.x_max for b in bboxes)
if page_height is None:
page_height = max(b.y_max for b in bboxes)
# Detect columns
num_columns, column_assignments = self._detect_columns(bboxes, page_width)
# Sort within columns
if num_columns == 1:
order = self._sort_single_column(bboxes)
else:
order = self._sort_multi_column(bboxes, column_assignments, num_columns)
# Handle headers/footers
if self.config.header_footer_separate and layout_regions:
order = self._adjust_for_headers_footers(
order, regions, layout_regions, page_height
)
processing_time = (time.time() - start_time) * 1000
return ReadingOrderResult(
order=order,
ordered_regions=[regions[i] for i in order],
num_columns=num_columns,
column_assignments=column_assignments,
processing_time_ms=processing_time,
success=True,
)
def _extract_bboxes(self, regions: List[Any]) -> List[BoundingBox]:
"""Extract bounding boxes from regions."""
bboxes = []
for r in regions:
if hasattr(r, 'bbox'):
bboxes.append(r.bbox)
elif isinstance(r, BoundingBox):
bboxes.append(r)
return bboxes
def _detect_columns(
self,
bboxes: List[BoundingBox],
page_width: int,
) -> Tuple[int, Dict[int, int]]:
"""Detect column structure in the document."""
if not self.config.detect_columns or len(bboxes) < 4:
return 1, {i: 0 for i in range(len(bboxes))}
# Find vertical gaps (potential column separators)
x_centers = [(b.x_min + b.x_max) / 2 for b in bboxes]
# Cluster x-centers
min_gap = page_width * self.config.column_gap_threshold
sorted_centers = sorted(set(x_centers))
# Find large gaps
gaps = []
for i in range(len(sorted_centers) - 1):
gap = sorted_centers[i + 1] - sorted_centers[i]
if gap > min_gap:
gaps.append((sorted_centers[i] + sorted_centers[i + 1]) / 2)
# Determine number of columns (limited by max_columns)
num_columns = min(len(gaps) + 1, self.config.max_columns)
if num_columns == 1:
return 1, {i: 0 for i in range(len(bboxes))}
# Assign regions to columns
column_boundaries = [0] + sorted(gaps[:num_columns - 1]) + [page_width]
assignments = {}
for i, bbox in enumerate(bboxes):
center = (bbox.x_min + bbox.x_max) / 2
for col in range(num_columns):
if column_boundaries[col] <= center < column_boundaries[col + 1]:
assignments[i] = col
break
else:
assignments[i] = num_columns - 1
return num_columns, assignments
def _sort_single_column(self, bboxes: List[BoundingBox]) -> List[int]:
"""Sort regions in single-column layout."""
# Simple top-to-bottom, left-to-right
indexed = list(enumerate(bboxes))
if self.config.vertical_priority:
# Primary sort by y, secondary by x
indexed.sort(key=lambda x: (x[1].y_min, x[1].x_min))
else:
# Primary sort by x, secondary by y
indexed.sort(key=lambda x: (x[1].x_min, x[1].y_min))
if self.config.reading_direction == "rtl":
# Reverse horizontal order within rows
# Group by approximate y position
rows = self._group_by_y(indexed)
result = []
for row in rows:
row.reverse()
result.extend([i for i, _ in row])
return result
return [i for i, _ in indexed]
def _sort_multi_column(
self,
bboxes: List[BoundingBox],
column_assignments: Dict[int, int],
num_columns: int,
) -> List[int]:
"""Sort regions in multi-column layout."""
# Group by column
columns = defaultdict(list)
for i, bbox in enumerate(bboxes):
col = column_assignments.get(i, 0)
columns[col].append((i, bbox))
# Sort within each column (top to bottom)
for col in columns:
columns[col].sort(key=lambda x: (x[1].y_min, x[1].x_min))
# Interleave columns based on reading direction
result = []
if self.config.reading_direction == "ltr":
col_order = range(num_columns)
else:
col_order = range(num_columns - 1, -1, -1)
for col in col_order:
result.extend([i for i, _ in columns.get(col, [])])
return result
def _group_by_y(
self,
indexed_bboxes: List[Tuple[int, BoundingBox]],
tolerance: float = 10.0,
) -> List[List[Tuple[int, BoundingBox]]]:
"""Group bboxes into rows by y position."""
if not indexed_bboxes:
return []
# Sort by y
sorted_items = sorted(indexed_bboxes, key=lambda x: x[1].y_min)
rows = []
current_row = [sorted_items[0]]
current_y = sorted_items[0][1].y_min
for item in sorted_items[1:]:
if abs(item[1].y_min - current_y) <= tolerance:
current_row.append(item)
else:
# Sort current row by x before adding
current_row.sort(key=lambda x: x[1].x_min)
rows.append(current_row)
current_row = [item]
current_y = item[1].y_min
if current_row:
current_row.sort(key=lambda x: x[1].x_min)
rows.append(current_row)
return rows
def _adjust_for_headers_footers(
self,
order: List[int],
regions: List[Any],
layout_regions: List[LayoutRegion],
page_height: int,
) -> List[int]:
"""Adjust order to put headers first and footers last."""
# Find header and footer layout regions
header_indices = set()
footer_indices = set()
header_y_threshold = page_height * 0.1
footer_y_threshold = page_height * 0.9
for layout_r in layout_regions:
if layout_r.type == LayoutType.HEADER:
for i, r in enumerate(regions):
if hasattr(r, 'bbox') and layout_r.bbox.contains(r.bbox):
header_indices.add(i)
elif layout_r.type == LayoutType.FOOTER:
for i, r in enumerate(regions):
if hasattr(r, 'bbox') and layout_r.bbox.contains(r.bbox):
footer_indices.add(i)
# Also detect by position
for i, r in enumerate(regions):
if hasattr(r, 'bbox'):
if r.bbox.y_max < header_y_threshold:
header_indices.add(i)
elif r.bbox.y_min > footer_y_threshold:
footer_indices.add(i)
# Reorder: headers first, then body, then footers
headers = [i for i in order if i in header_indices]
footers = [i for i in order if i in footer_indices]
body = [i for i in order if i not in header_indices and i not in footer_indices]
return headers + body + footers
# Factory
_reading_order: Optional[ReadingOrderReconstructor] = None
def get_reading_order_reconstructor(
config: Optional[ReadingOrderConfig] = None,
) -> ReadingOrderReconstructor:
"""Get or create singleton reading order reconstructor."""
global _reading_order
if _reading_order is None:
config = config or ReadingOrderConfig()
_reading_order = RuleBasedReadingOrder(config)
_reading_order.initialize()
return _reading_order