|
|
""" |
|
|
Extraction Validation |
|
|
|
|
|
Validates extracted data and provides confidence scoring. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from ..chunks.models import ( |
|
|
ExtractionResult, |
|
|
FieldExtraction, |
|
|
ConfidenceLevel, |
|
|
) |
|
|
from .schema import ExtractionSchema, FieldSpec, FieldType |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ValidationIssue: |
|
|
"""A validation issue found during extraction validation.""" |
|
|
|
|
|
field_name: str |
|
|
issue_type: str |
|
|
message: str |
|
|
severity: str = "warning" |
|
|
suggested_action: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ValidationResult: |
|
|
"""Result of extraction validation.""" |
|
|
|
|
|
is_valid: bool |
|
|
issues: List[ValidationIssue] = field(default_factory=list) |
|
|
confidence_score: float = 0.0 |
|
|
field_scores: Dict[str, float] = field(default_factory=dict) |
|
|
recommendations: List[str] = field(default_factory=list) |
|
|
|
|
|
@property |
|
|
def error_count(self) -> int: |
|
|
return sum(1 for i in self.issues if i.severity == "error") |
|
|
|
|
|
@property |
|
|
def warning_count(self) -> int: |
|
|
return sum(1 for i in self.issues if i.severity == "warning") |
|
|
|
|
|
def get_issues_for_field(self, field_name: str) -> List[ValidationIssue]: |
|
|
"""Get all issues for a specific field.""" |
|
|
return [i for i in self.issues if i.field_name == field_name] |
|
|
|
|
|
|
|
|
class ExtractionValidator: |
|
|
""" |
|
|
Validates extraction results against schemas. |
|
|
|
|
|
Checks for: |
|
|
- Required field presence |
|
|
- Type correctness |
|
|
- Value constraints |
|
|
- Confidence thresholds |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
min_confidence: float = 0.5, |
|
|
strict_mode: bool = False, |
|
|
): |
|
|
self.min_confidence = min_confidence |
|
|
self.strict_mode = strict_mode |
|
|
|
|
|
def validate( |
|
|
self, |
|
|
extraction: ExtractionResult, |
|
|
schema: ExtractionSchema, |
|
|
) -> ValidationResult: |
|
|
""" |
|
|
Validate extraction result against schema. |
|
|
|
|
|
Args: |
|
|
extraction: Extraction result to validate |
|
|
schema: Schema defining expected fields |
|
|
|
|
|
Returns: |
|
|
ValidationResult with issues and scores |
|
|
""" |
|
|
issues: List[ValidationIssue] = [] |
|
|
field_scores: Dict[str, float] = {} |
|
|
|
|
|
|
|
|
for field_spec in schema.fields: |
|
|
field_issues, score = self._validate_field( |
|
|
field_spec=field_spec, |
|
|
extraction=extraction, |
|
|
) |
|
|
issues.extend(field_issues) |
|
|
field_scores[field_spec.name] = score |
|
|
|
|
|
|
|
|
expected_fields = {f.name for f in schema.fields} |
|
|
for field_name in extraction.data.keys(): |
|
|
if field_name not in expected_fields: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_name, |
|
|
issue_type="unexpected", |
|
|
message=f"Unexpected field: {field_name}", |
|
|
severity="info", |
|
|
)) |
|
|
|
|
|
|
|
|
if field_scores: |
|
|
confidence_score = sum(field_scores.values()) / len(field_scores) |
|
|
else: |
|
|
confidence_score = 0.0 |
|
|
|
|
|
|
|
|
is_valid = ( |
|
|
all(i.severity != "error" for i in issues) and |
|
|
confidence_score >= schema.min_overall_confidence |
|
|
) |
|
|
|
|
|
|
|
|
recommendations = self._generate_recommendations(issues, extraction) |
|
|
|
|
|
return ValidationResult( |
|
|
is_valid=is_valid, |
|
|
issues=issues, |
|
|
confidence_score=confidence_score, |
|
|
field_scores=field_scores, |
|
|
recommendations=recommendations, |
|
|
) |
|
|
|
|
|
def _validate_field( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
extraction: ExtractionResult, |
|
|
) -> Tuple[List[ValidationIssue], float]: |
|
|
"""Validate a single field.""" |
|
|
issues: List[ValidationIssue] = [] |
|
|
score = 1.0 |
|
|
|
|
|
value = extraction.data.get(field_spec.name) |
|
|
field_extraction = self._get_field_extraction(field_spec.name, extraction) |
|
|
|
|
|
|
|
|
if value is None: |
|
|
if field_spec.required: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="missing", |
|
|
message=f"Required field '{field_spec.name}' is missing", |
|
|
severity="error", |
|
|
suggested_action="Manual review required", |
|
|
)) |
|
|
return issues, 0.0 |
|
|
else: |
|
|
return issues, 1.0 |
|
|
|
|
|
|
|
|
if field_spec.name in extraction.abstained_fields: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="abstained", |
|
|
message=f"Field '{field_spec.name}' was abstained due to low confidence", |
|
|
severity="warning", |
|
|
suggested_action="Manual verification recommended", |
|
|
)) |
|
|
score *= 0.5 |
|
|
|
|
|
|
|
|
if field_extraction: |
|
|
if field_extraction.confidence < self.min_confidence: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="low_confidence", |
|
|
message=f"Field '{field_spec.name}' has low confidence: {field_extraction.confidence:.2f}", |
|
|
severity="warning", |
|
|
suggested_action="Manual verification recommended", |
|
|
)) |
|
|
score *= field_extraction.confidence |
|
|
else: |
|
|
score *= field_extraction.confidence |
|
|
|
|
|
|
|
|
type_issues = self._validate_type(field_spec, value) |
|
|
issues.extend(type_issues) |
|
|
if type_issues: |
|
|
score *= 0.7 |
|
|
|
|
|
|
|
|
constraint_issues = self._validate_constraints(field_spec, value) |
|
|
issues.extend(constraint_issues) |
|
|
if constraint_issues: |
|
|
score *= 0.8 |
|
|
|
|
|
return issues, max(0.0, min(1.0, score)) |
|
|
|
|
|
def _validate_type( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
value: Any, |
|
|
) -> List[ValidationIssue]: |
|
|
"""Validate field type.""" |
|
|
issues = [] |
|
|
|
|
|
expected_type = self._get_expected_python_type(field_spec.field_type) |
|
|
|
|
|
if expected_type and not isinstance(value, expected_type): |
|
|
|
|
|
try: |
|
|
expected_type(value) |
|
|
except (ValueError, TypeError): |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="type_mismatch", |
|
|
message=f"Field '{field_spec.name}' expected {field_spec.field_type.value}, got {type(value).__name__}", |
|
|
severity="warning" if not self.strict_mode else "error", |
|
|
)) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _validate_constraints( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
value: Any, |
|
|
) -> List[ValidationIssue]: |
|
|
"""Validate field constraints.""" |
|
|
issues = [] |
|
|
|
|
|
|
|
|
if field_spec.pattern: |
|
|
import re |
|
|
if not re.match(field_spec.pattern, str(value)): |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="pattern_mismatch", |
|
|
message=f"Field '{field_spec.name}' does not match pattern: {field_spec.pattern}", |
|
|
severity="warning", |
|
|
)) |
|
|
|
|
|
|
|
|
try: |
|
|
num_value = float(value) |
|
|
if field_spec.min_value is not None and num_value < field_spec.min_value: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="below_minimum", |
|
|
message=f"Field '{field_spec.name}' value {num_value} is below minimum {field_spec.min_value}", |
|
|
severity="warning", |
|
|
)) |
|
|
if field_spec.max_value is not None and num_value > field_spec.max_value: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="above_maximum", |
|
|
message=f"Field '{field_spec.name}' value {num_value} is above maximum {field_spec.max_value}", |
|
|
severity="warning", |
|
|
)) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
str_value = str(value) |
|
|
if field_spec.min_length is not None and len(str_value) < field_spec.min_length: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="too_short", |
|
|
message=f"Field '{field_spec.name}' is too short: {len(str_value)} < {field_spec.min_length}", |
|
|
severity="warning", |
|
|
)) |
|
|
if field_spec.max_length is not None and len(str_value) > field_spec.max_length: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="too_long", |
|
|
message=f"Field '{field_spec.name}' is too long: {len(str_value)} > {field_spec.max_length}", |
|
|
severity="warning", |
|
|
)) |
|
|
|
|
|
|
|
|
if field_spec.allowed_values and value not in field_spec.allowed_values: |
|
|
issues.append(ValidationIssue( |
|
|
field_name=field_spec.name, |
|
|
issue_type="not_in_allowed", |
|
|
message=f"Field '{field_spec.name}' value '{value}' not in allowed values", |
|
|
severity="warning", |
|
|
)) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _get_field_extraction( |
|
|
self, |
|
|
field_name: str, |
|
|
extraction: ExtractionResult, |
|
|
) -> Optional[FieldExtraction]: |
|
|
"""Get field extraction by name.""" |
|
|
for fe in extraction.fields: |
|
|
if fe.field_name == field_name: |
|
|
return fe |
|
|
return None |
|
|
|
|
|
def _get_expected_python_type(self, field_type: FieldType) -> Optional[type]: |
|
|
"""Get expected Python type for field type.""" |
|
|
type_map = { |
|
|
FieldType.INTEGER: int, |
|
|
FieldType.FLOAT: float, |
|
|
FieldType.BOOLEAN: bool, |
|
|
FieldType.LIST: list, |
|
|
FieldType.OBJECT: dict, |
|
|
} |
|
|
return type_map.get(field_type) |
|
|
|
|
|
def _generate_recommendations( |
|
|
self, |
|
|
issues: List[ValidationIssue], |
|
|
extraction: ExtractionResult, |
|
|
) -> List[str]: |
|
|
"""Generate recommendations based on issues.""" |
|
|
recommendations = [] |
|
|
|
|
|
|
|
|
missing_count = sum(1 for i in issues if i.issue_type == "missing") |
|
|
low_conf_count = sum(1 for i in issues if i.issue_type == "low_confidence") |
|
|
type_count = sum(1 for i in issues if i.issue_type == "type_mismatch") |
|
|
|
|
|
if missing_count > 0: |
|
|
recommendations.append( |
|
|
f"Review document for {missing_count} missing required field(s)" |
|
|
) |
|
|
|
|
|
if low_conf_count > 0: |
|
|
recommendations.append( |
|
|
f"Manual verification recommended for {low_conf_count} low-confidence field(s)" |
|
|
) |
|
|
|
|
|
if type_count > 0: |
|
|
recommendations.append( |
|
|
f"Check data types for {type_count} field(s) with type mismatches" |
|
|
) |
|
|
|
|
|
if extraction.overall_confidence < 0.5: |
|
|
recommendations.append( |
|
|
"Overall extraction confidence is low - consider manual review" |
|
|
) |
|
|
|
|
|
if len(extraction.abstained_fields) > 0: |
|
|
recommendations.append( |
|
|
f"System abstained on {len(extraction.abstained_fields)} field(s) due to uncertainty" |
|
|
) |
|
|
|
|
|
return recommendations |
|
|
|
|
|
|
|
|
class CrossFieldValidator: |
|
|
""" |
|
|
Validates relationships between fields. |
|
|
|
|
|
Checks for: |
|
|
- Consistency (e.g., subtotal + tax = total) |
|
|
- Logical relationships |
|
|
- Date ordering |
|
|
""" |
|
|
|
|
|
def validate_consistency( |
|
|
self, |
|
|
extraction: ExtractionResult, |
|
|
rules: List[Dict[str, Any]], |
|
|
) -> List[ValidationIssue]: |
|
|
""" |
|
|
Validate cross-field consistency rules. |
|
|
|
|
|
Rules format: |
|
|
{ |
|
|
"type": "sum", |
|
|
"fields": ["subtotal", "tax"], |
|
|
"equals": "total", |
|
|
"tolerance": 0.01 |
|
|
} |
|
|
""" |
|
|
issues = [] |
|
|
|
|
|
for rule in rules: |
|
|
rule_type = rule.get("type") |
|
|
|
|
|
if rule_type == "sum": |
|
|
issue = self._validate_sum_rule(extraction, rule) |
|
|
if issue: |
|
|
issues.append(issue) |
|
|
|
|
|
elif rule_type == "date_order": |
|
|
issue = self._validate_date_order(extraction, rule) |
|
|
if issue: |
|
|
issues.append(issue) |
|
|
|
|
|
elif rule_type == "required_if": |
|
|
issue = self._validate_required_if(extraction, rule) |
|
|
if issue: |
|
|
issues.append(issue) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _validate_sum_rule( |
|
|
self, |
|
|
extraction: ExtractionResult, |
|
|
rule: Dict[str, Any], |
|
|
) -> Optional[ValidationIssue]: |
|
|
"""Validate that sum of fields equals another field.""" |
|
|
fields = rule.get("fields", []) |
|
|
equals_field = rule.get("equals") |
|
|
tolerance = rule.get("tolerance", 0.01) |
|
|
|
|
|
try: |
|
|
sum_value = sum( |
|
|
float(extraction.data.get(f, 0) or 0) |
|
|
for f in fields |
|
|
) |
|
|
expected = float(extraction.data.get(equals_field, 0) or 0) |
|
|
|
|
|
if abs(sum_value - expected) > tolerance: |
|
|
return ValidationIssue( |
|
|
field_name=equals_field, |
|
|
issue_type="sum_mismatch", |
|
|
message=f"Sum of {fields} ({sum_value}) does not equal {equals_field} ({expected})", |
|
|
severity="warning", |
|
|
) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def _validate_date_order( |
|
|
self, |
|
|
extraction: ExtractionResult, |
|
|
rule: Dict[str, Any], |
|
|
) -> Optional[ValidationIssue]: |
|
|
"""Validate that dates are in correct order.""" |
|
|
from datetime import datetime |
|
|
|
|
|
before_field = rule.get("before") |
|
|
after_field = rule.get("after") |
|
|
|
|
|
before_val = extraction.data.get(before_field) |
|
|
after_val = extraction.data.get(after_field) |
|
|
|
|
|
if not before_val or not after_val: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
formats = ["%Y-%m-%d", "%m/%d/%Y", "%d/%m/%Y", "%B %d, %Y"] |
|
|
|
|
|
before_date = None |
|
|
after_date = None |
|
|
|
|
|
for fmt in formats: |
|
|
try: |
|
|
before_date = datetime.strptime(str(before_val), fmt) |
|
|
break |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
for fmt in formats: |
|
|
try: |
|
|
after_date = datetime.strptime(str(after_val), fmt) |
|
|
break |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
if before_date and after_date and before_date > after_date: |
|
|
return ValidationIssue( |
|
|
field_name=after_field, |
|
|
issue_type="date_order", |
|
|
message=f"Date {before_field} ({before_val}) should be before {after_field} ({after_val})", |
|
|
severity="warning", |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def _validate_required_if( |
|
|
self, |
|
|
extraction: ExtractionResult, |
|
|
rule: Dict[str, Any], |
|
|
) -> Optional[ValidationIssue]: |
|
|
"""Validate conditional required fields.""" |
|
|
field = rule.get("field") |
|
|
required_if = rule.get("required_if") |
|
|
condition_value = rule.get("value") |
|
|
|
|
|
condition_field_value = extraction.data.get(required_if) |
|
|
|
|
|
|
|
|
condition_met = False |
|
|
if condition_value is not None: |
|
|
condition_met = condition_field_value == condition_value |
|
|
else: |
|
|
condition_met = condition_field_value is not None |
|
|
|
|
|
if condition_met: |
|
|
field_value = extraction.data.get(field) |
|
|
if field_value is None: |
|
|
return ValidationIssue( |
|
|
field_name=field, |
|
|
issue_type="conditional_required", |
|
|
message=f"Field '{field}' is required when '{required_if}' is present", |
|
|
severity="warning", |
|
|
) |
|
|
|
|
|
return None |
|
|
|