|
|
""" |
|
|
Document Chunker Implementation |
|
|
|
|
|
Creates semantic chunks from document content with bounding box tracking. |
|
|
Includes TableAwareChunker for preserving table structure in markdown format. |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
import time |
|
|
import re |
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
from dataclasses import dataclass |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
from collections import defaultdict |
|
|
|
|
|
from ..schemas.core import ( |
|
|
BoundingBox, |
|
|
DocumentChunk, |
|
|
ChunkType, |
|
|
LayoutRegion, |
|
|
LayoutType, |
|
|
OCRRegion, |
|
|
) |
|
|
|
|
|
|
|
|
class ChunkerConfig(BaseModel): |
|
|
"""Configuration for document chunking.""" |
|
|
|
|
|
max_chunk_chars: int = Field( |
|
|
default=1000, |
|
|
ge=100, |
|
|
description="Maximum characters per chunk" |
|
|
) |
|
|
min_chunk_chars: int = Field( |
|
|
default=50, |
|
|
ge=10, |
|
|
description="Minimum characters per chunk" |
|
|
) |
|
|
overlap_chars: int = Field( |
|
|
default=100, |
|
|
ge=0, |
|
|
description="Character overlap between chunks" |
|
|
) |
|
|
|
|
|
|
|
|
strategy: str = Field( |
|
|
default="semantic", |
|
|
description="Chunking strategy: semantic, fixed, or layout" |
|
|
) |
|
|
respect_layout: bool = Field( |
|
|
default=True, |
|
|
description="Respect layout region boundaries" |
|
|
) |
|
|
merge_small_regions: bool = Field( |
|
|
default=True, |
|
|
description="Merge small adjacent regions" |
|
|
) |
|
|
|
|
|
|
|
|
chunk_tables: bool = Field( |
|
|
default=True, |
|
|
description="Create separate chunks for tables" |
|
|
) |
|
|
chunk_figures: bool = Field( |
|
|
default=True, |
|
|
description="Create separate chunks for figures" |
|
|
) |
|
|
include_captions: bool = Field( |
|
|
default=True, |
|
|
description="Include captions with figures/tables" |
|
|
) |
|
|
|
|
|
|
|
|
split_on_sentences: bool = Field( |
|
|
default=True, |
|
|
description="Split on sentence boundaries when possible" |
|
|
) |
|
|
|
|
|
|
|
|
preserve_table_structure: bool = Field( |
|
|
default=True, |
|
|
description="Preserve table structure as markdown with structured data" |
|
|
) |
|
|
table_row_threshold: float = Field( |
|
|
default=10.0, |
|
|
description="Y-coordinate threshold for grouping cells into rows" |
|
|
) |
|
|
table_col_threshold: float = Field( |
|
|
default=20.0, |
|
|
description="X-coordinate threshold for grouping cells into columns" |
|
|
) |
|
|
detect_table_headers: bool = Field( |
|
|
default=True, |
|
|
description="Attempt to detect and mark header rows" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
LAYOUT_TO_CHUNK_TYPE = { |
|
|
LayoutType.TEXT: ChunkType.TEXT, |
|
|
LayoutType.TITLE: ChunkType.TITLE, |
|
|
LayoutType.HEADING: ChunkType.HEADING, |
|
|
LayoutType.PARAGRAPH: ChunkType.PARAGRAPH, |
|
|
LayoutType.LIST: ChunkType.LIST_ITEM, |
|
|
LayoutType.TABLE: ChunkType.TABLE, |
|
|
LayoutType.FIGURE: ChunkType.FIGURE, |
|
|
LayoutType.CHART: ChunkType.CHART, |
|
|
LayoutType.FORMULA: ChunkType.FORMULA, |
|
|
LayoutType.CAPTION: ChunkType.CAPTION, |
|
|
LayoutType.FOOTNOTE: ChunkType.FOOTNOTE, |
|
|
LayoutType.HEADER: ChunkType.HEADER, |
|
|
LayoutType.FOOTER: ChunkType.FOOTER, |
|
|
} |
|
|
|
|
|
|
|
|
class DocumentChunker: |
|
|
"""Base class for document chunkers.""" |
|
|
|
|
|
def __init__(self, config: Optional[ChunkerConfig] = None): |
|
|
self.config = config or ChunkerConfig() |
|
|
|
|
|
def create_chunks( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout_regions: Optional[List[LayoutRegion]] = None, |
|
|
document_id: str = "", |
|
|
source_path: Optional[str] = None, |
|
|
) -> List[DocumentChunk]: |
|
|
""" |
|
|
Create chunks from OCR and layout regions. |
|
|
|
|
|
Args: |
|
|
ocr_regions: OCR text regions |
|
|
layout_regions: Optional layout regions |
|
|
document_id: Parent document ID |
|
|
source_path: Source file path |
|
|
|
|
|
Returns: |
|
|
List of DocumentChunk |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class SemanticChunker(DocumentChunker): |
|
|
""" |
|
|
Semantic chunker that respects document structure. |
|
|
|
|
|
Creates chunks based on: |
|
|
- Layout region boundaries |
|
|
- Semantic coherence (paragraphs, sections) |
|
|
- Size constraints with overlap |
|
|
""" |
|
|
|
|
|
def create_chunks( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout_regions: Optional[List[LayoutRegion]] = None, |
|
|
document_id: str = "", |
|
|
source_path: Optional[str] = None, |
|
|
) -> List[DocumentChunk]: |
|
|
"""Create semantic chunks from document content.""" |
|
|
if not ocr_regions: |
|
|
return [] |
|
|
|
|
|
start_time = time.time() |
|
|
chunks = [] |
|
|
chunk_index = 0 |
|
|
|
|
|
if layout_regions and self.config.respect_layout: |
|
|
|
|
|
chunks = self._chunk_by_layout( |
|
|
ocr_regions, layout_regions, document_id, source_path |
|
|
) |
|
|
else: |
|
|
|
|
|
chunks = self._chunk_by_text( |
|
|
ocr_regions, document_id, source_path |
|
|
) |
|
|
|
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
chunk.sequence_index = i |
|
|
|
|
|
logger.debug( |
|
|
f"Created {len(chunks)} chunks in " |
|
|
f"{(time.time() - start_time) * 1000:.1f}ms" |
|
|
) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _chunk_by_layout( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout_regions: List[LayoutRegion], |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
) -> List[DocumentChunk]: |
|
|
"""Create chunks based on layout regions.""" |
|
|
chunks = [] |
|
|
|
|
|
|
|
|
sorted_layouts = sorted( |
|
|
layout_regions, |
|
|
key=lambda r: (r.reading_order or 0, r.bbox.y_min, r.bbox.x_min) |
|
|
) |
|
|
|
|
|
for layout in sorted_layouts: |
|
|
|
|
|
contained_ocr = self._get_contained_ocr(ocr_regions, layout) |
|
|
|
|
|
if not contained_ocr: |
|
|
continue |
|
|
|
|
|
|
|
|
chunk_type = LAYOUT_TO_CHUNK_TYPE.get(layout.type, ChunkType.TEXT) |
|
|
|
|
|
|
|
|
if layout.type == LayoutType.TABLE and self.config.chunk_tables: |
|
|
chunk = self._create_table_chunk( |
|
|
contained_ocr, layout, document_id, source_path |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
elif layout.type in (LayoutType.FIGURE, LayoutType.CHART) and self.config.chunk_figures: |
|
|
chunk = self._create_figure_chunk( |
|
|
contained_ocr, layout, document_id, source_path |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
else: |
|
|
|
|
|
text_chunks = self._create_text_chunks( |
|
|
contained_ocr, layout, chunk_type, document_id, source_path |
|
|
) |
|
|
chunks.extend(text_chunks) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _chunk_by_text( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
) -> List[DocumentChunk]: |
|
|
"""Create chunks from text without layout guidance.""" |
|
|
chunks = [] |
|
|
|
|
|
|
|
|
sorted_regions = sorted( |
|
|
ocr_regions, |
|
|
key=lambda r: (r.page, r.bbox.y_min, r.bbox.x_min) |
|
|
) |
|
|
|
|
|
|
|
|
pages: Dict[int, List[OCRRegion]] = {} |
|
|
for r in sorted_regions: |
|
|
if r.page not in pages: |
|
|
pages[r.page] = [] |
|
|
pages[r.page].append(r) |
|
|
|
|
|
|
|
|
for page_num in sorted(pages.keys()): |
|
|
page_regions = pages[page_num] |
|
|
page_chunks = self._split_text_regions( |
|
|
page_regions, document_id, source_path, page_num |
|
|
) |
|
|
chunks.extend(page_chunks) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _get_contained_ocr( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout: LayoutRegion, |
|
|
) -> List[OCRRegion]: |
|
|
"""Get OCR regions contained within a layout region.""" |
|
|
contained = [] |
|
|
for ocr in ocr_regions: |
|
|
if ocr.page == layout.page: |
|
|
|
|
|
iou = layout.bbox.iou(ocr.bbox) |
|
|
if iou > 0.3 or layout.bbox.contains(ocr.bbox): |
|
|
contained.append(ocr) |
|
|
return contained |
|
|
|
|
|
def _create_text_chunks( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout: LayoutRegion, |
|
|
chunk_type: ChunkType, |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
) -> List[DocumentChunk]: |
|
|
"""Create text chunks from OCR regions, splitting if needed.""" |
|
|
chunks = [] |
|
|
|
|
|
|
|
|
text = " ".join(r.text for r in ocr_regions) |
|
|
|
|
|
|
|
|
avg_conf = sum(r.confidence for r in ocr_regions) / len(ocr_regions) |
|
|
|
|
|
|
|
|
if len(text) <= self.config.max_chunk_chars: |
|
|
|
|
|
chunk = DocumentChunk( |
|
|
chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=chunk_type, |
|
|
text=text, |
|
|
bbox=layout.bbox, |
|
|
page=layout.page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=0, |
|
|
confidence=avg_conf, |
|
|
) |
|
|
chunks.append(chunk) |
|
|
else: |
|
|
|
|
|
split_chunks = self._split_text( |
|
|
text, layout.bbox, layout.page, chunk_type, |
|
|
document_id, source_path, avg_conf |
|
|
) |
|
|
chunks.extend(split_chunks) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _split_text( |
|
|
self, |
|
|
text: str, |
|
|
bbox: BoundingBox, |
|
|
page: int, |
|
|
chunk_type: ChunkType, |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
confidence: float, |
|
|
) -> List[DocumentChunk]: |
|
|
"""Split long text into multiple chunks with overlap.""" |
|
|
chunks = [] |
|
|
max_chars = self.config.max_chunk_chars |
|
|
overlap = self.config.overlap_chars |
|
|
|
|
|
|
|
|
if self.config.split_on_sentences: |
|
|
sentences = self._split_sentences(text) |
|
|
else: |
|
|
sentences = [text] |
|
|
|
|
|
current_text = "" |
|
|
for sentence in sentences: |
|
|
if len(current_text) + len(sentence) > max_chars and current_text: |
|
|
|
|
|
chunk = DocumentChunk( |
|
|
chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=chunk_type, |
|
|
text=current_text.strip(), |
|
|
bbox=bbox, |
|
|
page=page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=len(chunks), |
|
|
confidence=confidence, |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
|
|
|
if overlap > 0: |
|
|
overlap_text = current_text[-overlap:] if len(current_text) > overlap else current_text |
|
|
current_text = overlap_text + " " + sentence |
|
|
else: |
|
|
current_text = sentence |
|
|
else: |
|
|
current_text += " " + sentence if current_text else sentence |
|
|
|
|
|
|
|
|
if current_text.strip(): |
|
|
chunk = DocumentChunk( |
|
|
chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=chunk_type, |
|
|
text=current_text.strip(), |
|
|
bbox=bbox, |
|
|
page=page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=len(chunks), |
|
|
confidence=confidence, |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _split_sentences(self, text: str) -> List[str]: |
|
|
"""Split text into sentences.""" |
|
|
|
|
|
import re |
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
|
|
def _create_table_chunk( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout: LayoutRegion, |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
) -> DocumentChunk: |
|
|
""" |
|
|
Create a chunk for table content with structure preservation. |
|
|
|
|
|
Enhanced table handling (FG-002): |
|
|
- Reconstructs table structure from OCR regions |
|
|
- Generates markdown table representation |
|
|
- Stores structured data for SQL-like queries |
|
|
- Detects and marks header rows |
|
|
""" |
|
|
if not ocr_regions: |
|
|
return DocumentChunk( |
|
|
chunk_id=f"{document_id}_table_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=ChunkType.TABLE, |
|
|
text="[Empty Table]", |
|
|
bbox=layout.bbox, |
|
|
page=layout.page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=0, |
|
|
confidence=0.0, |
|
|
extra=layout.extra or {}, |
|
|
) |
|
|
|
|
|
avg_conf = sum(r.confidence for r in ocr_regions) / len(ocr_regions) |
|
|
|
|
|
|
|
|
if not self.config.preserve_table_structure: |
|
|
|
|
|
text = " | ".join(r.text for r in ocr_regions) |
|
|
return DocumentChunk( |
|
|
chunk_id=f"{document_id}_table_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=ChunkType.TABLE, |
|
|
text=text, |
|
|
bbox=layout.bbox, |
|
|
page=layout.page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=0, |
|
|
confidence=avg_conf, |
|
|
extra=layout.extra or {}, |
|
|
) |
|
|
|
|
|
|
|
|
table_data = self._reconstruct_table_structure(ocr_regions) |
|
|
|
|
|
|
|
|
markdown_table = self._table_to_markdown( |
|
|
table_data["rows"], |
|
|
table_data["headers"], |
|
|
table_data["has_header"] |
|
|
) |
|
|
|
|
|
|
|
|
table_extra = { |
|
|
**(layout.extra or {}), |
|
|
"table_structure": { |
|
|
"row_count": table_data["row_count"], |
|
|
"col_count": table_data["col_count"], |
|
|
"has_header": table_data["has_header"], |
|
|
"headers": table_data["headers"], |
|
|
"cells": table_data["cells"], |
|
|
"cell_positions": table_data["cell_positions"], |
|
|
}, |
|
|
"format": "markdown", |
|
|
"searchable_text": table_data["searchable_text"], |
|
|
} |
|
|
|
|
|
return DocumentChunk( |
|
|
chunk_id=f"{document_id}_table_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=ChunkType.TABLE, |
|
|
text=markdown_table, |
|
|
bbox=layout.bbox, |
|
|
page=layout.page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=0, |
|
|
confidence=avg_conf, |
|
|
extra=table_extra, |
|
|
) |
|
|
|
|
|
def _reconstruct_table_structure( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Reconstruct table structure from OCR regions based on spatial positions. |
|
|
|
|
|
Groups OCR regions into rows and columns by analyzing their bounding boxes. |
|
|
Returns structured table data for markdown generation and queries. |
|
|
""" |
|
|
if not ocr_regions: |
|
|
return { |
|
|
"rows": [], |
|
|
"headers": [], |
|
|
"has_header": False, |
|
|
"row_count": 0, |
|
|
"col_count": 0, |
|
|
"cells": [], |
|
|
"cell_positions": [], |
|
|
"searchable_text": "", |
|
|
} |
|
|
|
|
|
|
|
|
sorted_regions = sorted( |
|
|
ocr_regions, |
|
|
key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
|
|
) |
|
|
|
|
|
|
|
|
row_threshold = self.config.table_row_threshold |
|
|
rows: List[List[OCRRegion]] = [] |
|
|
current_row: List[OCRRegion] = [] |
|
|
current_y = None |
|
|
|
|
|
for region in sorted_regions: |
|
|
if current_y is None: |
|
|
current_y = region.bbox.y_min |
|
|
current_row.append(region) |
|
|
elif abs(region.bbox.y_min - current_y) <= row_threshold: |
|
|
current_row.append(region) |
|
|
else: |
|
|
if current_row: |
|
|
|
|
|
current_row.sort(key=lambda r: r.bbox.x_min) |
|
|
rows.append(current_row) |
|
|
current_row = [region] |
|
|
current_y = region.bbox.y_min |
|
|
|
|
|
|
|
|
if current_row: |
|
|
current_row.sort(key=lambda r: r.bbox.x_min) |
|
|
rows.append(current_row) |
|
|
|
|
|
|
|
|
|
|
|
col_positions = self._detect_column_positions(rows) |
|
|
num_cols = len(col_positions) if col_positions else max(len(row) for row in rows) |
|
|
|
|
|
|
|
|
cells: List[List[str]] = [] |
|
|
cell_positions: List[List[Dict[str, Any]]] = [] |
|
|
|
|
|
for row in rows: |
|
|
row_cells = self._assign_cells_to_columns(row, col_positions, num_cols) |
|
|
cells.append([cell["text"] for cell in row_cells]) |
|
|
cell_positions.append([{ |
|
|
"text": cell["text"], |
|
|
"bbox": cell["bbox"], |
|
|
"confidence": cell["confidence"] |
|
|
} for cell in row_cells]) |
|
|
|
|
|
|
|
|
has_header = False |
|
|
headers: List[str] = [] |
|
|
|
|
|
if self.config.detect_table_headers and len(cells) > 0: |
|
|
has_header, headers = self._detect_header_row(cells, rows) |
|
|
|
|
|
|
|
|
searchable_parts = [] |
|
|
for i, row in enumerate(cells): |
|
|
if has_header and i == 0: |
|
|
searchable_parts.append("Headers: " + ", ".join(row)) |
|
|
else: |
|
|
if has_header and headers: |
|
|
|
|
|
for j, cell in enumerate(row): |
|
|
if j < len(headers) and headers[j]: |
|
|
searchable_parts.append(f"{headers[j]}: {cell}") |
|
|
else: |
|
|
searchable_parts.append(cell) |
|
|
else: |
|
|
searchable_parts.extend(row) |
|
|
|
|
|
return { |
|
|
"rows": cells, |
|
|
"headers": headers, |
|
|
"has_header": has_header, |
|
|
"row_count": len(cells), |
|
|
"col_count": num_cols, |
|
|
"cells": cells, |
|
|
"cell_positions": cell_positions, |
|
|
"searchable_text": " | ".join(searchable_parts), |
|
|
} |
|
|
|
|
|
def _detect_column_positions( |
|
|
self, |
|
|
rows: List[List[OCRRegion]], |
|
|
) -> List[Tuple[float, float]]: |
|
|
""" |
|
|
Detect consistent column boundaries from table rows. |
|
|
|
|
|
Returns list of (x_start, x_end) tuples for each column. |
|
|
""" |
|
|
if not rows: |
|
|
return [] |
|
|
|
|
|
col_threshold = self.config.table_col_threshold |
|
|
|
|
|
|
|
|
all_x_starts = [] |
|
|
for row in rows: |
|
|
for region in row: |
|
|
all_x_starts.append(region.bbox.x_min) |
|
|
|
|
|
if not all_x_starts: |
|
|
return [] |
|
|
|
|
|
|
|
|
all_x_starts.sort() |
|
|
columns = [] |
|
|
current_col_start = all_x_starts[0] |
|
|
current_col_regions = [all_x_starts[0]] |
|
|
|
|
|
for x in all_x_starts[1:]: |
|
|
if x - current_col_regions[-1] <= col_threshold: |
|
|
current_col_regions.append(x) |
|
|
else: |
|
|
|
|
|
col_center = sum(current_col_regions) / len(current_col_regions) |
|
|
columns.append(col_center) |
|
|
current_col_regions = [x] |
|
|
|
|
|
|
|
|
if current_col_regions: |
|
|
col_center = sum(current_col_regions) / len(current_col_regions) |
|
|
columns.append(col_center) |
|
|
|
|
|
|
|
|
col_ranges = [] |
|
|
for i, col_x in enumerate(columns): |
|
|
x_start = col_x - col_threshold |
|
|
if i < len(columns) - 1: |
|
|
x_end = (col_x + columns[i + 1]) / 2 |
|
|
else: |
|
|
x_end = col_x + col_threshold * 3 |
|
|
col_ranges.append((x_start, x_end)) |
|
|
|
|
|
return col_ranges |
|
|
|
|
|
def _assign_cells_to_columns( |
|
|
self, |
|
|
row_regions: List[OCRRegion], |
|
|
col_positions: List[Tuple[float, float]], |
|
|
num_cols: int, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Assign OCR regions in a row to their respective columns. |
|
|
Handles merged cells and missing cells. |
|
|
""" |
|
|
|
|
|
row_cells = [ |
|
|
{"text": "", "bbox": None, "confidence": 0.0} |
|
|
for _ in range(num_cols) |
|
|
] |
|
|
|
|
|
if not col_positions: |
|
|
|
|
|
for i, region in enumerate(row_regions): |
|
|
if i < num_cols: |
|
|
row_cells[i] = { |
|
|
"text": region.text.strip(), |
|
|
"bbox": region.bbox.to_xyxy(), |
|
|
"confidence": region.confidence, |
|
|
} |
|
|
return row_cells |
|
|
|
|
|
|
|
|
for region in row_regions: |
|
|
region_x = region.bbox.x_min |
|
|
assigned = False |
|
|
|
|
|
for col_idx, (x_start, x_end) in enumerate(col_positions): |
|
|
if x_start <= region_x <= x_end: |
|
|
|
|
|
if row_cells[col_idx]["text"]: |
|
|
row_cells[col_idx]["text"] += " " + region.text.strip() |
|
|
else: |
|
|
row_cells[col_idx]["text"] = region.text.strip() |
|
|
row_cells[col_idx]["bbox"] = region.bbox.to_xyxy() |
|
|
row_cells[col_idx]["confidence"] = max( |
|
|
row_cells[col_idx]["confidence"], |
|
|
region.confidence |
|
|
) |
|
|
assigned = True |
|
|
break |
|
|
|
|
|
|
|
|
if not assigned: |
|
|
min_dist = float("inf") |
|
|
nearest_col = 0 |
|
|
for col_idx, (x_start, x_end) in enumerate(col_positions): |
|
|
col_center = (x_start + x_end) / 2 |
|
|
dist = abs(region_x - col_center) |
|
|
if dist < min_dist: |
|
|
min_dist = dist |
|
|
nearest_col = col_idx |
|
|
|
|
|
if row_cells[nearest_col]["text"]: |
|
|
row_cells[nearest_col]["text"] += " " + region.text.strip() |
|
|
else: |
|
|
row_cells[nearest_col]["text"] = region.text.strip() |
|
|
row_cells[nearest_col]["bbox"] = region.bbox.to_xyxy() |
|
|
row_cells[nearest_col]["confidence"] = region.confidence |
|
|
|
|
|
return row_cells |
|
|
|
|
|
def _detect_header_row( |
|
|
self, |
|
|
cells: List[List[str]], |
|
|
rows: List[List[OCRRegion]], |
|
|
) -> Tuple[bool, List[str]]: |
|
|
""" |
|
|
Detect if the first row is a header row. |
|
|
|
|
|
Heuristics used: |
|
|
- First row contains non-numeric text |
|
|
- First row text is shorter (labels vs data) |
|
|
- First row has distinct formatting (if available) |
|
|
""" |
|
|
if not cells or len(cells) < 2: |
|
|
return False, [] |
|
|
|
|
|
first_row = cells[0] |
|
|
other_rows = cells[1:] |
|
|
|
|
|
|
|
|
first_row_numeric_count = sum( |
|
|
1 for cell in first_row |
|
|
if cell and self._is_numeric(cell) |
|
|
) |
|
|
first_row_text_ratio = (len(first_row) - first_row_numeric_count) / max(len(first_row), 1) |
|
|
|
|
|
|
|
|
other_numeric_ratios = [] |
|
|
for row in other_rows: |
|
|
if row: |
|
|
numeric_count = sum(1 for cell in row if cell and self._is_numeric(cell)) |
|
|
other_numeric_ratios.append(numeric_count / max(len(row), 1)) |
|
|
|
|
|
avg_other_numeric = sum(other_numeric_ratios) / max(len(other_numeric_ratios), 1) |
|
|
|
|
|
|
|
|
is_header = ( |
|
|
first_row_text_ratio > 0.5 and |
|
|
(avg_other_numeric > first_row_text_ratio * 0.5 or first_row_text_ratio > 0.8) |
|
|
) |
|
|
|
|
|
|
|
|
first_row_avg_len = sum(len(cell) for cell in first_row) / max(len(first_row), 1) |
|
|
other_avg_lens = [ |
|
|
sum(len(cell) for cell in row) / max(len(row), 1) |
|
|
for row in other_rows |
|
|
] |
|
|
avg_other_len = sum(other_avg_lens) / max(len(other_avg_lens), 1) |
|
|
|
|
|
if first_row_avg_len < avg_other_len * 0.8: |
|
|
is_header = True |
|
|
|
|
|
return is_header, first_row if is_header else [] |
|
|
|
|
|
def _is_numeric(self, text: str) -> bool: |
|
|
"""Check if text is primarily numeric (including currency, percentages).""" |
|
|
cleaned = re.sub(r'[$€£¥%,.\s\-+()]', '', text) |
|
|
return cleaned.isdigit() if cleaned else False |
|
|
|
|
|
def _table_to_markdown( |
|
|
self, |
|
|
rows: List[List[str]], |
|
|
headers: List[str], |
|
|
has_header: bool, |
|
|
) -> str: |
|
|
""" |
|
|
Convert table data to markdown format. |
|
|
|
|
|
Creates a properly formatted markdown table with: |
|
|
- Header row (if detected) |
|
|
- Separator row |
|
|
- Data rows |
|
|
""" |
|
|
if not rows: |
|
|
return "[Empty Table]" |
|
|
|
|
|
|
|
|
num_cols = max(len(row) for row in rows) if rows else 0 |
|
|
if num_cols == 0: |
|
|
return "[Empty Table]" |
|
|
|
|
|
|
|
|
normalized_rows = [] |
|
|
for row in rows: |
|
|
normalized = row + [""] * (num_cols - len(row)) |
|
|
normalized_rows.append(normalized) |
|
|
|
|
|
|
|
|
md_lines = [] |
|
|
|
|
|
if has_header and headers: |
|
|
|
|
|
header_line = "| " + " | ".join(headers + [""] * (num_cols - len(headers))) + " |" |
|
|
separator = "| " + " | ".join(["---"] * num_cols) + " |" |
|
|
md_lines.append(header_line) |
|
|
md_lines.append(separator) |
|
|
data_rows = normalized_rows[1:] |
|
|
else: |
|
|
|
|
|
generic_headers = [f"Col{i+1}" for i in range(num_cols)] |
|
|
header_line = "| " + " | ".join(generic_headers) + " |" |
|
|
separator = "| " + " | ".join(["---"] * num_cols) + " |" |
|
|
md_lines.append(header_line) |
|
|
md_lines.append(separator) |
|
|
data_rows = normalized_rows |
|
|
|
|
|
|
|
|
for row in data_rows: |
|
|
|
|
|
escaped_row = [cell.replace("|", "\\|") for cell in row] |
|
|
row_line = "| " + " | ".join(escaped_row) + " |" |
|
|
md_lines.append(row_line) |
|
|
|
|
|
return "\n".join(md_lines) |
|
|
|
|
|
def _create_figure_chunk( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
layout: LayoutRegion, |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
) -> DocumentChunk: |
|
|
"""Create a chunk for figure/chart content.""" |
|
|
|
|
|
text = " ".join(r.text for r in ocr_regions) if ocr_regions else "[Figure]" |
|
|
avg_conf = sum(r.confidence for r in ocr_regions) / len(ocr_regions) if ocr_regions else 0.5 |
|
|
|
|
|
chunk_type = ChunkType.CHART if layout.type == LayoutType.CHART else ChunkType.FIGURE |
|
|
|
|
|
return DocumentChunk( |
|
|
chunk_id=f"{document_id}_{chunk_type.value}_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=chunk_type, |
|
|
text=text, |
|
|
bbox=layout.bbox, |
|
|
page=layout.page, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=0, |
|
|
confidence=avg_conf, |
|
|
caption=text if ocr_regions else None, |
|
|
) |
|
|
|
|
|
def _split_text_regions( |
|
|
self, |
|
|
ocr_regions: List[OCRRegion], |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
page_num: int, |
|
|
) -> List[DocumentChunk]: |
|
|
"""Split OCR regions into chunks without layout guidance.""" |
|
|
if not ocr_regions: |
|
|
return [] |
|
|
|
|
|
chunks = [] |
|
|
current_text = "" |
|
|
current_regions = [] |
|
|
|
|
|
for region in ocr_regions: |
|
|
if len(current_text) + len(region.text) > self.config.max_chunk_chars: |
|
|
if current_regions: |
|
|
|
|
|
chunk = self._create_chunk_from_regions( |
|
|
current_regions, document_id, source_path, page_num, len(chunks) |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
current_text = region.text |
|
|
current_regions = [region] |
|
|
else: |
|
|
current_text += " " + region.text |
|
|
current_regions.append(region) |
|
|
|
|
|
|
|
|
if current_regions: |
|
|
chunk = self._create_chunk_from_regions( |
|
|
current_regions, document_id, source_path, page_num, len(chunks) |
|
|
) |
|
|
chunks.append(chunk) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _create_chunk_from_regions( |
|
|
self, |
|
|
regions: List[OCRRegion], |
|
|
document_id: str, |
|
|
source_path: Optional[str], |
|
|
page_num: int, |
|
|
sequence_index: int, |
|
|
) -> DocumentChunk: |
|
|
"""Create a chunk from a list of OCR regions.""" |
|
|
text = " ".join(r.text for r in regions) |
|
|
avg_conf = sum(r.confidence for r in regions) / len(regions) |
|
|
|
|
|
|
|
|
x_min = min(r.bbox.x_min for r in regions) |
|
|
y_min = min(r.bbox.y_min for r in regions) |
|
|
x_max = max(r.bbox.x_max for r in regions) |
|
|
y_max = max(r.bbox.y_max for r in regions) |
|
|
|
|
|
bbox = BoundingBox( |
|
|
x_min=x_min, y_min=y_min, |
|
|
x_max=x_max, y_max=y_max, |
|
|
normalized=False, |
|
|
) |
|
|
|
|
|
return DocumentChunk( |
|
|
chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}", |
|
|
chunk_type=ChunkType.TEXT, |
|
|
text=text, |
|
|
bbox=bbox, |
|
|
page=page_num, |
|
|
document_id=document_id, |
|
|
source_path=source_path, |
|
|
sequence_index=sequence_index, |
|
|
confidence=avg_conf, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_document_chunker: Optional[DocumentChunker] = None |
|
|
|
|
|
|
|
|
def get_document_chunker( |
|
|
config: Optional[ChunkerConfig] = None, |
|
|
) -> DocumentChunker: |
|
|
"""Get or create singleton document chunker.""" |
|
|
global _document_chunker |
|
|
if _document_chunker is None: |
|
|
config = config or ChunkerConfig() |
|
|
_document_chunker = SemanticChunker(config) |
|
|
return _document_chunker |
|
|
|