""" Extraction Validation Validates extracted data and provides confidence scoring. """ import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple from ..chunks.models import ( ExtractionResult, FieldExtraction, ConfidenceLevel, ) from .schema import ExtractionSchema, FieldSpec, FieldType logger = logging.getLogger(__name__) @dataclass class ValidationIssue: """A validation issue found during extraction validation.""" field_name: str issue_type: str # "missing", "invalid", "low_confidence", "type_mismatch" message: str severity: str = "warning" # "error", "warning", "info" suggested_action: Optional[str] = None @dataclass class ValidationResult: """Result of extraction validation.""" is_valid: bool issues: List[ValidationIssue] = field(default_factory=list) confidence_score: float = 0.0 field_scores: Dict[str, float] = field(default_factory=dict) recommendations: List[str] = field(default_factory=list) @property def error_count(self) -> int: return sum(1 for i in self.issues if i.severity == "error") @property def warning_count(self) -> int: return sum(1 for i in self.issues if i.severity == "warning") def get_issues_for_field(self, field_name: str) -> List[ValidationIssue]: """Get all issues for a specific field.""" return [i for i in self.issues if i.field_name == field_name] class ExtractionValidator: """ Validates extraction results against schemas. Checks for: - Required field presence - Type correctness - Value constraints - Confidence thresholds """ def __init__( self, min_confidence: float = 0.5, strict_mode: bool = False, ): self.min_confidence = min_confidence self.strict_mode = strict_mode def validate( self, extraction: ExtractionResult, schema: ExtractionSchema, ) -> ValidationResult: """ Validate extraction result against schema. Args: extraction: Extraction result to validate schema: Schema defining expected fields Returns: ValidationResult with issues and scores """ issues: List[ValidationIssue] = [] field_scores: Dict[str, float] = {} # Check each field for field_spec in schema.fields: field_issues, score = self._validate_field( field_spec=field_spec, extraction=extraction, ) issues.extend(field_issues) field_scores[field_spec.name] = score # Check for unexpected fields expected_fields = {f.name for f in schema.fields} for field_name in extraction.data.keys(): if field_name not in expected_fields: issues.append(ValidationIssue( field_name=field_name, issue_type="unexpected", message=f"Unexpected field: {field_name}", severity="info", )) # Calculate overall score if field_scores: confidence_score = sum(field_scores.values()) / len(field_scores) else: confidence_score = 0.0 # Determine validity is_valid = ( all(i.severity != "error" for i in issues) and confidence_score >= schema.min_overall_confidence ) # Generate recommendations recommendations = self._generate_recommendations(issues, extraction) return ValidationResult( is_valid=is_valid, issues=issues, confidence_score=confidence_score, field_scores=field_scores, recommendations=recommendations, ) def _validate_field( self, field_spec: FieldSpec, extraction: ExtractionResult, ) -> Tuple[List[ValidationIssue], float]: """Validate a single field.""" issues: List[ValidationIssue] = [] score = 1.0 value = extraction.data.get(field_spec.name) field_extraction = self._get_field_extraction(field_spec.name, extraction) # Check presence if value is None: if field_spec.required: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="missing", message=f"Required field '{field_spec.name}' is missing", severity="error", suggested_action="Manual review required", )) return issues, 0.0 else: return issues, 1.0 # Optional field, OK to be missing # Check abstention if field_spec.name in extraction.abstained_fields: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="abstained", message=f"Field '{field_spec.name}' was abstained due to low confidence", severity="warning", suggested_action="Manual verification recommended", )) score *= 0.5 # Check confidence if field_extraction: if field_extraction.confidence < self.min_confidence: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="low_confidence", message=f"Field '{field_spec.name}' has low confidence: {field_extraction.confidence:.2f}", severity="warning", suggested_action="Manual verification recommended", )) score *= field_extraction.confidence else: score *= field_extraction.confidence # Check type type_issues = self._validate_type(field_spec, value) issues.extend(type_issues) if type_issues: score *= 0.7 # Check constraints constraint_issues = self._validate_constraints(field_spec, value) issues.extend(constraint_issues) if constraint_issues: score *= 0.8 return issues, max(0.0, min(1.0, score)) def _validate_type( self, field_spec: FieldSpec, value: Any, ) -> List[ValidationIssue]: """Validate field type.""" issues = [] expected_type = self._get_expected_python_type(field_spec.field_type) if expected_type and not isinstance(value, expected_type): # Try conversion try: expected_type(value) except (ValueError, TypeError): issues.append(ValidationIssue( field_name=field_spec.name, issue_type="type_mismatch", message=f"Field '{field_spec.name}' expected {field_spec.field_type.value}, got {type(value).__name__}", severity="warning" if not self.strict_mode else "error", )) return issues def _validate_constraints( self, field_spec: FieldSpec, value: Any, ) -> List[ValidationIssue]: """Validate field constraints.""" issues = [] # Pattern if field_spec.pattern: import re if not re.match(field_spec.pattern, str(value)): issues.append(ValidationIssue( field_name=field_spec.name, issue_type="pattern_mismatch", message=f"Field '{field_spec.name}' does not match pattern: {field_spec.pattern}", severity="warning", )) # Range try: num_value = float(value) if field_spec.min_value is not None and num_value < field_spec.min_value: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="below_minimum", message=f"Field '{field_spec.name}' value {num_value} is below minimum {field_spec.min_value}", severity="warning", )) if field_spec.max_value is not None and num_value > field_spec.max_value: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="above_maximum", message=f"Field '{field_spec.name}' value {num_value} is above maximum {field_spec.max_value}", severity="warning", )) except (ValueError, TypeError): pass # Length str_value = str(value) if field_spec.min_length is not None and len(str_value) < field_spec.min_length: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="too_short", message=f"Field '{field_spec.name}' is too short: {len(str_value)} < {field_spec.min_length}", severity="warning", )) if field_spec.max_length is not None and len(str_value) > field_spec.max_length: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="too_long", message=f"Field '{field_spec.name}' is too long: {len(str_value)} > {field_spec.max_length}", severity="warning", )) # Allowed values if field_spec.allowed_values and value not in field_spec.allowed_values: issues.append(ValidationIssue( field_name=field_spec.name, issue_type="not_in_allowed", message=f"Field '{field_spec.name}' value '{value}' not in allowed values", severity="warning", )) return issues def _get_field_extraction( self, field_name: str, extraction: ExtractionResult, ) -> Optional[FieldExtraction]: """Get field extraction by name.""" for fe in extraction.fields: if fe.field_name == field_name: return fe return None def _get_expected_python_type(self, field_type: FieldType) -> Optional[type]: """Get expected Python type for field type.""" type_map = { FieldType.INTEGER: int, FieldType.FLOAT: float, FieldType.BOOLEAN: bool, FieldType.LIST: list, FieldType.OBJECT: dict, } return type_map.get(field_type) def _generate_recommendations( self, issues: List[ValidationIssue], extraction: ExtractionResult, ) -> List[str]: """Generate recommendations based on issues.""" recommendations = [] # Count issue types missing_count = sum(1 for i in issues if i.issue_type == "missing") low_conf_count = sum(1 for i in issues if i.issue_type == "low_confidence") type_count = sum(1 for i in issues if i.issue_type == "type_mismatch") if missing_count > 0: recommendations.append( f"Review document for {missing_count} missing required field(s)" ) if low_conf_count > 0: recommendations.append( f"Manual verification recommended for {low_conf_count} low-confidence field(s)" ) if type_count > 0: recommendations.append( f"Check data types for {type_count} field(s) with type mismatches" ) if extraction.overall_confidence < 0.5: recommendations.append( "Overall extraction confidence is low - consider manual review" ) if len(extraction.abstained_fields) > 0: recommendations.append( f"System abstained on {len(extraction.abstained_fields)} field(s) due to uncertainty" ) return recommendations class CrossFieldValidator: """ Validates relationships between fields. Checks for: - Consistency (e.g., subtotal + tax = total) - Logical relationships - Date ordering """ def validate_consistency( self, extraction: ExtractionResult, rules: List[Dict[str, Any]], ) -> List[ValidationIssue]: """ Validate cross-field consistency rules. Rules format: { "type": "sum", "fields": ["subtotal", "tax"], "equals": "total", "tolerance": 0.01 } """ issues = [] for rule in rules: rule_type = rule.get("type") if rule_type == "sum": issue = self._validate_sum_rule(extraction, rule) if issue: issues.append(issue) elif rule_type == "date_order": issue = self._validate_date_order(extraction, rule) if issue: issues.append(issue) elif rule_type == "required_if": issue = self._validate_required_if(extraction, rule) if issue: issues.append(issue) return issues def _validate_sum_rule( self, extraction: ExtractionResult, rule: Dict[str, Any], ) -> Optional[ValidationIssue]: """Validate that sum of fields equals another field.""" fields = rule.get("fields", []) equals_field = rule.get("equals") tolerance = rule.get("tolerance", 0.01) try: sum_value = sum( float(extraction.data.get(f, 0) or 0) for f in fields ) expected = float(extraction.data.get(equals_field, 0) or 0) if abs(sum_value - expected) > tolerance: return ValidationIssue( field_name=equals_field, issue_type="sum_mismatch", message=f"Sum of {fields} ({sum_value}) does not equal {equals_field} ({expected})", severity="warning", ) except (ValueError, TypeError): pass return None def _validate_date_order( self, extraction: ExtractionResult, rule: Dict[str, Any], ) -> Optional[ValidationIssue]: """Validate that dates are in correct order.""" from datetime import datetime before_field = rule.get("before") after_field = rule.get("after") before_val = extraction.data.get(before_field) after_val = extraction.data.get(after_field) if not before_val or not after_val: return None try: # Try common date formats formats = ["%Y-%m-%d", "%m/%d/%Y", "%d/%m/%Y", "%B %d, %Y"] before_date = None after_date = None for fmt in formats: try: before_date = datetime.strptime(str(before_val), fmt) break except ValueError: continue for fmt in formats: try: after_date = datetime.strptime(str(after_val), fmt) break except ValueError: continue if before_date and after_date and before_date > after_date: return ValidationIssue( field_name=after_field, issue_type="date_order", message=f"Date {before_field} ({before_val}) should be before {after_field} ({after_val})", severity="warning", ) except Exception: pass return None def _validate_required_if( self, extraction: ExtractionResult, rule: Dict[str, Any], ) -> Optional[ValidationIssue]: """Validate conditional required fields.""" field = rule.get("field") required_if = rule.get("required_if") # Field that must exist condition_value = rule.get("value") # Optional specific value condition_field_value = extraction.data.get(required_if) # Check if condition is met condition_met = False if condition_value is not None: condition_met = condition_field_value == condition_value else: condition_met = condition_field_value is not None if condition_met: field_value = extraction.data.get(field) if field_value is None: return ValidationIssue( field_name=field, issue_type="conditional_required", message=f"Field '{field}' is required when '{required_if}' is present", severity="warning", ) return None