|
|
""" |
|
|
Table Extraction Model Interface |
|
|
|
|
|
Abstract interface for table structure recognition and cell extraction. |
|
|
Handles complex tables with merged cells, headers, and nested structures. |
|
|
""" |
|
|
|
|
|
from abc import abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from ..chunks.models import BoundingBox, TableCell, TableChunk |
|
|
from .base import ( |
|
|
BaseModel, |
|
|
BatchableModel, |
|
|
ImageInput, |
|
|
ModelCapability, |
|
|
ModelConfig, |
|
|
) |
|
|
from .layout import LayoutRegion |
|
|
|
|
|
|
|
|
class TableCellType(str, Enum): |
|
|
"""Types of table cells.""" |
|
|
|
|
|
HEADER = "header" |
|
|
DATA = "data" |
|
|
INDEX = "index" |
|
|
MERGED = "merged" |
|
|
EMPTY = "empty" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TableConfig(ModelConfig): |
|
|
"""Configuration for table extraction models.""" |
|
|
|
|
|
min_confidence: float = 0.5 |
|
|
detect_headers: bool = True |
|
|
detect_merged_cells: bool = True |
|
|
max_rows: int = 500 |
|
|
max_cols: int = 50 |
|
|
extract_cell_text: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if not self.name: |
|
|
self.name = "table_extractor" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TableStructure: |
|
|
""" |
|
|
Detected table structure with cell grid. |
|
|
|
|
|
Represents the logical structure of a table including |
|
|
merged cells, headers, and cell relationships. |
|
|
""" |
|
|
|
|
|
bbox: BoundingBox |
|
|
cells: List[TableCell] = field(default_factory=list) |
|
|
num_rows: int = 0 |
|
|
num_cols: int = 0 |
|
|
|
|
|
|
|
|
header_rows: List[int] = field(default_factory=list) |
|
|
header_cols: List[int] = field(default_factory=list) |
|
|
|
|
|
|
|
|
structure_confidence: float = 0.0 |
|
|
cell_confidence_avg: float = 0.0 |
|
|
|
|
|
|
|
|
has_merged_cells: bool = False |
|
|
is_bordered: bool = True |
|
|
table_id: str = "" |
|
|
|
|
|
def __post_init__(self): |
|
|
if not self.table_id: |
|
|
import hashlib |
|
|
content = f"table_{self.bbox.xyxy}_{self.num_rows}x{self.num_cols}" |
|
|
self.table_id = hashlib.md5(content.encode()).hexdigest()[:12] |
|
|
|
|
|
def get_cell(self, row: int, col: int) -> Optional[TableCell]: |
|
|
"""Get cell at specific position.""" |
|
|
for cell in self.cells: |
|
|
if cell.row == row and cell.col == col: |
|
|
return cell |
|
|
|
|
|
if (cell.row <= row < cell.row + cell.rowspan and |
|
|
cell.col <= col < cell.col + cell.colspan): |
|
|
return cell |
|
|
return None |
|
|
|
|
|
def get_row(self, row_index: int) -> List[TableCell]: |
|
|
"""Get all cells in a row.""" |
|
|
return sorted( |
|
|
[c for c in self.cells if c.row == row_index], |
|
|
key=lambda c: c.col |
|
|
) |
|
|
|
|
|
def get_col(self, col_index: int) -> List[TableCell]: |
|
|
"""Get all cells in a column.""" |
|
|
return sorted( |
|
|
[c for c in self.cells if c.col == col_index], |
|
|
key=lambda c: c.row |
|
|
) |
|
|
|
|
|
def get_headers(self) -> List[TableCell]: |
|
|
"""Get all header cells.""" |
|
|
return [c for c in self.cells if c.is_header] |
|
|
|
|
|
def to_csv(self, delimiter: str = ",") -> str: |
|
|
"""Convert table to CSV string.""" |
|
|
rows = [] |
|
|
for r in range(self.num_rows): |
|
|
row_cells = [] |
|
|
for c in range(self.num_cols): |
|
|
cell = self.get_cell(r, c) |
|
|
text = cell.text if cell else "" |
|
|
|
|
|
if delimiter in text or '"' in text or '\n' in text: |
|
|
text = '"' + text.replace('"', '""') + '"' |
|
|
row_cells.append(text) |
|
|
rows.append(delimiter.join(row_cells)) |
|
|
return "\n".join(rows) |
|
|
|
|
|
def to_markdown(self) -> str: |
|
|
"""Convert table to Markdown format.""" |
|
|
if self.num_rows == 0 or self.num_cols == 0: |
|
|
return "" |
|
|
|
|
|
lines = [] |
|
|
|
|
|
|
|
|
for r in range(self.num_rows): |
|
|
row_texts = [] |
|
|
for c in range(self.num_cols): |
|
|
cell = self.get_cell(r, c) |
|
|
text = cell.text.replace("|", "\\|") if cell else "" |
|
|
row_texts.append(text) |
|
|
lines.append("| " + " | ".join(row_texts) + " |") |
|
|
|
|
|
|
|
|
if r == 0: |
|
|
separators = ["---"] * self.num_cols |
|
|
lines.append("| " + " | ".join(separators) + " |") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to structured dictionary.""" |
|
|
return { |
|
|
"num_rows": self.num_rows, |
|
|
"num_cols": self.num_cols, |
|
|
"header_rows": self.header_rows, |
|
|
"header_cols": self.header_cols, |
|
|
"cells": [ |
|
|
{ |
|
|
"row": c.row, |
|
|
"col": c.col, |
|
|
"text": c.text, |
|
|
"rowspan": c.rowspan, |
|
|
"colspan": c.colspan, |
|
|
"is_header": c.is_header, |
|
|
"confidence": c.confidence |
|
|
} |
|
|
for c in self.cells |
|
|
] |
|
|
} |
|
|
|
|
|
def to_table_chunk( |
|
|
self, |
|
|
doc_id: str, |
|
|
page: int, |
|
|
sequence_index: int |
|
|
) -> TableChunk: |
|
|
"""Convert to TableChunk for the chunks module.""" |
|
|
return TableChunk( |
|
|
chunk_id=TableChunk.generate_chunk_id( |
|
|
doc_id=doc_id, |
|
|
page=page, |
|
|
bbox=self.bbox, |
|
|
chunk_type_str="table" |
|
|
), |
|
|
doc_id=doc_id, |
|
|
text=self.to_markdown(), |
|
|
page=page, |
|
|
bbox=self.bbox, |
|
|
confidence=self.structure_confidence, |
|
|
sequence_index=sequence_index, |
|
|
cells=self.cells, |
|
|
num_rows=self.num_rows, |
|
|
num_cols=self.num_cols, |
|
|
header_rows=self.header_rows, |
|
|
header_cols=self.header_cols, |
|
|
has_merged_cells=self.has_merged_cells |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TableExtractionResult: |
|
|
"""Result of table extraction from a page.""" |
|
|
|
|
|
tables: List[TableStructure] = field(default_factory=list) |
|
|
processing_time_ms: float = 0.0 |
|
|
model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@property |
|
|
def table_count(self) -> int: |
|
|
return len(self.tables) |
|
|
|
|
|
def get_table_at_region( |
|
|
self, |
|
|
region: LayoutRegion, |
|
|
iou_threshold: float = 0.5 |
|
|
) -> Optional[TableStructure]: |
|
|
"""Find table that matches a layout region.""" |
|
|
best_match = None |
|
|
best_iou = 0.0 |
|
|
|
|
|
for table in self.tables: |
|
|
iou = table.bbox.iou(region.bbox) |
|
|
if iou > iou_threshold and iou > best_iou: |
|
|
best_match = table |
|
|
best_iou = iou |
|
|
|
|
|
return best_match |
|
|
|
|
|
|
|
|
class TableModel(BatchableModel): |
|
|
""" |
|
|
Abstract base class for table extraction models. |
|
|
|
|
|
Implementations should handle: |
|
|
- Table structure detection (rows, columns) |
|
|
- Cell boundary detection |
|
|
- Merged cell handling |
|
|
- Header detection |
|
|
- Cell content extraction |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[TableConfig] = None): |
|
|
super().__init__(config or TableConfig(name="table")) |
|
|
self.config: TableConfig = self.config |
|
|
|
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
return [ModelCapability.TABLE_EXTRACTION] |
|
|
|
|
|
@abstractmethod |
|
|
def extract_structure( |
|
|
self, |
|
|
image: ImageInput, |
|
|
table_region: Optional[BoundingBox] = None, |
|
|
**kwargs |
|
|
) -> TableStructure: |
|
|
""" |
|
|
Extract table structure from an image. |
|
|
|
|
|
Args: |
|
|
image: Input image containing a table |
|
|
table_region: Optional bounding box of the table region |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
TableStructure with cells and metadata |
|
|
""" |
|
|
pass |
|
|
|
|
|
def extract_all_tables( |
|
|
self, |
|
|
image: ImageInput, |
|
|
table_regions: Optional[List[BoundingBox]] = None, |
|
|
**kwargs |
|
|
) -> TableExtractionResult: |
|
|
""" |
|
|
Extract all tables from an image. |
|
|
|
|
|
Args: |
|
|
image: Input document image |
|
|
table_regions: Optional list of table bounding boxes |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
TableExtractionResult with all detected tables |
|
|
""" |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
tables = [] |
|
|
|
|
|
if table_regions: |
|
|
|
|
|
for region in table_regions: |
|
|
try: |
|
|
table = self.extract_structure(image, region, **kwargs) |
|
|
tables.append(table) |
|
|
except Exception: |
|
|
continue |
|
|
else: |
|
|
|
|
|
table = self.extract_structure(image, **kwargs) |
|
|
if table.num_rows > 0: |
|
|
tables.append(table) |
|
|
|
|
|
processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
return TableExtractionResult( |
|
|
tables=tables, |
|
|
processing_time_ms=processing_time |
|
|
) |
|
|
|
|
|
def process_batch( |
|
|
self, |
|
|
inputs: List[ImageInput], |
|
|
**kwargs |
|
|
) -> List[TableExtractionResult]: |
|
|
"""Process multiple images.""" |
|
|
return [self.extract_all_tables(img, **kwargs) for img in inputs] |
|
|
|
|
|
@abstractmethod |
|
|
def extract_cell_text( |
|
|
self, |
|
|
image: ImageInput, |
|
|
cell_bbox: BoundingBox, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Extract text from a specific cell region. |
|
|
|
|
|
Args: |
|
|
image: Image containing the cell |
|
|
cell_bbox: Bounding box of the cell |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
Extracted text content |
|
|
""" |
|
|
pass |
|
|
|