MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Schema Definitions for Field Extraction
Pydantic-compatible schemas for defining extraction targets.
"""
from dataclasses import dataclass, field as dataclass_field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Type, Union
from pydantic import BaseModel, Field, create_model
class FieldType(str, Enum):
"""Types of extractable fields."""
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
BOOLEAN = "boolean"
DATE = "date"
DATETIME = "datetime"
CURRENCY = "currency"
PERCENTAGE = "percentage"
EMAIL = "email"
PHONE = "phone"
ADDRESS = "address"
LIST = "list"
OBJECT = "object"
@dataclass
class FieldSpec:
"""Specification for a single extraction field."""
name: str
field_type: FieldType = FieldType.STRING
description: str = ""
required: bool = True
default: Any = None
# Validation
pattern: Optional[str] = None # Regex pattern for validation
min_value: Optional[float] = None
max_value: Optional[float] = None
min_length: Optional[int] = None
max_length: Optional[int] = None
allowed_values: Optional[List[Any]] = None
# Nested schema (for OBJECT and LIST types)
nested_schema: Optional["ExtractionSchema"] = None
list_item_type: Optional[FieldType] = None
# Extraction hints
aliases: List[str] = dataclass_field(default_factory=list) # Alternative names
examples: List[str] = dataclass_field(default_factory=list) # Example values
context_hints: List[str] = dataclass_field(default_factory=list) # Where to look
# Confidence threshold for this field
min_confidence: float = 0.5
def to_json_schema(self) -> Dict[str, Any]:
"""Convert to JSON Schema format."""
type_mapping = {
FieldType.STRING: "string",
FieldType.INTEGER: "integer",
FieldType.FLOAT: "number",
FieldType.BOOLEAN: "boolean",
FieldType.DATE: "string",
FieldType.DATETIME: "string",
FieldType.CURRENCY: "string",
FieldType.PERCENTAGE: "string",
FieldType.EMAIL: "string",
FieldType.PHONE: "string",
FieldType.ADDRESS: "string",
FieldType.LIST: "array",
FieldType.OBJECT: "object",
}
schema: Dict[str, Any] = {
"type": type_mapping.get(self.field_type, "string"),
}
if self.description:
schema["description"] = self.description
if self.pattern:
schema["pattern"] = self.pattern
if self.field_type == FieldType.DATE:
schema["format"] = "date"
elif self.field_type == FieldType.DATETIME:
schema["format"] = "date-time"
elif self.field_type == FieldType.EMAIL:
schema["format"] = "email"
if self.min_value is not None:
schema["minimum"] = self.min_value
if self.max_value is not None:
schema["maximum"] = self.max_value
if self.min_length is not None:
schema["minLength"] = self.min_length
if self.max_length is not None:
schema["maxLength"] = self.max_length
if self.allowed_values:
schema["enum"] = self.allowed_values
if self.field_type == FieldType.LIST and self.nested_schema:
schema["items"] = self.nested_schema.to_json_schema()
elif self.field_type == FieldType.OBJECT and self.nested_schema:
schema.update(self.nested_schema.to_json_schema())
return schema
@dataclass
class ExtractionSchema:
"""
Schema defining fields to extract from a document.
Can be nested for complex document structures.
"""
name: str
description: str = ""
fields: List[FieldSpec] = dataclass_field(default_factory=list)
# Schema-level settings
allow_partial: bool = True # Allow partial extraction
abstain_on_low_confidence: bool = True
min_overall_confidence: float = 0.5
def add_field(self, field: FieldSpec) -> "ExtractionSchema":
"""Add a field to the schema."""
self.fields.append(field)
return self
def add_string_field(
self,
name: str,
description: str = "",
required: bool = True,
**kwargs
) -> "ExtractionSchema":
"""Add a string field."""
field = FieldSpec(
name=name,
field_type=FieldType.STRING,
description=description,
required=required,
**kwargs
)
return self.add_field(field)
def add_number_field(
self,
name: str,
description: str = "",
required: bool = True,
is_integer: bool = False,
**kwargs
) -> "ExtractionSchema":
"""Add a number field."""
field = FieldSpec(
name=name,
field_type=FieldType.INTEGER if is_integer else FieldType.FLOAT,
description=description,
required=required,
**kwargs
)
return self.add_field(field)
def add_date_field(
self,
name: str,
description: str = "",
required: bool = True,
**kwargs
) -> "ExtractionSchema":
"""Add a date field."""
field = FieldSpec(
name=name,
field_type=FieldType.DATE,
description=description,
required=required,
**kwargs
)
return self.add_field(field)
def add_currency_field(
self,
name: str,
description: str = "",
required: bool = True,
**kwargs
) -> "ExtractionSchema":
"""Add a currency field."""
field = FieldSpec(
name=name,
field_type=FieldType.CURRENCY,
description=description,
required=required,
**kwargs
)
return self.add_field(field)
def get_field(self, name: str) -> Optional[FieldSpec]:
"""Get a field by name."""
for field in self.fields:
if field.name == name:
return field
return None
def get_required_fields(self) -> List[FieldSpec]:
"""Get all required fields."""
return [f for f in self.fields if f.required]
def get_optional_fields(self) -> List[FieldSpec]:
"""Get all optional fields."""
return [f for f in self.fields if not f.required]
def to_json_schema(self) -> Dict[str, Any]:
"""Convert to JSON Schema format."""
properties = {}
required = []
for field in self.fields:
properties[field.name] = field.to_json_schema()
if field.required:
required.append(field.name)
schema = {
"type": "object",
"properties": properties,
}
if required:
schema["required"] = required
if self.description:
schema["description"] = self.description
return schema
def to_pydantic_model(self) -> Type[BaseModel]:
"""Generate a Pydantic model from this schema."""
field_definitions = {}
for field in self.fields:
python_type = self._get_python_type(field.field_type)
default = ... if field.required else field.default
field_definitions[field.name] = (
python_type,
Field(default=default, description=field.description)
)
return create_model(
self.name,
**field_definitions
)
def _get_python_type(self, field_type: FieldType) -> type:
"""Get Python type for field type."""
type_mapping = {
FieldType.STRING: str,
FieldType.INTEGER: int,
FieldType.FLOAT: float,
FieldType.BOOLEAN: bool,
FieldType.DATE: str,
FieldType.DATETIME: str,
FieldType.CURRENCY: str,
FieldType.PERCENTAGE: str,
FieldType.EMAIL: str,
FieldType.PHONE: str,
FieldType.ADDRESS: str,
FieldType.LIST: list,
FieldType.OBJECT: dict,
}
return type_mapping.get(field_type, str)
@classmethod
def from_json_schema(cls, schema: Dict[str, Any], name: str = "Schema") -> "ExtractionSchema":
"""Create from JSON Schema."""
extraction_schema = cls(
name=name,
description=schema.get("description", ""),
)
properties = schema.get("properties", {})
required = set(schema.get("required", []))
for field_name, field_schema in properties.items():
field_type = cls._json_type_to_field_type(field_schema)
field = FieldSpec(
name=field_name,
field_type=field_type,
description=field_schema.get("description", ""),
required=field_name in required,
pattern=field_schema.get("pattern"),
min_value=field_schema.get("minimum"),
max_value=field_schema.get("maximum"),
min_length=field_schema.get("minLength"),
max_length=field_schema.get("maxLength"),
allowed_values=field_schema.get("enum"),
)
extraction_schema.add_field(field)
return extraction_schema
@staticmethod
def _json_type_to_field_type(field_schema: Dict[str, Any]) -> FieldType:
"""Convert JSON Schema type to FieldType."""
json_type = field_schema.get("type", "string")
format_ = field_schema.get("format", "")
if json_type == "integer":
return FieldType.INTEGER
elif json_type == "number":
return FieldType.FLOAT
elif json_type == "boolean":
return FieldType.BOOLEAN
elif json_type == "array":
return FieldType.LIST
elif json_type == "object":
return FieldType.OBJECT
elif format_ == "date":
return FieldType.DATE
elif format_ == "date-time":
return FieldType.DATETIME
elif format_ == "email":
return FieldType.EMAIL
else:
return FieldType.STRING
# Pre-built schemas for common document types
def create_invoice_schema() -> ExtractionSchema:
"""Create schema for invoice extraction."""
schema = ExtractionSchema(
name="Invoice",
description="Invoice document extraction schema"
)
schema.add_string_field("invoice_number", "Invoice number or ID", required=True)
schema.add_date_field("invoice_date", "Date of invoice")
schema.add_date_field("due_date", "Payment due date", required=False)
schema.add_string_field("vendor_name", "Name of vendor/seller")
schema.add_string_field("vendor_address", "Address of vendor", required=False)
schema.add_string_field("customer_name", "Name of customer/buyer", required=False)
schema.add_string_field("customer_address", "Address of customer", required=False)
schema.add_currency_field("subtotal", "Subtotal before tax", required=False)
schema.add_currency_field("tax_amount", "Tax amount", required=False)
schema.add_currency_field("total_amount", "Total amount due", required=True)
schema.add_string_field("currency", "Currency code (USD, EUR, etc.)", required=False)
schema.add_string_field("payment_terms", "Payment terms", required=False)
return schema
def create_receipt_schema() -> ExtractionSchema:
"""Create schema for receipt extraction."""
schema = ExtractionSchema(
name="Receipt",
description="Receipt document extraction schema"
)
schema.add_string_field("merchant_name", "Name of merchant/store")
schema.add_string_field("merchant_address", "Address of merchant", required=False)
schema.add_date_field("transaction_date", "Date of transaction")
schema.add_string_field("transaction_time", "Time of transaction", required=False)
schema.add_currency_field("subtotal", "Subtotal before tax", required=False)
schema.add_currency_field("tax_amount", "Tax amount", required=False)
schema.add_currency_field("total_amount", "Total amount paid")
schema.add_string_field("payment_method", "Method of payment", required=False)
schema.add_string_field("last_four_digits", "Last 4 digits of card", required=False)
return schema
def create_contract_schema() -> ExtractionSchema:
"""Create schema for contract extraction."""
schema = ExtractionSchema(
name="Contract",
description="Contract document extraction schema"
)
schema.add_string_field("contract_title", "Title of the contract", required=False)
schema.add_date_field("effective_date", "Date contract becomes effective")
schema.add_date_field("expiration_date", "Date contract expires", required=False)
schema.add_string_field("party_a_name", "Name of first party")
schema.add_string_field("party_b_name", "Name of second party")
schema.add_currency_field("contract_value", "Total contract value", required=False)
schema.add_string_field("governing_law", "Governing law/jurisdiction", required=False)
schema.add_string_field("termination_clause", "Summary of termination terms", required=False)
return schema