|
|
""" |
|
|
Layout Detector Implementations |
|
|
|
|
|
Rule-based and model-based layout detection. |
|
|
""" |
|
|
|
|
|
import time |
|
|
import uuid |
|
|
from typing import List, Optional, Dict, Tuple |
|
|
from collections import defaultdict |
|
|
import numpy as np |
|
|
from loguru import logger |
|
|
|
|
|
from .base import LayoutDetector, LayoutConfig, LayoutResult |
|
|
from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion |
|
|
|
|
|
|
|
|
class RuleBasedLayoutDetector(LayoutDetector): |
|
|
""" |
|
|
Rule-based layout detector using OCR region analysis. |
|
|
|
|
|
Uses heuristics based on: |
|
|
- Text positioning and alignment |
|
|
- Font size estimation (based on region height) |
|
|
- Spacing patterns |
|
|
- Structural patterns (tables, lists) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[LayoutConfig] = None): |
|
|
"""Initialize rule-based detector.""" |
|
|
super().__init__(config) |
|
|
|
|
|
def initialize(self): |
|
|
"""Initialize detector (no model loading needed for rule-based).""" |
|
|
self._initialized = True |
|
|
logger.info("Initialized rule-based layout detector") |
|
|
|
|
|
def detect( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
page_number: int = 0, |
|
|
ocr_regions: Optional[List[OCRRegion]] = None, |
|
|
) -> LayoutResult: |
|
|
""" |
|
|
Detect layout regions using rule-based analysis. |
|
|
|
|
|
Args: |
|
|
image: Page image |
|
|
page_number: Page number |
|
|
ocr_regions: OCR regions for text-based analysis |
|
|
|
|
|
Returns: |
|
|
LayoutResult with detected regions |
|
|
""" |
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
|
|
|
start_time = time.time() |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
regions = [] |
|
|
region_counter = 0 |
|
|
|
|
|
def make_region_id(): |
|
|
nonlocal region_counter |
|
|
region_counter += 1 |
|
|
return f"region_{page_number}_{region_counter}" |
|
|
|
|
|
if ocr_regions: |
|
|
|
|
|
regions.extend(self._detect_titles_headings(ocr_regions, page_number, make_region_id, height)) |
|
|
regions.extend(self._detect_paragraphs(ocr_regions, page_number, make_region_id)) |
|
|
regions.extend(self._detect_lists(ocr_regions, page_number, make_region_id)) |
|
|
regions.extend(self._detect_tables_from_ocr(ocr_regions, page_number, make_region_id)) |
|
|
regions.extend(self._detect_headers_footers(ocr_regions, page_number, make_region_id, height)) |
|
|
|
|
|
|
|
|
if self.config.detect_figures: |
|
|
regions.extend(self._detect_figures_from_image(image, page_number, make_region_id, ocr_regions)) |
|
|
|
|
|
|
|
|
regions = self._merge_overlapping_regions(regions) |
|
|
|
|
|
|
|
|
regions = self._assign_reading_order(regions) |
|
|
|
|
|
processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
return LayoutResult( |
|
|
page=page_number, |
|
|
regions=regions, |
|
|
image_width=width, |
|
|
image_height=height, |
|
|
processing_time_ms=processing_time, |
|
|
success=True, |
|
|
) |
|
|
|
|
|
def _detect_titles_headings( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
page_number: int, |
|
|
make_id, |
|
|
page_height: int, |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect title and heading regions based on font size and position.""" |
|
|
if not ocr_regions or not self.config.detect_titles: |
|
|
return [] |
|
|
|
|
|
regions = [] |
|
|
|
|
|
|
|
|
heights = [r.bbox.height for r in ocr_regions if r.bbox.height > 0] |
|
|
if not heights: |
|
|
return [] |
|
|
|
|
|
avg_height = np.median(heights) |
|
|
title_threshold = avg_height * self.config.heading_font_ratio |
|
|
|
|
|
|
|
|
lines = self._group_into_lines(ocr_regions) |
|
|
|
|
|
for line_id, line_regions in lines.items(): |
|
|
if not line_regions: |
|
|
continue |
|
|
|
|
|
|
|
|
line_height = max(r.bbox.height for r in line_regions) |
|
|
line_text = " ".join(r.text for r in line_regions) |
|
|
line_y = min(r.bbox.y_min for r in line_regions) |
|
|
|
|
|
|
|
|
is_large_text = line_height > title_threshold |
|
|
is_short = len(line_text) < 100 |
|
|
is_top_of_page = line_y < page_height * 0.15 |
|
|
|
|
|
if is_large_text and is_short: |
|
|
|
|
|
x_min = min(r.bbox.x_min for r in line_regions) |
|
|
y_min = min(r.bbox.y_min for r in line_regions) |
|
|
x_max = max(r.bbox.x_max for r in line_regions) |
|
|
y_max = max(r.bbox.y_max for r in line_regions) |
|
|
|
|
|
|
|
|
if is_top_of_page and line_height > title_threshold * 1.2: |
|
|
layout_type = LayoutType.TITLE |
|
|
else: |
|
|
layout_type = LayoutType.HEADING |
|
|
|
|
|
regions.append(LayoutRegion( |
|
|
id=make_id(), |
|
|
type=layout_type, |
|
|
confidence=0.8, |
|
|
bbox=BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=False, |
|
|
), |
|
|
page=page_number, |
|
|
ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in line_regions], |
|
|
)) |
|
|
|
|
|
return regions |
|
|
|
|
|
def _detect_paragraphs( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
page_number: int, |
|
|
make_id, |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect paragraph regions by grouping nearby text.""" |
|
|
if not ocr_regions: |
|
|
return [] |
|
|
|
|
|
regions = [] |
|
|
|
|
|
|
|
|
lines = self._group_into_lines(ocr_regions) |
|
|
paragraphs = self._group_lines_into_paragraphs(lines, ocr_regions) |
|
|
|
|
|
for para_lines in paragraphs: |
|
|
if not para_lines: |
|
|
continue |
|
|
|
|
|
|
|
|
para_regions = [] |
|
|
for line_id in para_lines: |
|
|
para_regions.extend(lines.get(line_id, [])) |
|
|
|
|
|
if not para_regions: |
|
|
continue |
|
|
|
|
|
|
|
|
x_min = min(r.bbox.x_min for r in para_regions) |
|
|
y_min = min(r.bbox.y_min for r in para_regions) |
|
|
x_max = max(r.bbox.x_max for r in para_regions) |
|
|
y_max = max(r.bbox.y_max for r in para_regions) |
|
|
|
|
|
regions.append(LayoutRegion( |
|
|
id=make_id(), |
|
|
type=LayoutType.PARAGRAPH, |
|
|
confidence=0.7, |
|
|
bbox=BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=False, |
|
|
), |
|
|
page=page_number, |
|
|
ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in para_regions], |
|
|
)) |
|
|
|
|
|
return regions |
|
|
|
|
|
def _detect_lists( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
page_number: int, |
|
|
make_id, |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect list structures based on bullet/number patterns.""" |
|
|
if not ocr_regions or not self.config.detect_lists: |
|
|
return [] |
|
|
|
|
|
regions = [] |
|
|
|
|
|
|
|
|
bullet_patterns = {'•', '-', '–', '—', '*', '○', '●', '■', '□', '▪', '▸', '▹'} |
|
|
number_patterns = ('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', |
|
|
'1)', '2)', '3)', '4)', '5)', 'a.', 'b.', 'c.', 'a)', 'b)', 'c)') |
|
|
|
|
|
|
|
|
lines = self._group_into_lines(ocr_regions) |
|
|
|
|
|
|
|
|
list_lines = [] |
|
|
current_list = [] |
|
|
|
|
|
sorted_line_ids = sorted(lines.keys()) |
|
|
for line_id in sorted_line_ids: |
|
|
line_regions = lines[line_id] |
|
|
if not line_regions: |
|
|
continue |
|
|
|
|
|
first_text = line_regions[0].text.strip() |
|
|
|
|
|
|
|
|
is_list_item = ( |
|
|
any(first_text.startswith(p) for p in bullet_patterns) or |
|
|
any(first_text.startswith(p) for p in number_patterns) or |
|
|
(len(first_text) <= 3 and first_text.endswith('.')) |
|
|
) |
|
|
|
|
|
if is_list_item: |
|
|
current_list.append(line_id) |
|
|
else: |
|
|
if len(current_list) >= 2: |
|
|
list_lines.append(current_list) |
|
|
current_list = [] |
|
|
|
|
|
|
|
|
if len(current_list) >= 2: |
|
|
list_lines.append(current_list) |
|
|
|
|
|
|
|
|
for list_line_ids in list_lines: |
|
|
list_regions = [] |
|
|
for line_id in list_line_ids: |
|
|
list_regions.extend(lines.get(line_id, [])) |
|
|
|
|
|
if not list_regions: |
|
|
continue |
|
|
|
|
|
x_min = min(r.bbox.x_min for r in list_regions) |
|
|
y_min = min(r.bbox.y_min for r in list_regions) |
|
|
x_max = max(r.bbox.x_max for r in list_regions) |
|
|
y_max = max(r.bbox.y_max for r in list_regions) |
|
|
|
|
|
regions.append(LayoutRegion( |
|
|
id=make_id(), |
|
|
type=LayoutType.LIST, |
|
|
confidence=0.75, |
|
|
bbox=BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=False, |
|
|
), |
|
|
page=page_number, |
|
|
ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in list_regions], |
|
|
extra={"item_count": len(list_line_ids)}, |
|
|
)) |
|
|
|
|
|
return regions |
|
|
|
|
|
def _detect_tables_from_ocr( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
page_number: int, |
|
|
make_id, |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect table regions based on aligned text patterns.""" |
|
|
if not ocr_regions or not self.config.detect_tables: |
|
|
return [] |
|
|
|
|
|
regions = [] |
|
|
|
|
|
|
|
|
x_groups = defaultdict(list) |
|
|
x_tolerance = 20 |
|
|
|
|
|
for region in ocr_regions: |
|
|
x_center = region.bbox.center[0] |
|
|
|
|
|
matched = False |
|
|
for x_key in list(x_groups.keys()): |
|
|
if abs(x_center - x_key) < x_tolerance: |
|
|
x_groups[x_key].append(region) |
|
|
matched = True |
|
|
break |
|
|
if not matched: |
|
|
x_groups[x_center].append(region) |
|
|
|
|
|
|
|
|
if len(x_groups) >= self.config.table_min_cols: |
|
|
|
|
|
columns = sorted(x_groups.keys()) |
|
|
|
|
|
|
|
|
|
|
|
all_regions = [r for regions in x_groups.values() for r in regions] |
|
|
if len(all_regions) >= self.config.table_min_rows * self.config.table_min_cols: |
|
|
x_min = min(r.bbox.x_min for r in all_regions) |
|
|
y_min = min(r.bbox.y_min for r in all_regions) |
|
|
x_max = max(r.bbox.x_max for r in all_regions) |
|
|
y_max = max(r.bbox.y_max for r in all_regions) |
|
|
|
|
|
|
|
|
width_ratio = (x_max - x_min) / max(r.bbox.page_width or 1000 for r in all_regions) |
|
|
if width_ratio > 0.3: |
|
|
regions.append(LayoutRegion( |
|
|
id=make_id(), |
|
|
type=LayoutType.TABLE, |
|
|
confidence=0.6, |
|
|
bbox=BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=False, |
|
|
), |
|
|
page=page_number, |
|
|
extra={"estimated_cols": len(columns)}, |
|
|
)) |
|
|
|
|
|
return regions |
|
|
|
|
|
def _detect_headers_footers( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
page_number: int, |
|
|
make_id, |
|
|
page_height: int, |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect header and footer regions.""" |
|
|
if not ocr_regions or not self.config.detect_headers: |
|
|
return [] |
|
|
|
|
|
regions = [] |
|
|
header_threshold = page_height * 0.08 |
|
|
footer_threshold = page_height * 0.92 |
|
|
|
|
|
header_regions = [r for r in ocr_regions if r.bbox.y_max < header_threshold] |
|
|
footer_regions = [r for r in ocr_regions if r.bbox.y_min > footer_threshold] |
|
|
|
|
|
if header_regions: |
|
|
x_min = min(r.bbox.x_min for r in header_regions) |
|
|
y_min = min(r.bbox.y_min for r in header_regions) |
|
|
x_max = max(r.bbox.x_max for r in header_regions) |
|
|
y_max = max(r.bbox.y_max for r in header_regions) |
|
|
|
|
|
regions.append(LayoutRegion( |
|
|
id=make_id(), |
|
|
type=LayoutType.HEADER, |
|
|
confidence=0.7, |
|
|
bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), |
|
|
page=page_number, |
|
|
)) |
|
|
|
|
|
if footer_regions: |
|
|
x_min = min(r.bbox.x_min for r in footer_regions) |
|
|
y_min = min(r.bbox.y_min for r in footer_regions) |
|
|
x_max = max(r.bbox.x_max for r in footer_regions) |
|
|
y_max = max(r.bbox.y_max for r in footer_regions) |
|
|
|
|
|
regions.append(LayoutRegion( |
|
|
id=make_id(), |
|
|
type=LayoutType.FOOTER, |
|
|
confidence=0.7, |
|
|
bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), |
|
|
page=page_number, |
|
|
)) |
|
|
|
|
|
return regions |
|
|
|
|
|
def _detect_figures_from_image( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
page_number: int, |
|
|
make_id, |
|
|
ocr_regions: Optional[List[OCRRegion]], |
|
|
) -> List[LayoutRegion]: |
|
|
"""Detect figure regions using image analysis.""" |
|
|
|
|
|
regions = [] |
|
|
|
|
|
|
|
|
if ocr_regions: |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
|
|
|
text_mask = np.zeros((height, width), dtype=np.uint8) |
|
|
for r in ocr_regions: |
|
|
bbox = r.bbox |
|
|
x1, y1, x2, y2 = int(bbox.x_min), int(bbox.y_min), int(bbox.x_max), int(bbox.y_max) |
|
|
text_mask[y1:y2, x1:x2] = 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return regions |
|
|
|
|
|
def _group_into_lines( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
) -> Dict[int, List[OCRRegion]]: |
|
|
"""Group OCR regions into lines based on y-position.""" |
|
|
if not ocr_regions: |
|
|
return {} |
|
|
|
|
|
lines = defaultdict(list) |
|
|
y_tolerance = 10 |
|
|
|
|
|
|
|
|
sorted_regions = sorted(ocr_regions, key=lambda r: r.bbox.y_min) |
|
|
|
|
|
current_line_id = 0 |
|
|
current_y = sorted_regions[0].bbox.y_min if sorted_regions else 0 |
|
|
|
|
|
for region in sorted_regions: |
|
|
if abs(region.bbox.y_min - current_y) > y_tolerance: |
|
|
current_line_id += 1 |
|
|
current_y = region.bbox.y_min |
|
|
lines[current_line_id].append(region) |
|
|
|
|
|
|
|
|
for line_id in lines: |
|
|
lines[line_id] = sorted(lines[line_id], key=lambda r: r.bbox.x_min) |
|
|
|
|
|
return dict(lines) |
|
|
|
|
|
def _group_lines_into_paragraphs( |
|
|
self, |
|
|
lines: Dict[int, List[OCRRegion]], |
|
|
all_regions: List[OCRRegion], |
|
|
) -> List[List[int]]: |
|
|
"""Group lines into paragraphs based on spacing.""" |
|
|
if not lines: |
|
|
return [] |
|
|
|
|
|
paragraphs = [] |
|
|
current_para = [] |
|
|
|
|
|
sorted_line_ids = sorted(lines.keys()) |
|
|
|
|
|
for i, line_id in enumerate(sorted_line_ids): |
|
|
if not current_para: |
|
|
current_para.append(line_id) |
|
|
continue |
|
|
|
|
|
prev_line = lines[sorted_line_ids[i - 1]] |
|
|
curr_line = lines[line_id] |
|
|
|
|
|
if not prev_line or not curr_line: |
|
|
continue |
|
|
|
|
|
|
|
|
prev_y_max = max(r.bbox.y_max for r in prev_line) |
|
|
curr_y_min = min(r.bbox.y_min for r in curr_line) |
|
|
gap = curr_y_min - prev_y_max |
|
|
|
|
|
|
|
|
avg_height = np.mean([r.bbox.height for r in prev_line + curr_line]) |
|
|
|
|
|
|
|
|
if gap > avg_height * 1.5: |
|
|
paragraphs.append(current_para) |
|
|
current_para = [line_id] |
|
|
else: |
|
|
current_para.append(line_id) |
|
|
|
|
|
if current_para: |
|
|
paragraphs.append(current_para) |
|
|
|
|
|
return paragraphs |
|
|
|
|
|
def _merge_overlapping_regions( |
|
|
self, |
|
|
regions: List[LayoutRegion], |
|
|
) -> List[LayoutRegion]: |
|
|
"""Merge overlapping regions of the same type.""" |
|
|
if not regions: |
|
|
return [] |
|
|
|
|
|
|
|
|
by_type = defaultdict(list) |
|
|
for r in regions: |
|
|
by_type[r.type].append(r) |
|
|
|
|
|
merged = [] |
|
|
for layout_type, type_regions in by_type.items(): |
|
|
|
|
|
|
|
|
merged.extend(type_regions) |
|
|
|
|
|
return merged |
|
|
|
|
|
def _assign_reading_order( |
|
|
self, |
|
|
regions: List[LayoutRegion], |
|
|
) -> List[LayoutRegion]: |
|
|
"""Assign reading order to regions (top-to-bottom, left-to-right).""" |
|
|
if not regions: |
|
|
return [] |
|
|
|
|
|
|
|
|
sorted_regions = sorted( |
|
|
regions, |
|
|
key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
|
|
) |
|
|
|
|
|
for i, region in enumerate(sorted_regions): |
|
|
region.reading_order = i |
|
|
|
|
|
return sorted_regions |
|
|
|
|
|
|
|
|
|
|
|
_layout_detector: Optional[LayoutDetector] = None |
|
|
|
|
|
|
|
|
def create_layout_detector( |
|
|
config: Optional[LayoutConfig] = None, |
|
|
initialize: bool = True, |
|
|
) -> LayoutDetector: |
|
|
"""Create a layout detector instance.""" |
|
|
if config is None: |
|
|
config = LayoutConfig() |
|
|
|
|
|
if config.method == "rule_based": |
|
|
detector = RuleBasedLayoutDetector(config) |
|
|
else: |
|
|
|
|
|
logger.warning(f"Unknown method {config.method}, using rule_based") |
|
|
detector = RuleBasedLayoutDetector(config) |
|
|
|
|
|
if initialize: |
|
|
detector.initialize() |
|
|
|
|
|
return detector |
|
|
|
|
|
|
|
|
def get_layout_detector( |
|
|
config: Optional[LayoutConfig] = None, |
|
|
) -> LayoutDetector: |
|
|
"""Get or create singleton layout detector.""" |
|
|
global _layout_detector |
|
|
if _layout_detector is None: |
|
|
_layout_detector = create_layout_detector(config) |
|
|
return _layout_detector |
|
|
|