|
|
""" |
|
|
Extraction Schemas for Document Intelligence |
|
|
|
|
|
Pydantic models for schema-based field extraction, tables, and charts. |
|
|
""" |
|
|
|
|
|
from enum import Enum |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from .core import BoundingBox, EvidenceRef |
|
|
|
|
|
|
|
|
class FieldType(str, Enum): |
|
|
"""Supported field types for extraction.""" |
|
|
STRING = "string" |
|
|
INTEGER = "integer" |
|
|
FLOAT = "float" |
|
|
BOOLEAN = "boolean" |
|
|
DATE = "date" |
|
|
CURRENCY = "currency" |
|
|
PERCENTAGE = "percentage" |
|
|
EMAIL = "email" |
|
|
PHONE = "phone" |
|
|
ADDRESS = "address" |
|
|
LIST = "list" |
|
|
OBJECT = "object" |
|
|
|
|
|
|
|
|
class FieldDefinition(BaseModel): |
|
|
""" |
|
|
Definition of a field to extract from a document. |
|
|
Used to build extraction schemas. |
|
|
""" |
|
|
name: str = Field(..., description="Field name/key") |
|
|
type: FieldType = Field(..., description="Expected data type") |
|
|
description: str = Field(..., description="Human-readable description") |
|
|
required: bool = Field(default=False, description="Whether field is required") |
|
|
|
|
|
|
|
|
pattern: Optional[str] = Field(default=None, description="Regex pattern for validation") |
|
|
min_value: Optional[float] = Field(default=None, description="Minimum numeric value") |
|
|
max_value: Optional[float] = Field(default=None, description="Maximum numeric value") |
|
|
enum_values: Optional[List[str]] = Field(default=None, description="Allowed values") |
|
|
|
|
|
|
|
|
aliases: List[str] = Field( |
|
|
default_factory=list, |
|
|
description="Alternative names/labels for the field" |
|
|
) |
|
|
search_context: Optional[str] = Field( |
|
|
default=None, |
|
|
description="Context hint for where to find this field" |
|
|
) |
|
|
|
|
|
|
|
|
nested_fields: Optional[List["FieldDefinition"]] = Field( |
|
|
default=None, |
|
|
description="Nested field definitions for complex types" |
|
|
) |
|
|
|
|
|
|
|
|
class ExtractionSchema(BaseModel): |
|
|
""" |
|
|
Schema defining fields to extract from a document. |
|
|
Supports document-type-specific extraction rules. |
|
|
""" |
|
|
schema_id: str = Field(..., description="Unique schema identifier") |
|
|
name: str = Field(..., description="Human-readable schema name") |
|
|
description: str = Field(..., description="Schema description") |
|
|
version: str = Field(default="1.0", description="Schema version") |
|
|
|
|
|
|
|
|
fields: List[FieldDefinition] = Field( |
|
|
default_factory=list, |
|
|
description="Fields to extract" |
|
|
) |
|
|
|
|
|
|
|
|
document_types: List[str] = Field( |
|
|
default_factory=list, |
|
|
description="Applicable document types" |
|
|
) |
|
|
|
|
|
|
|
|
cross_field_validations: List[str] = Field( |
|
|
default_factory=list, |
|
|
description="Cross-field validation expressions" |
|
|
) |
|
|
|
|
|
|
|
|
require_evidence: bool = Field( |
|
|
default=True, |
|
|
description="Require evidence for all extracted fields" |
|
|
) |
|
|
min_confidence: float = Field( |
|
|
default=0.7, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Minimum confidence threshold" |
|
|
) |
|
|
abstain_on_low_confidence: bool = Field( |
|
|
default=True, |
|
|
description="Abstain rather than guess when confidence is low" |
|
|
) |
|
|
|
|
|
def get_field(self, name: str) -> Optional[FieldDefinition]: |
|
|
"""Get field definition by name.""" |
|
|
for field in self.fields: |
|
|
if field.name == name or name in field.aliases: |
|
|
return field |
|
|
return None |
|
|
|
|
|
def get_required_fields(self) -> List[FieldDefinition]: |
|
|
"""Get all required field definitions.""" |
|
|
return [f for f in self.fields if f.required] |
|
|
|
|
|
|
|
|
class TableCell(BaseModel): |
|
|
""" |
|
|
Single cell in a table structure. |
|
|
""" |
|
|
cell_id: str = Field(..., description="Unique cell identifier") |
|
|
row: int = Field(..., ge=0, description="Row index (0-based)") |
|
|
col: int = Field(..., ge=0, description="Column index (0-based)") |
|
|
text: str = Field(..., description="Cell text content") |
|
|
bbox: BoundingBox = Field(..., description="Cell bounding box") |
|
|
|
|
|
|
|
|
row_span: int = Field(default=1, ge=1, description="Number of rows spanned") |
|
|
col_span: int = Field(default=1, ge=1, description="Number of columns spanned") |
|
|
|
|
|
|
|
|
is_header: bool = Field(default=False, description="Whether cell is a header") |
|
|
is_empty: bool = Field(default=False, description="Whether cell is empty") |
|
|
|
|
|
|
|
|
confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
|
|
|
|
|
|
class TableData(BaseModel): |
|
|
""" |
|
|
Structured table data extracted from a document. |
|
|
""" |
|
|
table_id: str = Field(..., description="Unique table identifier") |
|
|
page: int = Field(..., ge=0, description="Page number") |
|
|
bbox: BoundingBox = Field(..., description="Table bounding box") |
|
|
|
|
|
|
|
|
num_rows: int = Field(..., ge=1, description="Number of rows") |
|
|
num_cols: int = Field(..., ge=1, description="Number of columns") |
|
|
cells: List[TableCell] = Field(default_factory=list, description="All cells") |
|
|
|
|
|
|
|
|
header_rows: List[int] = Field( |
|
|
default_factory=list, |
|
|
description="Row indices that are headers" |
|
|
) |
|
|
header_cols: List[int] = Field( |
|
|
default_factory=list, |
|
|
description="Column indices that are headers" |
|
|
) |
|
|
|
|
|
|
|
|
caption: Optional[str] = Field(default=None, description="Table caption") |
|
|
caption_bbox: Optional[BoundingBox] = Field(default=None) |
|
|
|
|
|
|
|
|
confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
|
|
|
|
|
|
evidence: Optional[EvidenceRef] = Field(default=None) |
|
|
|
|
|
def to_markdown(self) -> str: |
|
|
"""Convert table to markdown format.""" |
|
|
if not self.cells: |
|
|
return "" |
|
|
|
|
|
|
|
|
grid = [[None for _ in range(self.num_cols)] for _ in range(self.num_rows)] |
|
|
for cell in self.cells: |
|
|
if cell.row < self.num_rows and cell.col < self.num_cols: |
|
|
grid[cell.row][cell.col] = cell.text |
|
|
|
|
|
|
|
|
lines = [] |
|
|
for i, row in enumerate(grid): |
|
|
line = "| " + " | ".join(str(c) if c else "" for c in row) + " |" |
|
|
lines.append(line) |
|
|
if i == 0 or i in self.header_rows: |
|
|
lines.append("|" + "|".join(["---"] * self.num_cols) + "|") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def to_dict_list(self) -> List[Dict[str, str]]: |
|
|
"""Convert table to list of dictionaries (using first row as keys).""" |
|
|
if not self.cells or self.num_rows < 2: |
|
|
return [] |
|
|
|
|
|
|
|
|
grid = [[None for _ in range(self.num_cols)] for _ in range(self.num_rows)] |
|
|
for cell in self.cells: |
|
|
if cell.row < self.num_rows and cell.col < self.num_cols: |
|
|
grid[cell.row][cell.col] = cell.text |
|
|
|
|
|
|
|
|
headers = [str(h) if h else f"col_{i}" for i, h in enumerate(grid[0])] |
|
|
|
|
|
|
|
|
result = [] |
|
|
for row in grid[1:]: |
|
|
row_dict = {headers[i]: str(v) if v else "" for i, v in enumerate(row)} |
|
|
result.append(row_dict) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class ChartType(str, Enum): |
|
|
"""Types of charts/graphs.""" |
|
|
BAR = "bar" |
|
|
LINE = "line" |
|
|
PIE = "pie" |
|
|
SCATTER = "scatter" |
|
|
AREA = "area" |
|
|
HISTOGRAM = "histogram" |
|
|
BOX = "box" |
|
|
HEATMAP = "heatmap" |
|
|
TREEMAP = "treemap" |
|
|
FLOWCHART = "flowchart" |
|
|
DIAGRAM = "diagram" |
|
|
OTHER = "other" |
|
|
|
|
|
|
|
|
class ChartData(BaseModel): |
|
|
""" |
|
|
Structured chart/graph data extracted from a document. |
|
|
""" |
|
|
chart_id: str = Field(..., description="Unique chart identifier") |
|
|
page: int = Field(..., ge=0, description="Page number") |
|
|
bbox: BoundingBox = Field(..., description="Chart bounding box") |
|
|
chart_type: ChartType = Field(..., description="Type of chart") |
|
|
|
|
|
|
|
|
title: Optional[str] = Field(default=None, description="Chart title") |
|
|
x_axis_label: Optional[str] = Field(default=None, description="X-axis label") |
|
|
y_axis_label: Optional[str] = Field(default=None, description="Y-axis label") |
|
|
|
|
|
|
|
|
series: List[Dict[str, Any]] = Field( |
|
|
default_factory=list, |
|
|
description="Data series extracted from chart" |
|
|
) |
|
|
|
|
|
|
|
|
trends: List[str] = Field( |
|
|
default_factory=list, |
|
|
description="Identified trends or patterns" |
|
|
) |
|
|
|
|
|
|
|
|
caption: Optional[str] = Field(default=None, description="Chart caption") |
|
|
|
|
|
|
|
|
confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
evidence: Optional[EvidenceRef] = Field(default=None) |
|
|
|
|
|
|
|
|
description: Optional[str] = Field( |
|
|
default=None, |
|
|
description="Natural language description of the chart" |
|
|
) |
|
|
|
|
|
|
|
|
class ExtractedField(BaseModel): |
|
|
""" |
|
|
A single extracted field value with evidence. |
|
|
""" |
|
|
field_name: str = Field(..., description="Field name from schema") |
|
|
value: Any = Field(..., description="Extracted value") |
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence") |
|
|
evidence: List[EvidenceRef] = Field( |
|
|
default_factory=list, |
|
|
description="Evidence supporting the extraction" |
|
|
) |
|
|
|
|
|
|
|
|
is_valid: bool = Field(default=True, description="Whether value passed validation") |
|
|
validation_errors: List[str] = Field( |
|
|
default_factory=list, |
|
|
description="Validation error messages" |
|
|
) |
|
|
|
|
|
|
|
|
abstained: bool = Field( |
|
|
default=False, |
|
|
description="Whether extraction was abstained" |
|
|
) |
|
|
abstain_reason: Optional[str] = Field( |
|
|
default=None, |
|
|
description="Reason for abstention" |
|
|
) |
|
|
|
|
|
@property |
|
|
def is_grounded(self) -> bool: |
|
|
"""Check if extraction has evidence.""" |
|
|
return len(self.evidence) > 0 and not self.abstained |
|
|
|