""" Table Extraction Model Interface Abstract interface for table structure recognition and cell extraction. Handles complex tables with merged cells, headers, and nested structures. """ from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Tuple from ..chunks.models import BoundingBox, TableCell, TableChunk from .base import ( BaseModel, BatchableModel, ImageInput, ModelCapability, ModelConfig, ) from .layout import LayoutRegion class TableCellType(str, Enum): """Types of table cells.""" HEADER = "header" DATA = "data" INDEX = "index" MERGED = "merged" EMPTY = "empty" @dataclass class TableConfig(ModelConfig): """Configuration for table extraction models.""" min_confidence: float = 0.5 detect_headers: bool = True detect_merged_cells: bool = True max_rows: int = 500 max_cols: int = 50 extract_cell_text: bool = True # Whether to OCR cell contents def __post_init__(self): super().__post_init__() if not self.name: self.name = "table_extractor" @dataclass class TableStructure: """ Detected table structure with cell grid. Represents the logical structure of a table including merged cells, headers, and cell relationships. """ bbox: BoundingBox cells: List[TableCell] = field(default_factory=list) num_rows: int = 0 num_cols: int = 0 # Header information header_rows: List[int] = field(default_factory=list) # 0-indexed row indices header_cols: List[int] = field(default_factory=list) # 0-indexed col indices # Confidence structure_confidence: float = 0.0 cell_confidence_avg: float = 0.0 # Additional metadata has_merged_cells: bool = False is_bordered: bool = True table_id: str = "" def __post_init__(self): if not self.table_id: import hashlib content = f"table_{self.bbox.xyxy}_{self.num_rows}x{self.num_cols}" self.table_id = hashlib.md5(content.encode()).hexdigest()[:12] def get_cell(self, row: int, col: int) -> Optional[TableCell]: """Get cell at specific position.""" for cell in self.cells: if cell.row == row and cell.col == col: return cell # Check merged cells if (cell.row <= row < cell.row + cell.rowspan and cell.col <= col < cell.col + cell.colspan): return cell return None def get_row(self, row_index: int) -> List[TableCell]: """Get all cells in a row.""" return sorted( [c for c in self.cells if c.row == row_index], key=lambda c: c.col ) def get_col(self, col_index: int) -> List[TableCell]: """Get all cells in a column.""" return sorted( [c for c in self.cells if c.col == col_index], key=lambda c: c.row ) def get_headers(self) -> List[TableCell]: """Get all header cells.""" return [c for c in self.cells if c.is_header] def to_csv(self, delimiter: str = ",") -> str: """Convert table to CSV string.""" rows = [] for r in range(self.num_rows): row_cells = [] for c in range(self.num_cols): cell = self.get_cell(r, c) text = cell.text if cell else "" # Escape delimiter and quotes if delimiter in text or '"' in text or '\n' in text: text = '"' + text.replace('"', '""') + '"' row_cells.append(text) rows.append(delimiter.join(row_cells)) return "\n".join(rows) def to_markdown(self) -> str: """Convert table to Markdown format.""" if self.num_rows == 0 or self.num_cols == 0: return "" lines = [] # Build rows for r in range(self.num_rows): row_texts = [] for c in range(self.num_cols): cell = self.get_cell(r, c) text = cell.text.replace("|", "\\|") if cell else "" row_texts.append(text) lines.append("| " + " | ".join(row_texts) + " |") # Add separator after first row (header) if r == 0: separators = ["---"] * self.num_cols lines.append("| " + " | ".join(separators) + " |") return "\n".join(lines) def to_dict(self) -> Dict[str, Any]: """Convert to structured dictionary.""" return { "num_rows": self.num_rows, "num_cols": self.num_cols, "header_rows": self.header_rows, "header_cols": self.header_cols, "cells": [ { "row": c.row, "col": c.col, "text": c.text, "rowspan": c.rowspan, "colspan": c.colspan, "is_header": c.is_header, "confidence": c.confidence } for c in self.cells ] } def to_table_chunk( self, doc_id: str, page: int, sequence_index: int ) -> TableChunk: """Convert to TableChunk for the chunks module.""" return TableChunk( chunk_id=TableChunk.generate_chunk_id( doc_id=doc_id, page=page, bbox=self.bbox, chunk_type_str="table" ), doc_id=doc_id, text=self.to_markdown(), page=page, bbox=self.bbox, confidence=self.structure_confidence, sequence_index=sequence_index, cells=self.cells, num_rows=self.num_rows, num_cols=self.num_cols, header_rows=self.header_rows, header_cols=self.header_cols, has_merged_cells=self.has_merged_cells ) @dataclass class TableExtractionResult: """Result of table extraction from a page.""" tables: List[TableStructure] = field(default_factory=list) processing_time_ms: float = 0.0 model_metadata: Dict[str, Any] = field(default_factory=dict) @property def table_count(self) -> int: return len(self.tables) def get_table_at_region( self, region: LayoutRegion, iou_threshold: float = 0.5 ) -> Optional[TableStructure]: """Find table that matches a layout region.""" best_match = None best_iou = 0.0 for table in self.tables: iou = table.bbox.iou(region.bbox) if iou > iou_threshold and iou > best_iou: best_match = table best_iou = iou return best_match class TableModel(BatchableModel): """ Abstract base class for table extraction models. Implementations should handle: - Table structure detection (rows, columns) - Cell boundary detection - Merged cell handling - Header detection - Cell content extraction """ def __init__(self, config: Optional[TableConfig] = None): super().__init__(config or TableConfig(name="table")) self.config: TableConfig = self.config def get_capabilities(self) -> List[ModelCapability]: return [ModelCapability.TABLE_EXTRACTION] @abstractmethod def extract_structure( self, image: ImageInput, table_region: Optional[BoundingBox] = None, **kwargs ) -> TableStructure: """ Extract table structure from an image. Args: image: Input image containing a table table_region: Optional bounding box of the table region **kwargs: Additional parameters Returns: TableStructure with cells and metadata """ pass def extract_all_tables( self, image: ImageInput, table_regions: Optional[List[BoundingBox]] = None, **kwargs ) -> TableExtractionResult: """ Extract all tables from an image. Args: image: Input document image table_regions: Optional list of table bounding boxes **kwargs: Additional parameters Returns: TableExtractionResult with all detected tables """ import time start_time = time.time() tables = [] if table_regions: # Extract from specified regions for region in table_regions: try: table = self.extract_structure(image, region, **kwargs) tables.append(table) except Exception: continue else: # Detect and extract all tables table = self.extract_structure(image, **kwargs) if table.num_rows > 0: tables.append(table) processing_time = (time.time() - start_time) * 1000 return TableExtractionResult( tables=tables, processing_time_ms=processing_time ) def process_batch( self, inputs: List[ImageInput], **kwargs ) -> List[TableExtractionResult]: """Process multiple images.""" return [self.extract_all_tables(img, **kwargs) for img in inputs] @abstractmethod def extract_cell_text( self, image: ImageInput, cell_bbox: BoundingBox, **kwargs ) -> str: """ Extract text from a specific cell region. Args: image: Image containing the cell cell_bbox: Bounding box of the cell **kwargs: Additional parameters Returns: Extracted text content """ pass