File size: 19,403 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
"""
Layout Detector Implementations

Rule-based and model-based layout detection.
"""

import time
import uuid
from typing import List, Optional, Dict, Tuple
from collections import defaultdict
import numpy as np
from loguru import logger

from .base import LayoutDetector, LayoutConfig, LayoutResult
from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion


class RuleBasedLayoutDetector(LayoutDetector):
    """
    Rule-based layout detector using OCR region analysis.

    Uses heuristics based on:
    - Text positioning and alignment
    - Font size estimation (based on region height)
    - Spacing patterns
    - Structural patterns (tables, lists)
    """

    def __init__(self, config: Optional[LayoutConfig] = None):
        """Initialize rule-based detector."""
        super().__init__(config)

    def initialize(self):
        """Initialize detector (no model loading needed for rule-based)."""
        self._initialized = True
        logger.info("Initialized rule-based layout detector")

    def detect(
        self,
        image: np.ndarray,
        page_number: int = 0,
        ocr_regions: Optional[List[OCRRegion]] = None,
    ) -> LayoutResult:
        """
        Detect layout regions using rule-based analysis.

        Args:
            image: Page image
            page_number: Page number
            ocr_regions: OCR regions for text-based analysis

        Returns:
            LayoutResult with detected regions
        """
        if not self._initialized:
            self.initialize()

        start_time = time.time()
        height, width = image.shape[:2]

        regions = []
        region_counter = 0

        def make_region_id():
            nonlocal region_counter
            region_counter += 1
            return f"region_{page_number}_{region_counter}"

        if ocr_regions:
            # Analyze OCR regions to detect layout
            regions.extend(self._detect_titles_headings(ocr_regions, page_number, make_region_id, height))
            regions.extend(self._detect_paragraphs(ocr_regions, page_number, make_region_id))
            regions.extend(self._detect_lists(ocr_regions, page_number, make_region_id))
            regions.extend(self._detect_tables_from_ocr(ocr_regions, page_number, make_region_id))
            regions.extend(self._detect_headers_footers(ocr_regions, page_number, make_region_id, height))

        # Image-based detection for figures/charts
        if self.config.detect_figures:
            regions.extend(self._detect_figures_from_image(image, page_number, make_region_id, ocr_regions))

        # Merge overlapping regions
        regions = self._merge_overlapping_regions(regions)

        # Assign reading order
        regions = self._assign_reading_order(regions)

        processing_time = (time.time() - start_time) * 1000

        return LayoutResult(
            page=page_number,
            regions=regions,
            image_width=width,
            image_height=height,
            processing_time_ms=processing_time,
            success=True,
        )

    def _detect_titles_headings(
        self,
        ocr_regions: List[OCRRegion],
        page_number: int,
        make_id,
        page_height: int,
    ) -> List[LayoutRegion]:
        """Detect title and heading regions based on font size and position."""
        if not ocr_regions or not self.config.detect_titles:
            return []

        regions = []

        # Calculate average text height
        heights = [r.bbox.height for r in ocr_regions if r.bbox.height > 0]
        if not heights:
            return []

        avg_height = np.median(heights)
        title_threshold = avg_height * self.config.heading_font_ratio

        # Group regions by line
        lines = self._group_into_lines(ocr_regions)

        for line_id, line_regions in lines.items():
            if not line_regions:
                continue

            # Calculate line properties
            line_height = max(r.bbox.height for r in line_regions)
            line_text = " ".join(r.text for r in line_regions)
            line_y = min(r.bbox.y_min for r in line_regions)

            # Check if this looks like a title/heading
            is_large_text = line_height > title_threshold
            is_short = len(line_text) < 100
            is_top_of_page = line_y < page_height * 0.15

            if is_large_text and is_short:
                # Merge line regions into one bbox
                x_min = min(r.bbox.x_min for r in line_regions)
                y_min = min(r.bbox.y_min for r in line_regions)
                x_max = max(r.bbox.x_max for r in line_regions)
                y_max = max(r.bbox.y_max for r in line_regions)

                # Determine if title or heading
                if is_top_of_page and line_height > title_threshold * 1.2:
                    layout_type = LayoutType.TITLE
                else:
                    layout_type = LayoutType.HEADING

                regions.append(LayoutRegion(
                    id=make_id(),
                    type=layout_type,
                    confidence=0.8,
                    bbox=BoundingBox(
                        x_min=x_min, y_min=y_min,
                        x_max=x_max, y_max=y_max,
                        normalized=False,
                    ),
                    page=page_number,
                    ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in line_regions],
                ))

        return regions

    def _detect_paragraphs(
        self,
        ocr_regions: List[OCRRegion],
        page_number: int,
        make_id,
    ) -> List[LayoutRegion]:
        """Detect paragraph regions by grouping nearby text."""
        if not ocr_regions:
            return []

        regions = []

        # Group regions by proximity
        lines = self._group_into_lines(ocr_regions)
        paragraphs = self._group_lines_into_paragraphs(lines, ocr_regions)

        for para_lines in paragraphs:
            if not para_lines:
                continue

            # Get all OCR regions in this paragraph
            para_regions = []
            for line_id in para_lines:
                para_regions.extend(lines.get(line_id, []))

            if not para_regions:
                continue

            # Calculate bounding box
            x_min = min(r.bbox.x_min for r in para_regions)
            y_min = min(r.bbox.y_min for r in para_regions)
            x_max = max(r.bbox.x_max for r in para_regions)
            y_max = max(r.bbox.y_max for r in para_regions)

            regions.append(LayoutRegion(
                id=make_id(),
                type=LayoutType.PARAGRAPH,
                confidence=0.7,
                bbox=BoundingBox(
                    x_min=x_min, y_min=y_min,
                    x_max=x_max, y_max=y_max,
                    normalized=False,
                ),
                page=page_number,
                ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in para_regions],
            ))

        return regions

    def _detect_lists(
        self,
        ocr_regions: List[OCRRegion],
        page_number: int,
        make_id,
    ) -> List[LayoutRegion]:
        """Detect list structures based on bullet/number patterns."""
        if not ocr_regions or not self.config.detect_lists:
            return []

        regions = []

        # List indicators
        bullet_patterns = {'•', '-', '–', '—', '*', '○', '●', '■', '□', '▪', '▸', '▹'}
        number_patterns = ('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.',
                          '1)', '2)', '3)', '4)', '5)', 'a.', 'b.', 'c.', 'a)', 'b)', 'c)')

        # Group by lines
        lines = self._group_into_lines(ocr_regions)

        # Find consecutive lines that look like list items
        list_lines = []
        current_list = []

        sorted_line_ids = sorted(lines.keys())
        for line_id in sorted_line_ids:
            line_regions = lines[line_id]
            if not line_regions:
                continue

            first_text = line_regions[0].text.strip()

            # Check if line starts with list indicator
            is_list_item = (
                any(first_text.startswith(p) for p in bullet_patterns) or
                any(first_text.startswith(p) for p in number_patterns) or
                (len(first_text) <= 3 and first_text.endswith('.'))
            )

            if is_list_item:
                current_list.append(line_id)
            else:
                if len(current_list) >= 2:
                    list_lines.append(current_list)
                current_list = []

        # Don't forget the last list
        if len(current_list) >= 2:
            list_lines.append(current_list)

        # Create list regions
        for list_line_ids in list_lines:
            list_regions = []
            for line_id in list_line_ids:
                list_regions.extend(lines.get(line_id, []))

            if not list_regions:
                continue

            x_min = min(r.bbox.x_min for r in list_regions)
            y_min = min(r.bbox.y_min for r in list_regions)
            x_max = max(r.bbox.x_max for r in list_regions)
            y_max = max(r.bbox.y_max for r in list_regions)

            regions.append(LayoutRegion(
                id=make_id(),
                type=LayoutType.LIST,
                confidence=0.75,
                bbox=BoundingBox(
                    x_min=x_min, y_min=y_min,
                    x_max=x_max, y_max=y_max,
                    normalized=False,
                ),
                page=page_number,
                ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in list_regions],
                extra={"item_count": len(list_line_ids)},
            ))

        return regions

    def _detect_tables_from_ocr(
        self,
        ocr_regions: List[OCRRegion],
        page_number: int,
        make_id,
    ) -> List[LayoutRegion]:
        """Detect table regions based on aligned text patterns."""
        if not ocr_regions or not self.config.detect_tables:
            return []

        regions = []

        # Group regions by approximate x-position (columns)
        x_groups = defaultdict(list)
        x_tolerance = 20  # pixels

        for region in ocr_regions:
            x_center = region.bbox.center[0]
            # Find closest existing column
            matched = False
            for x_key in list(x_groups.keys()):
                if abs(x_center - x_key) < x_tolerance:
                    x_groups[x_key].append(region)
                    matched = True
                    break
            if not matched:
                x_groups[x_center].append(region)

        # Find areas where multiple columns align vertically
        if len(x_groups) >= self.config.table_min_cols:
            # Check for row alignment
            columns = sorted(x_groups.keys())

            # Find overlapping y-ranges across columns
            # This is a simplified heuristic
            all_regions = [r for regions in x_groups.values() for r in regions]
            if len(all_regions) >= self.config.table_min_rows * self.config.table_min_cols:
                x_min = min(r.bbox.x_min for r in all_regions)
                y_min = min(r.bbox.y_min for r in all_regions)
                x_max = max(r.bbox.x_max for r in all_regions)
                y_max = max(r.bbox.y_max for r in all_regions)

                # Only create table if it spans significant width
                width_ratio = (x_max - x_min) / max(r.bbox.page_width or 1000 for r in all_regions)
                if width_ratio > 0.3:
                    regions.append(LayoutRegion(
                        id=make_id(),
                        type=LayoutType.TABLE,
                        confidence=0.6,  # Lower confidence for rule-based
                        bbox=BoundingBox(
                            x_min=x_min, y_min=y_min,
                            x_max=x_max, y_max=y_max,
                            normalized=False,
                        ),
                        page=page_number,
                        extra={"estimated_cols": len(columns)},
                    ))

        return regions

    def _detect_headers_footers(
        self,
        ocr_regions: List[OCRRegion],
        page_number: int,
        make_id,
        page_height: int,
    ) -> List[LayoutRegion]:
        """Detect header and footer regions."""
        if not ocr_regions or not self.config.detect_headers:
            return []

        regions = []
        header_threshold = page_height * 0.08
        footer_threshold = page_height * 0.92

        header_regions = [r for r in ocr_regions if r.bbox.y_max < header_threshold]
        footer_regions = [r for r in ocr_regions if r.bbox.y_min > footer_threshold]

        if header_regions:
            x_min = min(r.bbox.x_min for r in header_regions)
            y_min = min(r.bbox.y_min for r in header_regions)
            x_max = max(r.bbox.x_max for r in header_regions)
            y_max = max(r.bbox.y_max for r in header_regions)

            regions.append(LayoutRegion(
                id=make_id(),
                type=LayoutType.HEADER,
                confidence=0.7,
                bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False),
                page=page_number,
            ))

        if footer_regions:
            x_min = min(r.bbox.x_min for r in footer_regions)
            y_min = min(r.bbox.y_min for r in footer_regions)
            x_max = max(r.bbox.x_max for r in footer_regions)
            y_max = max(r.bbox.y_max for r in footer_regions)

            regions.append(LayoutRegion(
                id=make_id(),
                type=LayoutType.FOOTER,
                confidence=0.7,
                bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False),
                page=page_number,
            ))

        return regions

    def _detect_figures_from_image(
        self,
        image: np.ndarray,
        page_number: int,
        make_id,
        ocr_regions: Optional[List[OCRRegion]],
    ) -> List[LayoutRegion]:
        """Detect figure regions using image analysis."""
        # This is a simplified approach - in production, use a vision model
        regions = []

        # Find large areas without text (potential figures)
        if ocr_regions:
            height, width = image.shape[:2]

            # Create a mask of text regions
            text_mask = np.zeros((height, width), dtype=np.uint8)
            for r in ocr_regions:
                bbox = r.bbox
                x1, y1, x2, y2 = int(bbox.x_min), int(bbox.y_min), int(bbox.x_max), int(bbox.y_max)
                text_mask[y1:y2, x1:x2] = 255

            # Find large non-text areas (very simplified)
            # In production, use connected components or contour detection
            # This is a placeholder for more sophisticated detection

        return regions

    def _group_into_lines(
        self,
        ocr_regions: List[OCRRegion],
    ) -> Dict[int, List[OCRRegion]]:
        """Group OCR regions into lines based on y-position."""
        if not ocr_regions:
            return {}

        lines = defaultdict(list)
        y_tolerance = 10  # pixels

        # Sort by y position
        sorted_regions = sorted(ocr_regions, key=lambda r: r.bbox.y_min)

        current_line_id = 0
        current_y = sorted_regions[0].bbox.y_min if sorted_regions else 0

        for region in sorted_regions:
            if abs(region.bbox.y_min - current_y) > y_tolerance:
                current_line_id += 1
                current_y = region.bbox.y_min
            lines[current_line_id].append(region)

        # Sort each line by x position
        for line_id in lines:
            lines[line_id] = sorted(lines[line_id], key=lambda r: r.bbox.x_min)

        return dict(lines)

    def _group_lines_into_paragraphs(
        self,
        lines: Dict[int, List[OCRRegion]],
        all_regions: List[OCRRegion],
    ) -> List[List[int]]:
        """Group lines into paragraphs based on spacing."""
        if not lines:
            return []

        paragraphs = []
        current_para = []

        sorted_line_ids = sorted(lines.keys())

        for i, line_id in enumerate(sorted_line_ids):
            if not current_para:
                current_para.append(line_id)
                continue

            prev_line = lines[sorted_line_ids[i - 1]]
            curr_line = lines[line_id]

            if not prev_line or not curr_line:
                continue

            # Calculate vertical gap
            prev_y_max = max(r.bbox.y_max for r in prev_line)
            curr_y_min = min(r.bbox.y_min for r in curr_line)
            gap = curr_y_min - prev_y_max

            # Calculate average line height
            avg_height = np.mean([r.bbox.height for r in prev_line + curr_line])

            # Large gap indicates new paragraph
            if gap > avg_height * 1.5:
                paragraphs.append(current_para)
                current_para = [line_id]
            else:
                current_para.append(line_id)

        if current_para:
            paragraphs.append(current_para)

        return paragraphs

    def _merge_overlapping_regions(
        self,
        regions: List[LayoutRegion],
    ) -> List[LayoutRegion]:
        """Merge overlapping regions of the same type."""
        if not regions:
            return []

        # Group by type
        by_type = defaultdict(list)
        for r in regions:
            by_type[r.type].append(r)

        merged = []
        for layout_type, type_regions in by_type.items():
            # Simple merging: keep non-overlapping or merge overlapping
            # This is simplified - production should use more sophisticated merging
            merged.extend(type_regions)

        return merged

    def _assign_reading_order(
        self,
        regions: List[LayoutRegion],
    ) -> List[LayoutRegion]:
        """Assign reading order to regions (top-to-bottom, left-to-right)."""
        if not regions:
            return []

        # Sort by y first, then x
        sorted_regions = sorted(
            regions,
            key=lambda r: (r.bbox.y_min, r.bbox.x_min)
        )

        for i, region in enumerate(sorted_regions):
            region.reading_order = i

        return sorted_regions


# Factory functions
_layout_detector: Optional[LayoutDetector] = None


def create_layout_detector(
    config: Optional[LayoutConfig] = None,
    initialize: bool = True,
) -> LayoutDetector:
    """Create a layout detector instance."""
    if config is None:
        config = LayoutConfig()

    if config.method == "rule_based":
        detector = RuleBasedLayoutDetector(config)
    else:
        # Default to rule-based
        logger.warning(f"Unknown method {config.method}, using rule_based")
        detector = RuleBasedLayoutDetector(config)

    if initialize:
        detector.initialize()

    return detector


def get_layout_detector(
    config: Optional[LayoutConfig] = None,
) -> LayoutDetector:
    """Get or create singleton layout detector."""
    global _layout_detector
    if _layout_detector is None:
        _layout_detector = create_layout_detector(config)
    return _layout_detector