""" Layout Detector Implementations Rule-based and model-based layout detection. """ import time import uuid from typing import List, Optional, Dict, Tuple from collections import defaultdict import numpy as np from loguru import logger from .base import LayoutDetector, LayoutConfig, LayoutResult from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion class RuleBasedLayoutDetector(LayoutDetector): """ Rule-based layout detector using OCR region analysis. Uses heuristics based on: - Text positioning and alignment - Font size estimation (based on region height) - Spacing patterns - Structural patterns (tables, lists) """ def __init__(self, config: Optional[LayoutConfig] = None): """Initialize rule-based detector.""" super().__init__(config) def initialize(self): """Initialize detector (no model loading needed for rule-based).""" self._initialized = True logger.info("Initialized rule-based layout detector") def detect( self, image: np.ndarray, page_number: int = 0, ocr_regions: Optional[List[OCRRegion]] = None, ) -> LayoutResult: """ Detect layout regions using rule-based analysis. Args: image: Page image page_number: Page number ocr_regions: OCR regions for text-based analysis Returns: LayoutResult with detected regions """ if not self._initialized: self.initialize() start_time = time.time() height, width = image.shape[:2] regions = [] region_counter = 0 def make_region_id(): nonlocal region_counter region_counter += 1 return f"region_{page_number}_{region_counter}" if ocr_regions: # Analyze OCR regions to detect layout regions.extend(self._detect_titles_headings(ocr_regions, page_number, make_region_id, height)) regions.extend(self._detect_paragraphs(ocr_regions, page_number, make_region_id)) regions.extend(self._detect_lists(ocr_regions, page_number, make_region_id)) regions.extend(self._detect_tables_from_ocr(ocr_regions, page_number, make_region_id)) regions.extend(self._detect_headers_footers(ocr_regions, page_number, make_region_id, height)) # Image-based detection for figures/charts if self.config.detect_figures: regions.extend(self._detect_figures_from_image(image, page_number, make_region_id, ocr_regions)) # Merge overlapping regions regions = self._merge_overlapping_regions(regions) # Assign reading order regions = self._assign_reading_order(regions) processing_time = (time.time() - start_time) * 1000 return LayoutResult( page=page_number, regions=regions, image_width=width, image_height=height, processing_time_ms=processing_time, success=True, ) def _detect_titles_headings( self, ocr_regions: List[OCRRegion], page_number: int, make_id, page_height: int, ) -> List[LayoutRegion]: """Detect title and heading regions based on font size and position.""" if not ocr_regions or not self.config.detect_titles: return [] regions = [] # Calculate average text height heights = [r.bbox.height for r in ocr_regions if r.bbox.height > 0] if not heights: return [] avg_height = np.median(heights) title_threshold = avg_height * self.config.heading_font_ratio # Group regions by line lines = self._group_into_lines(ocr_regions) for line_id, line_regions in lines.items(): if not line_regions: continue # Calculate line properties line_height = max(r.bbox.height for r in line_regions) line_text = " ".join(r.text for r in line_regions) line_y = min(r.bbox.y_min for r in line_regions) # Check if this looks like a title/heading is_large_text = line_height > title_threshold is_short = len(line_text) < 100 is_top_of_page = line_y < page_height * 0.15 if is_large_text and is_short: # Merge line regions into one bbox x_min = min(r.bbox.x_min for r in line_regions) y_min = min(r.bbox.y_min for r in line_regions) x_max = max(r.bbox.x_max for r in line_regions) y_max = max(r.bbox.y_max for r in line_regions) # Determine if title or heading if is_top_of_page and line_height > title_threshold * 1.2: layout_type = LayoutType.TITLE else: layout_type = LayoutType.HEADING regions.append(LayoutRegion( id=make_id(), type=layout_type, confidence=0.8, bbox=BoundingBox( x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False, ), page=page_number, ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in line_regions], )) return regions def _detect_paragraphs( self, ocr_regions: List[OCRRegion], page_number: int, make_id, ) -> List[LayoutRegion]: """Detect paragraph regions by grouping nearby text.""" if not ocr_regions: return [] regions = [] # Group regions by proximity lines = self._group_into_lines(ocr_regions) paragraphs = self._group_lines_into_paragraphs(lines, ocr_regions) for para_lines in paragraphs: if not para_lines: continue # Get all OCR regions in this paragraph para_regions = [] for line_id in para_lines: para_regions.extend(lines.get(line_id, [])) if not para_regions: continue # Calculate bounding box x_min = min(r.bbox.x_min for r in para_regions) y_min = min(r.bbox.y_min for r in para_regions) x_max = max(r.bbox.x_max for r in para_regions) y_max = max(r.bbox.y_max for r in para_regions) regions.append(LayoutRegion( id=make_id(), type=LayoutType.PARAGRAPH, confidence=0.7, bbox=BoundingBox( x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False, ), page=page_number, ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in para_regions], )) return regions def _detect_lists( self, ocr_regions: List[OCRRegion], page_number: int, make_id, ) -> List[LayoutRegion]: """Detect list structures based on bullet/number patterns.""" if not ocr_regions or not self.config.detect_lists: return [] regions = [] # List indicators bullet_patterns = {'•', '-', '–', '—', '*', '○', '●', '■', '□', '▪', '▸', '▹'} number_patterns = ('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', '1)', '2)', '3)', '4)', '5)', 'a.', 'b.', 'c.', 'a)', 'b)', 'c)') # Group by lines lines = self._group_into_lines(ocr_regions) # Find consecutive lines that look like list items list_lines = [] current_list = [] sorted_line_ids = sorted(lines.keys()) for line_id in sorted_line_ids: line_regions = lines[line_id] if not line_regions: continue first_text = line_regions[0].text.strip() # Check if line starts with list indicator is_list_item = ( any(first_text.startswith(p) for p in bullet_patterns) or any(first_text.startswith(p) for p in number_patterns) or (len(first_text) <= 3 and first_text.endswith('.')) ) if is_list_item: current_list.append(line_id) else: if len(current_list) >= 2: list_lines.append(current_list) current_list = [] # Don't forget the last list if len(current_list) >= 2: list_lines.append(current_list) # Create list regions for list_line_ids in list_lines: list_regions = [] for line_id in list_line_ids: list_regions.extend(lines.get(line_id, [])) if not list_regions: continue x_min = min(r.bbox.x_min for r in list_regions) y_min = min(r.bbox.y_min for r in list_regions) x_max = max(r.bbox.x_max for r in list_regions) y_max = max(r.bbox.y_max for r in list_regions) regions.append(LayoutRegion( id=make_id(), type=LayoutType.LIST, confidence=0.75, bbox=BoundingBox( x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False, ), page=page_number, ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in list_regions], extra={"item_count": len(list_line_ids)}, )) return regions def _detect_tables_from_ocr( self, ocr_regions: List[OCRRegion], page_number: int, make_id, ) -> List[LayoutRegion]: """Detect table regions based on aligned text patterns.""" if not ocr_regions or not self.config.detect_tables: return [] regions = [] # Group regions by approximate x-position (columns) x_groups = defaultdict(list) x_tolerance = 20 # pixels for region in ocr_regions: x_center = region.bbox.center[0] # Find closest existing column matched = False for x_key in list(x_groups.keys()): if abs(x_center - x_key) < x_tolerance: x_groups[x_key].append(region) matched = True break if not matched: x_groups[x_center].append(region) # Find areas where multiple columns align vertically if len(x_groups) >= self.config.table_min_cols: # Check for row alignment columns = sorted(x_groups.keys()) # Find overlapping y-ranges across columns # This is a simplified heuristic all_regions = [r for regions in x_groups.values() for r in regions] if len(all_regions) >= self.config.table_min_rows * self.config.table_min_cols: x_min = min(r.bbox.x_min for r in all_regions) y_min = min(r.bbox.y_min for r in all_regions) x_max = max(r.bbox.x_max for r in all_regions) y_max = max(r.bbox.y_max for r in all_regions) # Only create table if it spans significant width width_ratio = (x_max - x_min) / max(r.bbox.page_width or 1000 for r in all_regions) if width_ratio > 0.3: regions.append(LayoutRegion( id=make_id(), type=LayoutType.TABLE, confidence=0.6, # Lower confidence for rule-based bbox=BoundingBox( x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False, ), page=page_number, extra={"estimated_cols": len(columns)}, )) return regions def _detect_headers_footers( self, ocr_regions: List[OCRRegion], page_number: int, make_id, page_height: int, ) -> List[LayoutRegion]: """Detect header and footer regions.""" if not ocr_regions or not self.config.detect_headers: return [] regions = [] header_threshold = page_height * 0.08 footer_threshold = page_height * 0.92 header_regions = [r for r in ocr_regions if r.bbox.y_max < header_threshold] footer_regions = [r for r in ocr_regions if r.bbox.y_min > footer_threshold] if header_regions: x_min = min(r.bbox.x_min for r in header_regions) y_min = min(r.bbox.y_min for r in header_regions) x_max = max(r.bbox.x_max for r in header_regions) y_max = max(r.bbox.y_max for r in header_regions) regions.append(LayoutRegion( id=make_id(), type=LayoutType.HEADER, confidence=0.7, bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), page=page_number, )) if footer_regions: x_min = min(r.bbox.x_min for r in footer_regions) y_min = min(r.bbox.y_min for r in footer_regions) x_max = max(r.bbox.x_max for r in footer_regions) y_max = max(r.bbox.y_max for r in footer_regions) regions.append(LayoutRegion( id=make_id(), type=LayoutType.FOOTER, confidence=0.7, bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), page=page_number, )) return regions def _detect_figures_from_image( self, image: np.ndarray, page_number: int, make_id, ocr_regions: Optional[List[OCRRegion]], ) -> List[LayoutRegion]: """Detect figure regions using image analysis.""" # This is a simplified approach - in production, use a vision model regions = [] # Find large areas without text (potential figures) if ocr_regions: height, width = image.shape[:2] # Create a mask of text regions text_mask = np.zeros((height, width), dtype=np.uint8) for r in ocr_regions: bbox = r.bbox x1, y1, x2, y2 = int(bbox.x_min), int(bbox.y_min), int(bbox.x_max), int(bbox.y_max) text_mask[y1:y2, x1:x2] = 255 # Find large non-text areas (very simplified) # In production, use connected components or contour detection # This is a placeholder for more sophisticated detection return regions def _group_into_lines( self, ocr_regions: List[OCRRegion], ) -> Dict[int, List[OCRRegion]]: """Group OCR regions into lines based on y-position.""" if not ocr_regions: return {} lines = defaultdict(list) y_tolerance = 10 # pixels # Sort by y position sorted_regions = sorted(ocr_regions, key=lambda r: r.bbox.y_min) current_line_id = 0 current_y = sorted_regions[0].bbox.y_min if sorted_regions else 0 for region in sorted_regions: if abs(region.bbox.y_min - current_y) > y_tolerance: current_line_id += 1 current_y = region.bbox.y_min lines[current_line_id].append(region) # Sort each line by x position for line_id in lines: lines[line_id] = sorted(lines[line_id], key=lambda r: r.bbox.x_min) return dict(lines) def _group_lines_into_paragraphs( self, lines: Dict[int, List[OCRRegion]], all_regions: List[OCRRegion], ) -> List[List[int]]: """Group lines into paragraphs based on spacing.""" if not lines: return [] paragraphs = [] current_para = [] sorted_line_ids = sorted(lines.keys()) for i, line_id in enumerate(sorted_line_ids): if not current_para: current_para.append(line_id) continue prev_line = lines[sorted_line_ids[i - 1]] curr_line = lines[line_id] if not prev_line or not curr_line: continue # Calculate vertical gap prev_y_max = max(r.bbox.y_max for r in prev_line) curr_y_min = min(r.bbox.y_min for r in curr_line) gap = curr_y_min - prev_y_max # Calculate average line height avg_height = np.mean([r.bbox.height for r in prev_line + curr_line]) # Large gap indicates new paragraph if gap > avg_height * 1.5: paragraphs.append(current_para) current_para = [line_id] else: current_para.append(line_id) if current_para: paragraphs.append(current_para) return paragraphs def _merge_overlapping_regions( self, regions: List[LayoutRegion], ) -> List[LayoutRegion]: """Merge overlapping regions of the same type.""" if not regions: return [] # Group by type by_type = defaultdict(list) for r in regions: by_type[r.type].append(r) merged = [] for layout_type, type_regions in by_type.items(): # Simple merging: keep non-overlapping or merge overlapping # This is simplified - production should use more sophisticated merging merged.extend(type_regions) return merged def _assign_reading_order( self, regions: List[LayoutRegion], ) -> List[LayoutRegion]: """Assign reading order to regions (top-to-bottom, left-to-right).""" if not regions: return [] # Sort by y first, then x sorted_regions = sorted( regions, key=lambda r: (r.bbox.y_min, r.bbox.x_min) ) for i, region in enumerate(sorted_regions): region.reading_order = i return sorted_regions # Factory functions _layout_detector: Optional[LayoutDetector] = None def create_layout_detector( config: Optional[LayoutConfig] = None, initialize: bool = True, ) -> LayoutDetector: """Create a layout detector instance.""" if config is None: config = LayoutConfig() if config.method == "rule_based": detector = RuleBasedLayoutDetector(config) else: # Default to rule-based logger.warning(f"Unknown method {config.method}, using rule_based") detector = RuleBasedLayoutDetector(config) if initialize: detector.initialize() return detector def get_layout_detector( config: Optional[LayoutConfig] = None, ) -> LayoutDetector: """Get or create singleton layout detector.""" global _layout_detector if _layout_detector is None: _layout_detector = create_layout_detector(config) return _layout_detector