""" Schema Definitions for Field Extraction Pydantic-compatible schemas for defining extraction targets. """ from dataclasses import dataclass, field as dataclass_field from enum import Enum from typing import Any, Callable, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field, create_model class FieldType(str, Enum): """Types of extractable fields.""" STRING = "string" INTEGER = "integer" FLOAT = "float" BOOLEAN = "boolean" DATE = "date" DATETIME = "datetime" CURRENCY = "currency" PERCENTAGE = "percentage" EMAIL = "email" PHONE = "phone" ADDRESS = "address" LIST = "list" OBJECT = "object" @dataclass class FieldSpec: """Specification for a single extraction field.""" name: str field_type: FieldType = FieldType.STRING description: str = "" required: bool = True default: Any = None # Validation pattern: Optional[str] = None # Regex pattern for validation min_value: Optional[float] = None max_value: Optional[float] = None min_length: Optional[int] = None max_length: Optional[int] = None allowed_values: Optional[List[Any]] = None # Nested schema (for OBJECT and LIST types) nested_schema: Optional["ExtractionSchema"] = None list_item_type: Optional[FieldType] = None # Extraction hints aliases: List[str] = dataclass_field(default_factory=list) # Alternative names examples: List[str] = dataclass_field(default_factory=list) # Example values context_hints: List[str] = dataclass_field(default_factory=list) # Where to look # Confidence threshold for this field min_confidence: float = 0.5 def to_json_schema(self) -> Dict[str, Any]: """Convert to JSON Schema format.""" type_mapping = { FieldType.STRING: "string", FieldType.INTEGER: "integer", FieldType.FLOAT: "number", FieldType.BOOLEAN: "boolean", FieldType.DATE: "string", FieldType.DATETIME: "string", FieldType.CURRENCY: "string", FieldType.PERCENTAGE: "string", FieldType.EMAIL: "string", FieldType.PHONE: "string", FieldType.ADDRESS: "string", FieldType.LIST: "array", FieldType.OBJECT: "object", } schema: Dict[str, Any] = { "type": type_mapping.get(self.field_type, "string"), } if self.description: schema["description"] = self.description if self.pattern: schema["pattern"] = self.pattern if self.field_type == FieldType.DATE: schema["format"] = "date" elif self.field_type == FieldType.DATETIME: schema["format"] = "date-time" elif self.field_type == FieldType.EMAIL: schema["format"] = "email" if self.min_value is not None: schema["minimum"] = self.min_value if self.max_value is not None: schema["maximum"] = self.max_value if self.min_length is not None: schema["minLength"] = self.min_length if self.max_length is not None: schema["maxLength"] = self.max_length if self.allowed_values: schema["enum"] = self.allowed_values if self.field_type == FieldType.LIST and self.nested_schema: schema["items"] = self.nested_schema.to_json_schema() elif self.field_type == FieldType.OBJECT and self.nested_schema: schema.update(self.nested_schema.to_json_schema()) return schema @dataclass class ExtractionSchema: """ Schema defining fields to extract from a document. Can be nested for complex document structures. """ name: str description: str = "" fields: List[FieldSpec] = dataclass_field(default_factory=list) # Schema-level settings allow_partial: bool = True # Allow partial extraction abstain_on_low_confidence: bool = True min_overall_confidence: float = 0.5 def add_field(self, field: FieldSpec) -> "ExtractionSchema": """Add a field to the schema.""" self.fields.append(field) return self def add_string_field( self, name: str, description: str = "", required: bool = True, **kwargs ) -> "ExtractionSchema": """Add a string field.""" field = FieldSpec( name=name, field_type=FieldType.STRING, description=description, required=required, **kwargs ) return self.add_field(field) def add_number_field( self, name: str, description: str = "", required: bool = True, is_integer: bool = False, **kwargs ) -> "ExtractionSchema": """Add a number field.""" field = FieldSpec( name=name, field_type=FieldType.INTEGER if is_integer else FieldType.FLOAT, description=description, required=required, **kwargs ) return self.add_field(field) def add_date_field( self, name: str, description: str = "", required: bool = True, **kwargs ) -> "ExtractionSchema": """Add a date field.""" field = FieldSpec( name=name, field_type=FieldType.DATE, description=description, required=required, **kwargs ) return self.add_field(field) def add_currency_field( self, name: str, description: str = "", required: bool = True, **kwargs ) -> "ExtractionSchema": """Add a currency field.""" field = FieldSpec( name=name, field_type=FieldType.CURRENCY, description=description, required=required, **kwargs ) return self.add_field(field) def get_field(self, name: str) -> Optional[FieldSpec]: """Get a field by name.""" for field in self.fields: if field.name == name: return field return None def get_required_fields(self) -> List[FieldSpec]: """Get all required fields.""" return [f for f in self.fields if f.required] def get_optional_fields(self) -> List[FieldSpec]: """Get all optional fields.""" return [f for f in self.fields if not f.required] def to_json_schema(self) -> Dict[str, Any]: """Convert to JSON Schema format.""" properties = {} required = [] for field in self.fields: properties[field.name] = field.to_json_schema() if field.required: required.append(field.name) schema = { "type": "object", "properties": properties, } if required: schema["required"] = required if self.description: schema["description"] = self.description return schema def to_pydantic_model(self) -> Type[BaseModel]: """Generate a Pydantic model from this schema.""" field_definitions = {} for field in self.fields: python_type = self._get_python_type(field.field_type) default = ... if field.required else field.default field_definitions[field.name] = ( python_type, Field(default=default, description=field.description) ) return create_model( self.name, **field_definitions ) def _get_python_type(self, field_type: FieldType) -> type: """Get Python type for field type.""" type_mapping = { FieldType.STRING: str, FieldType.INTEGER: int, FieldType.FLOAT: float, FieldType.BOOLEAN: bool, FieldType.DATE: str, FieldType.DATETIME: str, FieldType.CURRENCY: str, FieldType.PERCENTAGE: str, FieldType.EMAIL: str, FieldType.PHONE: str, FieldType.ADDRESS: str, FieldType.LIST: list, FieldType.OBJECT: dict, } return type_mapping.get(field_type, str) @classmethod def from_json_schema(cls, schema: Dict[str, Any], name: str = "Schema") -> "ExtractionSchema": """Create from JSON Schema.""" extraction_schema = cls( name=name, description=schema.get("description", ""), ) properties = schema.get("properties", {}) required = set(schema.get("required", [])) for field_name, field_schema in properties.items(): field_type = cls._json_type_to_field_type(field_schema) field = FieldSpec( name=field_name, field_type=field_type, description=field_schema.get("description", ""), required=field_name in required, pattern=field_schema.get("pattern"), min_value=field_schema.get("minimum"), max_value=field_schema.get("maximum"), min_length=field_schema.get("minLength"), max_length=field_schema.get("maxLength"), allowed_values=field_schema.get("enum"), ) extraction_schema.add_field(field) return extraction_schema @staticmethod def _json_type_to_field_type(field_schema: Dict[str, Any]) -> FieldType: """Convert JSON Schema type to FieldType.""" json_type = field_schema.get("type", "string") format_ = field_schema.get("format", "") if json_type == "integer": return FieldType.INTEGER elif json_type == "number": return FieldType.FLOAT elif json_type == "boolean": return FieldType.BOOLEAN elif json_type == "array": return FieldType.LIST elif json_type == "object": return FieldType.OBJECT elif format_ == "date": return FieldType.DATE elif format_ == "date-time": return FieldType.DATETIME elif format_ == "email": return FieldType.EMAIL else: return FieldType.STRING # Pre-built schemas for common document types def create_invoice_schema() -> ExtractionSchema: """Create schema for invoice extraction.""" schema = ExtractionSchema( name="Invoice", description="Invoice document extraction schema" ) schema.add_string_field("invoice_number", "Invoice number or ID", required=True) schema.add_date_field("invoice_date", "Date of invoice") schema.add_date_field("due_date", "Payment due date", required=False) schema.add_string_field("vendor_name", "Name of vendor/seller") schema.add_string_field("vendor_address", "Address of vendor", required=False) schema.add_string_field("customer_name", "Name of customer/buyer", required=False) schema.add_string_field("customer_address", "Address of customer", required=False) schema.add_currency_field("subtotal", "Subtotal before tax", required=False) schema.add_currency_field("tax_amount", "Tax amount", required=False) schema.add_currency_field("total_amount", "Total amount due", required=True) schema.add_string_field("currency", "Currency code (USD, EUR, etc.)", required=False) schema.add_string_field("payment_terms", "Payment terms", required=False) return schema def create_receipt_schema() -> ExtractionSchema: """Create schema for receipt extraction.""" schema = ExtractionSchema( name="Receipt", description="Receipt document extraction schema" ) schema.add_string_field("merchant_name", "Name of merchant/store") schema.add_string_field("merchant_address", "Address of merchant", required=False) schema.add_date_field("transaction_date", "Date of transaction") schema.add_string_field("transaction_time", "Time of transaction", required=False) schema.add_currency_field("subtotal", "Subtotal before tax", required=False) schema.add_currency_field("tax_amount", "Tax amount", required=False) schema.add_currency_field("total_amount", "Total amount paid") schema.add_string_field("payment_method", "Method of payment", required=False) schema.add_string_field("last_four_digits", "Last 4 digits of card", required=False) return schema def create_contract_schema() -> ExtractionSchema: """Create schema for contract extraction.""" schema = ExtractionSchema( name="Contract", description="Contract document extraction schema" ) schema.add_string_field("contract_title", "Title of the contract", required=False) schema.add_date_field("effective_date", "Date contract becomes effective") schema.add_date_field("expiration_date", "Date contract expires", required=False) schema.add_string_field("party_a_name", "Name of first party") schema.add_string_field("party_b_name", "Name of second party") schema.add_currency_field("contract_value", "Total contract value", required=False) schema.add_string_field("governing_law", "Governing law/jurisdiction", required=False) schema.add_string_field("termination_clause", "Summary of termination terms", required=False) return schema