MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Table Extraction Model Interface
Abstract interface for table structure recognition and cell extraction.
Handles complex tables with merged cells, headers, and nested structures.
"""
from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from ..chunks.models import BoundingBox, TableCell, TableChunk
from .base import (
BaseModel,
BatchableModel,
ImageInput,
ModelCapability,
ModelConfig,
)
from .layout import LayoutRegion
class TableCellType(str, Enum):
"""Types of table cells."""
HEADER = "header"
DATA = "data"
INDEX = "index"
MERGED = "merged"
EMPTY = "empty"
@dataclass
class TableConfig(ModelConfig):
"""Configuration for table extraction models."""
min_confidence: float = 0.5
detect_headers: bool = True
detect_merged_cells: bool = True
max_rows: int = 500
max_cols: int = 50
extract_cell_text: bool = True # Whether to OCR cell contents
def __post_init__(self):
super().__post_init__()
if not self.name:
self.name = "table_extractor"
@dataclass
class TableStructure:
"""
Detected table structure with cell grid.
Represents the logical structure of a table including
merged cells, headers, and cell relationships.
"""
bbox: BoundingBox
cells: List[TableCell] = field(default_factory=list)
num_rows: int = 0
num_cols: int = 0
# Header information
header_rows: List[int] = field(default_factory=list) # 0-indexed row indices
header_cols: List[int] = field(default_factory=list) # 0-indexed col indices
# Confidence
structure_confidence: float = 0.0
cell_confidence_avg: float = 0.0
# Additional metadata
has_merged_cells: bool = False
is_bordered: bool = True
table_id: str = ""
def __post_init__(self):
if not self.table_id:
import hashlib
content = f"table_{self.bbox.xyxy}_{self.num_rows}x{self.num_cols}"
self.table_id = hashlib.md5(content.encode()).hexdigest()[:12]
def get_cell(self, row: int, col: int) -> Optional[TableCell]:
"""Get cell at specific position."""
for cell in self.cells:
if cell.row == row and cell.col == col:
return cell
# Check merged cells
if (cell.row <= row < cell.row + cell.rowspan and
cell.col <= col < cell.col + cell.colspan):
return cell
return None
def get_row(self, row_index: int) -> List[TableCell]:
"""Get all cells in a row."""
return sorted(
[c for c in self.cells if c.row == row_index],
key=lambda c: c.col
)
def get_col(self, col_index: int) -> List[TableCell]:
"""Get all cells in a column."""
return sorted(
[c for c in self.cells if c.col == col_index],
key=lambda c: c.row
)
def get_headers(self) -> List[TableCell]:
"""Get all header cells."""
return [c for c in self.cells if c.is_header]
def to_csv(self, delimiter: str = ",") -> str:
"""Convert table to CSV string."""
rows = []
for r in range(self.num_rows):
row_cells = []
for c in range(self.num_cols):
cell = self.get_cell(r, c)
text = cell.text if cell else ""
# Escape delimiter and quotes
if delimiter in text or '"' in text or '\n' in text:
text = '"' + text.replace('"', '""') + '"'
row_cells.append(text)
rows.append(delimiter.join(row_cells))
return "\n".join(rows)
def to_markdown(self) -> str:
"""Convert table to Markdown format."""
if self.num_rows == 0 or self.num_cols == 0:
return ""
lines = []
# Build rows
for r in range(self.num_rows):
row_texts = []
for c in range(self.num_cols):
cell = self.get_cell(r, c)
text = cell.text.replace("|", "\\|") if cell else ""
row_texts.append(text)
lines.append("| " + " | ".join(row_texts) + " |")
# Add separator after first row (header)
if r == 0:
separators = ["---"] * self.num_cols
lines.append("| " + " | ".join(separators) + " |")
return "\n".join(lines)
def to_dict(self) -> Dict[str, Any]:
"""Convert to structured dictionary."""
return {
"num_rows": self.num_rows,
"num_cols": self.num_cols,
"header_rows": self.header_rows,
"header_cols": self.header_cols,
"cells": [
{
"row": c.row,
"col": c.col,
"text": c.text,
"rowspan": c.rowspan,
"colspan": c.colspan,
"is_header": c.is_header,
"confidence": c.confidence
}
for c in self.cells
]
}
def to_table_chunk(
self,
doc_id: str,
page: int,
sequence_index: int
) -> TableChunk:
"""Convert to TableChunk for the chunks module."""
return TableChunk(
chunk_id=TableChunk.generate_chunk_id(
doc_id=doc_id,
page=page,
bbox=self.bbox,
chunk_type_str="table"
),
doc_id=doc_id,
text=self.to_markdown(),
page=page,
bbox=self.bbox,
confidence=self.structure_confidence,
sequence_index=sequence_index,
cells=self.cells,
num_rows=self.num_rows,
num_cols=self.num_cols,
header_rows=self.header_rows,
header_cols=self.header_cols,
has_merged_cells=self.has_merged_cells
)
@dataclass
class TableExtractionResult:
"""Result of table extraction from a page."""
tables: List[TableStructure] = field(default_factory=list)
processing_time_ms: float = 0.0
model_metadata: Dict[str, Any] = field(default_factory=dict)
@property
def table_count(self) -> int:
return len(self.tables)
def get_table_at_region(
self,
region: LayoutRegion,
iou_threshold: float = 0.5
) -> Optional[TableStructure]:
"""Find table that matches a layout region."""
best_match = None
best_iou = 0.0
for table in self.tables:
iou = table.bbox.iou(region.bbox)
if iou > iou_threshold and iou > best_iou:
best_match = table
best_iou = iou
return best_match
class TableModel(BatchableModel):
"""
Abstract base class for table extraction models.
Implementations should handle:
- Table structure detection (rows, columns)
- Cell boundary detection
- Merged cell handling
- Header detection
- Cell content extraction
"""
def __init__(self, config: Optional[TableConfig] = None):
super().__init__(config or TableConfig(name="table"))
self.config: TableConfig = self.config
def get_capabilities(self) -> List[ModelCapability]:
return [ModelCapability.TABLE_EXTRACTION]
@abstractmethod
def extract_structure(
self,
image: ImageInput,
table_region: Optional[BoundingBox] = None,
**kwargs
) -> TableStructure:
"""
Extract table structure from an image.
Args:
image: Input image containing a table
table_region: Optional bounding box of the table region
**kwargs: Additional parameters
Returns:
TableStructure with cells and metadata
"""
pass
def extract_all_tables(
self,
image: ImageInput,
table_regions: Optional[List[BoundingBox]] = None,
**kwargs
) -> TableExtractionResult:
"""
Extract all tables from an image.
Args:
image: Input document image
table_regions: Optional list of table bounding boxes
**kwargs: Additional parameters
Returns:
TableExtractionResult with all detected tables
"""
import time
start_time = time.time()
tables = []
if table_regions:
# Extract from specified regions
for region in table_regions:
try:
table = self.extract_structure(image, region, **kwargs)
tables.append(table)
except Exception:
continue
else:
# Detect and extract all tables
table = self.extract_structure(image, **kwargs)
if table.num_rows > 0:
tables.append(table)
processing_time = (time.time() - start_time) * 1000
return TableExtractionResult(
tables=tables,
processing_time_ms=processing_time
)
def process_batch(
self,
inputs: List[ImageInput],
**kwargs
) -> List[TableExtractionResult]:
"""Process multiple images."""
return [self.extract_all_tables(img, **kwargs) for img in inputs]
@abstractmethod
def extract_cell_text(
self,
image: ImageInput,
cell_bbox: BoundingBox,
**kwargs
) -> str:
"""
Extract text from a specific cell region.
Args:
image: Image containing the cell
cell_bbox: Bounding box of the cell
**kwargs: Additional parameters
Returns:
Extracted text content
"""
pass