|
|
""" |
|
|
Extraction Critic for Validation |
|
|
|
|
|
Validates extracted information against source evidence. |
|
|
Provides confidence scoring and abstention recommendations. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
from enum import Enum |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
|
|
|
try: |
|
|
import httpx |
|
|
HTTPX_AVAILABLE = True |
|
|
except ImportError: |
|
|
HTTPX_AVAILABLE = False |
|
|
|
|
|
|
|
|
class ValidationStatus(str, Enum): |
|
|
"""Validation status codes.""" |
|
|
VALID = "valid" |
|
|
INVALID = "invalid" |
|
|
UNCERTAIN = "uncertain" |
|
|
ABSTAIN = "abstain" |
|
|
NO_EVIDENCE = "no_evidence" |
|
|
|
|
|
|
|
|
class CriticConfig(BaseModel): |
|
|
"""Configuration for extraction critic.""" |
|
|
|
|
|
llm_provider: str = Field(default="ollama", description="LLM provider") |
|
|
ollama_base_url: str = Field(default="http://localhost:11434") |
|
|
ollama_model: str = Field(default="llama3.2:3b") |
|
|
|
|
|
|
|
|
confidence_threshold: float = Field( |
|
|
default=0.7, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Minimum confidence for valid extraction" |
|
|
) |
|
|
evidence_required: bool = Field( |
|
|
default=True, |
|
|
description="Require evidence for validation" |
|
|
) |
|
|
strict_mode: bool = Field( |
|
|
default=False, |
|
|
description="Strict validation mode" |
|
|
) |
|
|
|
|
|
|
|
|
max_fields_per_request: int = Field(default=10, ge=1) |
|
|
timeout: float = Field(default=60.0, ge=1.0) |
|
|
|
|
|
|
|
|
class FieldValidation(BaseModel): |
|
|
"""Validation result for a single field.""" |
|
|
field_name: str |
|
|
extracted_value: Any |
|
|
status: ValidationStatus |
|
|
confidence: float |
|
|
reasoning: str |
|
|
|
|
|
|
|
|
evidence_found: bool = False |
|
|
evidence_snippet: Optional[str] = None |
|
|
evidence_page: Optional[int] = None |
|
|
|
|
|
|
|
|
suggested_value: Optional[Any] = None |
|
|
correction_reason: Optional[str] = None |
|
|
|
|
|
|
|
|
class ValidationResult(BaseModel): |
|
|
"""Complete validation result.""" |
|
|
overall_status: ValidationStatus |
|
|
overall_confidence: float |
|
|
field_validations: List[FieldValidation] |
|
|
|
|
|
|
|
|
valid_count: int = 0 |
|
|
invalid_count: int = 0 |
|
|
uncertain_count: int = 0 |
|
|
abstain_count: int = 0 |
|
|
|
|
|
|
|
|
should_accept: bool |
|
|
abstain_reason: Optional[str] = None |
|
|
|
|
|
|
|
|
class ExtractionCritic: |
|
|
""" |
|
|
Critic for validating extracted information. |
|
|
|
|
|
Features: |
|
|
- Validates extracted fields against source evidence |
|
|
- Provides confidence scores |
|
|
- Recommends abstention when uncertain |
|
|
- Suggests corrections when possible |
|
|
""" |
|
|
|
|
|
VALIDATION_PROMPT = """You are a critical validator for document extraction. |
|
|
Your task is to validate extracted information against the source evidence. |
|
|
|
|
|
For each field, determine: |
|
|
1. Is the extracted value supported by the evidence? (yes/no/partially) |
|
|
2. Confidence score (0.0 to 1.0) |
|
|
3. Brief reasoning |
|
|
4. If incorrect, suggest the correct value |
|
|
|
|
|
Be strict and skeptical. Only mark as valid if clearly supported. |
|
|
|
|
|
Evidence: |
|
|
{evidence} |
|
|
|
|
|
Extracted Fields to Validate: |
|
|
{fields} |
|
|
|
|
|
Respond in JSON format: |
|
|
{{ |
|
|
"validations": [ |
|
|
{{ |
|
|
"field": "field_name", |
|
|
"status": "valid|invalid|uncertain|no_evidence", |
|
|
"confidence": 0.0-1.0, |
|
|
"reasoning": "explanation", |
|
|
"suggested_value": null or corrected value |
|
|
}} |
|
|
] |
|
|
}}""" |
|
|
|
|
|
def __init__(self, config: Optional[CriticConfig] = None): |
|
|
"""Initialize extraction critic.""" |
|
|
self.config = config or CriticConfig() |
|
|
|
|
|
def validate_extraction( |
|
|
self, |
|
|
extracted_fields: Dict[str, Any], |
|
|
evidence: List[Dict[str, Any]], |
|
|
) -> ValidationResult: |
|
|
""" |
|
|
Validate extracted fields against evidence. |
|
|
|
|
|
Args: |
|
|
extracted_fields: Dictionary of field_name -> value |
|
|
evidence: List of evidence chunks with text, page, etc. |
|
|
|
|
|
Returns: |
|
|
ValidationResult |
|
|
""" |
|
|
if not extracted_fields: |
|
|
return ValidationResult( |
|
|
overall_status=ValidationStatus.ABSTAIN, |
|
|
overall_confidence=0.0, |
|
|
field_validations=[], |
|
|
should_accept=False, |
|
|
abstain_reason="No fields to validate", |
|
|
) |
|
|
|
|
|
|
|
|
if not evidence and self.config.evidence_required: |
|
|
return self._create_no_evidence_result(extracted_fields) |
|
|
|
|
|
|
|
|
field_validations = self._validate_with_llm(extracted_fields, evidence) |
|
|
|
|
|
|
|
|
valid_count = sum(1 for v in field_validations if v.status == ValidationStatus.VALID) |
|
|
invalid_count = sum(1 for v in field_validations if v.status == ValidationStatus.INVALID) |
|
|
uncertain_count = sum(1 for v in field_validations if v.status == ValidationStatus.UNCERTAIN) |
|
|
abstain_count = sum(1 for v in field_validations if v.status == ValidationStatus.ABSTAIN) |
|
|
|
|
|
|
|
|
if field_validations: |
|
|
overall_confidence = sum(v.confidence for v in field_validations) / len(field_validations) |
|
|
else: |
|
|
overall_confidence = 0.0 |
|
|
|
|
|
|
|
|
if invalid_count > 0: |
|
|
overall_status = ValidationStatus.INVALID |
|
|
elif abstain_count > valid_count: |
|
|
overall_status = ValidationStatus.ABSTAIN |
|
|
elif uncertain_count > valid_count: |
|
|
overall_status = ValidationStatus.UNCERTAIN |
|
|
else: |
|
|
overall_status = ValidationStatus.VALID |
|
|
|
|
|
|
|
|
should_accept = ( |
|
|
overall_confidence >= self.config.confidence_threshold |
|
|
and invalid_count == 0 |
|
|
and overall_status in [ValidationStatus.VALID, ValidationStatus.UNCERTAIN] |
|
|
) |
|
|
|
|
|
|
|
|
abstain_reason = None |
|
|
if not should_accept: |
|
|
if overall_confidence < self.config.confidence_threshold: |
|
|
abstain_reason = f"Confidence ({overall_confidence:.2f}) below threshold ({self.config.confidence_threshold})" |
|
|
elif invalid_count > 0: |
|
|
abstain_reason = f"{invalid_count} field(s) validated as invalid" |
|
|
elif overall_status == ValidationStatus.ABSTAIN: |
|
|
abstain_reason = "Insufficient evidence to validate" |
|
|
|
|
|
return ValidationResult( |
|
|
overall_status=overall_status, |
|
|
overall_confidence=overall_confidence, |
|
|
field_validations=field_validations, |
|
|
valid_count=valid_count, |
|
|
invalid_count=invalid_count, |
|
|
uncertain_count=uncertain_count, |
|
|
abstain_count=abstain_count, |
|
|
should_accept=should_accept, |
|
|
abstain_reason=abstain_reason, |
|
|
) |
|
|
|
|
|
def _validate_with_llm( |
|
|
self, |
|
|
fields: Dict[str, Any], |
|
|
evidence: List[Dict[str, Any]], |
|
|
) -> List[FieldValidation]: |
|
|
"""Validate fields using LLM.""" |
|
|
|
|
|
evidence_text = self._format_evidence(evidence) |
|
|
|
|
|
|
|
|
fields_text = "\n".join( |
|
|
f"- {name}: {value}" |
|
|
for name, value in fields.items() |
|
|
) |
|
|
|
|
|
|
|
|
prompt = self.VALIDATION_PROMPT.format( |
|
|
evidence=evidence_text, |
|
|
fields=fields_text, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
response = self._call_llm(prompt) |
|
|
validations = self._parse_validation_response(response, fields, evidence) |
|
|
except Exception as e: |
|
|
logger.error(f"LLM validation failed: {e}") |
|
|
|
|
|
validations = self._heuristic_validation(fields, evidence) |
|
|
|
|
|
return validations |
|
|
|
|
|
def _format_evidence(self, evidence: List[Dict[str, Any]]) -> str: |
|
|
"""Format evidence for prompt.""" |
|
|
parts = [] |
|
|
for i, ev in enumerate(evidence[:10], 1): |
|
|
page = ev.get("page", "?") |
|
|
text = ev.get("text", ev.get("snippet", ""))[:500] |
|
|
parts.append(f"[{i}] Page {page}: {text}") |
|
|
return "\n\n".join(parts) |
|
|
|
|
|
def _call_llm(self, prompt: str) -> str: |
|
|
"""Call LLM for validation.""" |
|
|
if not HTTPX_AVAILABLE: |
|
|
raise ImportError("httpx required for LLM calls") |
|
|
|
|
|
with httpx.Client(timeout=self.config.timeout) as client: |
|
|
response = client.post( |
|
|
f"{self.config.ollama_base_url}/api/generate", |
|
|
json={ |
|
|
"model": self.config.ollama_model, |
|
|
"prompt": prompt, |
|
|
"stream": False, |
|
|
"options": {"temperature": 0.1}, |
|
|
}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
return response.json().get("response", "") |
|
|
|
|
|
def _parse_validation_response( |
|
|
self, |
|
|
response: str, |
|
|
fields: Dict[str, Any], |
|
|
evidence: List[Dict[str, Any]], |
|
|
) -> List[FieldValidation]: |
|
|
"""Parse LLM validation response.""" |
|
|
import json |
|
|
import re |
|
|
|
|
|
validations = [] |
|
|
|
|
|
|
|
|
json_match = re.search(r'\{[\s\S]*\}', response) |
|
|
if json_match: |
|
|
try: |
|
|
data = json.loads(json_match.group()) |
|
|
llm_validations = data.get("validations", []) |
|
|
|
|
|
for v in llm_validations: |
|
|
field_name = v.get("field", "") |
|
|
if field_name not in fields: |
|
|
continue |
|
|
|
|
|
status_str = v.get("status", "uncertain").lower() |
|
|
try: |
|
|
status = ValidationStatus(status_str) |
|
|
except ValueError: |
|
|
status = ValidationStatus.UNCERTAIN |
|
|
|
|
|
validation = FieldValidation( |
|
|
field_name=field_name, |
|
|
extracted_value=fields[field_name], |
|
|
status=status, |
|
|
confidence=float(v.get("confidence", 0.5)), |
|
|
reasoning=v.get("reasoning", ""), |
|
|
evidence_found=status != ValidationStatus.NO_EVIDENCE, |
|
|
suggested_value=v.get("suggested_value"), |
|
|
) |
|
|
validations.append(validation) |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
validated_fields = {v.field_name for v in validations} |
|
|
for field_name, value in fields.items(): |
|
|
if field_name not in validated_fields: |
|
|
validations.append(FieldValidation( |
|
|
field_name=field_name, |
|
|
extracted_value=value, |
|
|
status=ValidationStatus.UNCERTAIN, |
|
|
confidence=0.5, |
|
|
reasoning="Could not validate", |
|
|
evidence_found=False, |
|
|
)) |
|
|
|
|
|
return validations |
|
|
|
|
|
def _heuristic_validation( |
|
|
self, |
|
|
fields: Dict[str, Any], |
|
|
evidence: List[Dict[str, Any]], |
|
|
) -> List[FieldValidation]: |
|
|
"""Heuristic validation when LLM fails.""" |
|
|
validations = [] |
|
|
evidence_text = " ".join( |
|
|
ev.get("text", ev.get("snippet", "")).lower() |
|
|
for ev in evidence |
|
|
) |
|
|
|
|
|
for field_name, value in fields.items(): |
|
|
|
|
|
value_str = str(value).lower() |
|
|
found = value_str in evidence_text if value_str else False |
|
|
|
|
|
if found: |
|
|
status = ValidationStatus.VALID |
|
|
confidence = 0.7 |
|
|
reasoning = "Value found in evidence" |
|
|
elif evidence: |
|
|
status = ValidationStatus.UNCERTAIN |
|
|
confidence = 0.4 |
|
|
reasoning = "Value not directly found in evidence" |
|
|
else: |
|
|
status = ValidationStatus.NO_EVIDENCE |
|
|
confidence = 0.2 |
|
|
reasoning = "No evidence available" |
|
|
|
|
|
validations.append(FieldValidation( |
|
|
field_name=field_name, |
|
|
extracted_value=value, |
|
|
status=status, |
|
|
confidence=confidence, |
|
|
reasoning=reasoning, |
|
|
evidence_found=found, |
|
|
)) |
|
|
|
|
|
return validations |
|
|
|
|
|
def _create_no_evidence_result( |
|
|
self, |
|
|
fields: Dict[str, Any], |
|
|
) -> ValidationResult: |
|
|
"""Create result when no evidence is available.""" |
|
|
validations = [ |
|
|
FieldValidation( |
|
|
field_name=name, |
|
|
extracted_value=value, |
|
|
status=ValidationStatus.NO_EVIDENCE, |
|
|
confidence=0.0, |
|
|
reasoning="No evidence provided for validation", |
|
|
evidence_found=False, |
|
|
) |
|
|
for name, value in fields.items() |
|
|
] |
|
|
|
|
|
return ValidationResult( |
|
|
overall_status=ValidationStatus.ABSTAIN, |
|
|
overall_confidence=0.0, |
|
|
field_validations=validations, |
|
|
abstain_count=len(validations), |
|
|
should_accept=False, |
|
|
abstain_reason="No evidence available for validation", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_extraction_critic: Optional[ExtractionCritic] = None |
|
|
|
|
|
|
|
|
def get_extraction_critic( |
|
|
config: Optional[CriticConfig] = None, |
|
|
) -> ExtractionCritic: |
|
|
"""Get or create singleton extraction critic.""" |
|
|
global _extraction_critic |
|
|
if _extraction_critic is None: |
|
|
_extraction_critic = ExtractionCritic(config) |
|
|
return _extraction_critic |
|
|
|
|
|
|
|
|
def reset_extraction_critic(): |
|
|
"""Reset the global critic instance.""" |
|
|
global _extraction_critic |
|
|
_extraction_critic = None |
|
|
|