|
|
""" |
|
|
Schema Definitions for Field Extraction |
|
|
|
|
|
Pydantic-compatible schemas for defining extraction targets. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field as dataclass_field |
|
|
from enum import Enum |
|
|
from typing import Any, Callable, Dict, List, Optional, Type, Union |
|
|
|
|
|
from pydantic import BaseModel, Field, create_model |
|
|
|
|
|
|
|
|
class FieldType(str, Enum): |
|
|
"""Types of extractable fields.""" |
|
|
|
|
|
STRING = "string" |
|
|
INTEGER = "integer" |
|
|
FLOAT = "float" |
|
|
BOOLEAN = "boolean" |
|
|
DATE = "date" |
|
|
DATETIME = "datetime" |
|
|
CURRENCY = "currency" |
|
|
PERCENTAGE = "percentage" |
|
|
EMAIL = "email" |
|
|
PHONE = "phone" |
|
|
ADDRESS = "address" |
|
|
LIST = "list" |
|
|
OBJECT = "object" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FieldSpec: |
|
|
"""Specification for a single extraction field.""" |
|
|
|
|
|
name: str |
|
|
field_type: FieldType = FieldType.STRING |
|
|
description: str = "" |
|
|
required: bool = True |
|
|
default: Any = None |
|
|
|
|
|
|
|
|
pattern: Optional[str] = None |
|
|
min_value: Optional[float] = None |
|
|
max_value: Optional[float] = None |
|
|
min_length: Optional[int] = None |
|
|
max_length: Optional[int] = None |
|
|
allowed_values: Optional[List[Any]] = None |
|
|
|
|
|
|
|
|
nested_schema: Optional["ExtractionSchema"] = None |
|
|
list_item_type: Optional[FieldType] = None |
|
|
|
|
|
|
|
|
aliases: List[str] = dataclass_field(default_factory=list) |
|
|
examples: List[str] = dataclass_field(default_factory=list) |
|
|
context_hints: List[str] = dataclass_field(default_factory=list) |
|
|
|
|
|
|
|
|
min_confidence: float = 0.5 |
|
|
|
|
|
def to_json_schema(self) -> Dict[str, Any]: |
|
|
"""Convert to JSON Schema format.""" |
|
|
type_mapping = { |
|
|
FieldType.STRING: "string", |
|
|
FieldType.INTEGER: "integer", |
|
|
FieldType.FLOAT: "number", |
|
|
FieldType.BOOLEAN: "boolean", |
|
|
FieldType.DATE: "string", |
|
|
FieldType.DATETIME: "string", |
|
|
FieldType.CURRENCY: "string", |
|
|
FieldType.PERCENTAGE: "string", |
|
|
FieldType.EMAIL: "string", |
|
|
FieldType.PHONE: "string", |
|
|
FieldType.ADDRESS: "string", |
|
|
FieldType.LIST: "array", |
|
|
FieldType.OBJECT: "object", |
|
|
} |
|
|
|
|
|
schema: Dict[str, Any] = { |
|
|
"type": type_mapping.get(self.field_type, "string"), |
|
|
} |
|
|
|
|
|
if self.description: |
|
|
schema["description"] = self.description |
|
|
|
|
|
if self.pattern: |
|
|
schema["pattern"] = self.pattern |
|
|
|
|
|
if self.field_type == FieldType.DATE: |
|
|
schema["format"] = "date" |
|
|
elif self.field_type == FieldType.DATETIME: |
|
|
schema["format"] = "date-time" |
|
|
elif self.field_type == FieldType.EMAIL: |
|
|
schema["format"] = "email" |
|
|
|
|
|
if self.min_value is not None: |
|
|
schema["minimum"] = self.min_value |
|
|
if self.max_value is not None: |
|
|
schema["maximum"] = self.max_value |
|
|
if self.min_length is not None: |
|
|
schema["minLength"] = self.min_length |
|
|
if self.max_length is not None: |
|
|
schema["maxLength"] = self.max_length |
|
|
if self.allowed_values: |
|
|
schema["enum"] = self.allowed_values |
|
|
|
|
|
if self.field_type == FieldType.LIST and self.nested_schema: |
|
|
schema["items"] = self.nested_schema.to_json_schema() |
|
|
elif self.field_type == FieldType.OBJECT and self.nested_schema: |
|
|
schema.update(self.nested_schema.to_json_schema()) |
|
|
|
|
|
return schema |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExtractionSchema: |
|
|
""" |
|
|
Schema defining fields to extract from a document. |
|
|
|
|
|
Can be nested for complex document structures. |
|
|
""" |
|
|
|
|
|
name: str |
|
|
description: str = "" |
|
|
fields: List[FieldSpec] = dataclass_field(default_factory=list) |
|
|
|
|
|
|
|
|
allow_partial: bool = True |
|
|
abstain_on_low_confidence: bool = True |
|
|
min_overall_confidence: float = 0.5 |
|
|
|
|
|
def add_field(self, field: FieldSpec) -> "ExtractionSchema": |
|
|
"""Add a field to the schema.""" |
|
|
self.fields.append(field) |
|
|
return self |
|
|
|
|
|
def add_string_field( |
|
|
self, |
|
|
name: str, |
|
|
description: str = "", |
|
|
required: bool = True, |
|
|
**kwargs |
|
|
) -> "ExtractionSchema": |
|
|
"""Add a string field.""" |
|
|
field = FieldSpec( |
|
|
name=name, |
|
|
field_type=FieldType.STRING, |
|
|
description=description, |
|
|
required=required, |
|
|
**kwargs |
|
|
) |
|
|
return self.add_field(field) |
|
|
|
|
|
def add_number_field( |
|
|
self, |
|
|
name: str, |
|
|
description: str = "", |
|
|
required: bool = True, |
|
|
is_integer: bool = False, |
|
|
**kwargs |
|
|
) -> "ExtractionSchema": |
|
|
"""Add a number field.""" |
|
|
field = FieldSpec( |
|
|
name=name, |
|
|
field_type=FieldType.INTEGER if is_integer else FieldType.FLOAT, |
|
|
description=description, |
|
|
required=required, |
|
|
**kwargs |
|
|
) |
|
|
return self.add_field(field) |
|
|
|
|
|
def add_date_field( |
|
|
self, |
|
|
name: str, |
|
|
description: str = "", |
|
|
required: bool = True, |
|
|
**kwargs |
|
|
) -> "ExtractionSchema": |
|
|
"""Add a date field.""" |
|
|
field = FieldSpec( |
|
|
name=name, |
|
|
field_type=FieldType.DATE, |
|
|
description=description, |
|
|
required=required, |
|
|
**kwargs |
|
|
) |
|
|
return self.add_field(field) |
|
|
|
|
|
def add_currency_field( |
|
|
self, |
|
|
name: str, |
|
|
description: str = "", |
|
|
required: bool = True, |
|
|
**kwargs |
|
|
) -> "ExtractionSchema": |
|
|
"""Add a currency field.""" |
|
|
field = FieldSpec( |
|
|
name=name, |
|
|
field_type=FieldType.CURRENCY, |
|
|
description=description, |
|
|
required=required, |
|
|
**kwargs |
|
|
) |
|
|
return self.add_field(field) |
|
|
|
|
|
def get_field(self, name: str) -> Optional[FieldSpec]: |
|
|
"""Get a field by name.""" |
|
|
for field in self.fields: |
|
|
if field.name == name: |
|
|
return field |
|
|
return None |
|
|
|
|
|
def get_required_fields(self) -> List[FieldSpec]: |
|
|
"""Get all required fields.""" |
|
|
return [f for f in self.fields if f.required] |
|
|
|
|
|
def get_optional_fields(self) -> List[FieldSpec]: |
|
|
"""Get all optional fields.""" |
|
|
return [f for f in self.fields if not f.required] |
|
|
|
|
|
def to_json_schema(self) -> Dict[str, Any]: |
|
|
"""Convert to JSON Schema format.""" |
|
|
properties = {} |
|
|
required = [] |
|
|
|
|
|
for field in self.fields: |
|
|
properties[field.name] = field.to_json_schema() |
|
|
if field.required: |
|
|
required.append(field.name) |
|
|
|
|
|
schema = { |
|
|
"type": "object", |
|
|
"properties": properties, |
|
|
} |
|
|
|
|
|
if required: |
|
|
schema["required"] = required |
|
|
|
|
|
if self.description: |
|
|
schema["description"] = self.description |
|
|
|
|
|
return schema |
|
|
|
|
|
def to_pydantic_model(self) -> Type[BaseModel]: |
|
|
"""Generate a Pydantic model from this schema.""" |
|
|
field_definitions = {} |
|
|
|
|
|
for field in self.fields: |
|
|
python_type = self._get_python_type(field.field_type) |
|
|
default = ... if field.required else field.default |
|
|
|
|
|
field_definitions[field.name] = ( |
|
|
python_type, |
|
|
Field(default=default, description=field.description) |
|
|
) |
|
|
|
|
|
return create_model( |
|
|
self.name, |
|
|
**field_definitions |
|
|
) |
|
|
|
|
|
def _get_python_type(self, field_type: FieldType) -> type: |
|
|
"""Get Python type for field type.""" |
|
|
type_mapping = { |
|
|
FieldType.STRING: str, |
|
|
FieldType.INTEGER: int, |
|
|
FieldType.FLOAT: float, |
|
|
FieldType.BOOLEAN: bool, |
|
|
FieldType.DATE: str, |
|
|
FieldType.DATETIME: str, |
|
|
FieldType.CURRENCY: str, |
|
|
FieldType.PERCENTAGE: str, |
|
|
FieldType.EMAIL: str, |
|
|
FieldType.PHONE: str, |
|
|
FieldType.ADDRESS: str, |
|
|
FieldType.LIST: list, |
|
|
FieldType.OBJECT: dict, |
|
|
} |
|
|
return type_mapping.get(field_type, str) |
|
|
|
|
|
@classmethod |
|
|
def from_json_schema(cls, schema: Dict[str, Any], name: str = "Schema") -> "ExtractionSchema": |
|
|
"""Create from JSON Schema.""" |
|
|
extraction_schema = cls( |
|
|
name=name, |
|
|
description=schema.get("description", ""), |
|
|
) |
|
|
|
|
|
properties = schema.get("properties", {}) |
|
|
required = set(schema.get("required", [])) |
|
|
|
|
|
for field_name, field_schema in properties.items(): |
|
|
field_type = cls._json_type_to_field_type(field_schema) |
|
|
|
|
|
field = FieldSpec( |
|
|
name=field_name, |
|
|
field_type=field_type, |
|
|
description=field_schema.get("description", ""), |
|
|
required=field_name in required, |
|
|
pattern=field_schema.get("pattern"), |
|
|
min_value=field_schema.get("minimum"), |
|
|
max_value=field_schema.get("maximum"), |
|
|
min_length=field_schema.get("minLength"), |
|
|
max_length=field_schema.get("maxLength"), |
|
|
allowed_values=field_schema.get("enum"), |
|
|
) |
|
|
|
|
|
extraction_schema.add_field(field) |
|
|
|
|
|
return extraction_schema |
|
|
|
|
|
@staticmethod |
|
|
def _json_type_to_field_type(field_schema: Dict[str, Any]) -> FieldType: |
|
|
"""Convert JSON Schema type to FieldType.""" |
|
|
json_type = field_schema.get("type", "string") |
|
|
format_ = field_schema.get("format", "") |
|
|
|
|
|
if json_type == "integer": |
|
|
return FieldType.INTEGER |
|
|
elif json_type == "number": |
|
|
return FieldType.FLOAT |
|
|
elif json_type == "boolean": |
|
|
return FieldType.BOOLEAN |
|
|
elif json_type == "array": |
|
|
return FieldType.LIST |
|
|
elif json_type == "object": |
|
|
return FieldType.OBJECT |
|
|
elif format_ == "date": |
|
|
return FieldType.DATE |
|
|
elif format_ == "date-time": |
|
|
return FieldType.DATETIME |
|
|
elif format_ == "email": |
|
|
return FieldType.EMAIL |
|
|
else: |
|
|
return FieldType.STRING |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_invoice_schema() -> ExtractionSchema: |
|
|
"""Create schema for invoice extraction.""" |
|
|
schema = ExtractionSchema( |
|
|
name="Invoice", |
|
|
description="Invoice document extraction schema" |
|
|
) |
|
|
|
|
|
schema.add_string_field("invoice_number", "Invoice number or ID", required=True) |
|
|
schema.add_date_field("invoice_date", "Date of invoice") |
|
|
schema.add_date_field("due_date", "Payment due date", required=False) |
|
|
schema.add_string_field("vendor_name", "Name of vendor/seller") |
|
|
schema.add_string_field("vendor_address", "Address of vendor", required=False) |
|
|
schema.add_string_field("customer_name", "Name of customer/buyer", required=False) |
|
|
schema.add_string_field("customer_address", "Address of customer", required=False) |
|
|
schema.add_currency_field("subtotal", "Subtotal before tax", required=False) |
|
|
schema.add_currency_field("tax_amount", "Tax amount", required=False) |
|
|
schema.add_currency_field("total_amount", "Total amount due", required=True) |
|
|
schema.add_string_field("currency", "Currency code (USD, EUR, etc.)", required=False) |
|
|
schema.add_string_field("payment_terms", "Payment terms", required=False) |
|
|
|
|
|
return schema |
|
|
|
|
|
|
|
|
def create_receipt_schema() -> ExtractionSchema: |
|
|
"""Create schema for receipt extraction.""" |
|
|
schema = ExtractionSchema( |
|
|
name="Receipt", |
|
|
description="Receipt document extraction schema" |
|
|
) |
|
|
|
|
|
schema.add_string_field("merchant_name", "Name of merchant/store") |
|
|
schema.add_string_field("merchant_address", "Address of merchant", required=False) |
|
|
schema.add_date_field("transaction_date", "Date of transaction") |
|
|
schema.add_string_field("transaction_time", "Time of transaction", required=False) |
|
|
schema.add_currency_field("subtotal", "Subtotal before tax", required=False) |
|
|
schema.add_currency_field("tax_amount", "Tax amount", required=False) |
|
|
schema.add_currency_field("total_amount", "Total amount paid") |
|
|
schema.add_string_field("payment_method", "Method of payment", required=False) |
|
|
schema.add_string_field("last_four_digits", "Last 4 digits of card", required=False) |
|
|
|
|
|
return schema |
|
|
|
|
|
|
|
|
def create_contract_schema() -> ExtractionSchema: |
|
|
"""Create schema for contract extraction.""" |
|
|
schema = ExtractionSchema( |
|
|
name="Contract", |
|
|
description="Contract document extraction schema" |
|
|
) |
|
|
|
|
|
schema.add_string_field("contract_title", "Title of the contract", required=False) |
|
|
schema.add_date_field("effective_date", "Date contract becomes effective") |
|
|
schema.add_date_field("expiration_date", "Date contract expires", required=False) |
|
|
schema.add_string_field("party_a_name", "Name of first party") |
|
|
schema.add_string_field("party_b_name", "Name of second party") |
|
|
schema.add_currency_field("contract_value", "Total contract value", required=False) |
|
|
schema.add_string_field("governing_law", "Governing law/jurisdiction", required=False) |
|
|
schema.add_string_field("termination_clause", "Summary of termination terms", required=False) |
|
|
|
|
|
return schema |
|
|
|