File size: 13,689 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
"""
Extraction Critic for Validation

Validates extracted information against source evidence.
Provides confidence scoring and abstention recommendations.
"""

from typing import List, Optional, Dict, Any, Tuple
from enum import Enum
from pydantic import BaseModel, Field
from loguru import logger

try:
    import httpx
    HTTPX_AVAILABLE = True
except ImportError:
    HTTPX_AVAILABLE = False


class ValidationStatus(str, Enum):
    """Validation status codes."""
    VALID = "valid"
    INVALID = "invalid"
    UNCERTAIN = "uncertain"
    ABSTAIN = "abstain"
    NO_EVIDENCE = "no_evidence"


class CriticConfig(BaseModel):
    """Configuration for extraction critic."""
    # LLM settings
    llm_provider: str = Field(default="ollama", description="LLM provider")
    ollama_base_url: str = Field(default="http://localhost:11434")
    ollama_model: str = Field(default="llama3.2:3b")

    # Validation thresholds
    confidence_threshold: float = Field(
        default=0.7,
        ge=0.0,
        le=1.0,
        description="Minimum confidence for valid extraction"
    )
    evidence_required: bool = Field(
        default=True,
        description="Require evidence for validation"
    )
    strict_mode: bool = Field(
        default=False,
        description="Strict validation mode"
    )

    # Processing
    max_fields_per_request: int = Field(default=10, ge=1)
    timeout: float = Field(default=60.0, ge=1.0)


class FieldValidation(BaseModel):
    """Validation result for a single field."""
    field_name: str
    extracted_value: Any
    status: ValidationStatus
    confidence: float
    reasoning: str

    # Evidence
    evidence_found: bool = False
    evidence_snippet: Optional[str] = None
    evidence_page: Optional[int] = None

    # Suggestions
    suggested_value: Optional[Any] = None
    correction_reason: Optional[str] = None


class ValidationResult(BaseModel):
    """Complete validation result."""
    overall_status: ValidationStatus
    overall_confidence: float
    field_validations: List[FieldValidation]

    # Statistics
    valid_count: int = 0
    invalid_count: int = 0
    uncertain_count: int = 0
    abstain_count: int = 0

    # Recommendations
    should_accept: bool
    abstain_reason: Optional[str] = None


class ExtractionCritic:
    """
    Critic for validating extracted information.

    Features:
    - Validates extracted fields against source evidence
    - Provides confidence scores
    - Recommends abstention when uncertain
    - Suggests corrections when possible
    """

    VALIDATION_PROMPT = """You are a critical validator for document extraction.
Your task is to validate extracted information against the source evidence.

For each field, determine:
1. Is the extracted value supported by the evidence? (yes/no/partially)
2. Confidence score (0.0 to 1.0)
3. Brief reasoning
4. If incorrect, suggest the correct value

Be strict and skeptical. Only mark as valid if clearly supported.

Evidence:
{evidence}

Extracted Fields to Validate:
{fields}

Respond in JSON format:
{{
  "validations": [
    {{
      "field": "field_name",
      "status": "valid|invalid|uncertain|no_evidence",
      "confidence": 0.0-1.0,
      "reasoning": "explanation",
      "suggested_value": null or corrected value
    }}
  ]
}}"""

    def __init__(self, config: Optional[CriticConfig] = None):
        """Initialize extraction critic."""
        self.config = config or CriticConfig()

    def validate_extraction(
        self,
        extracted_fields: Dict[str, Any],
        evidence: List[Dict[str, Any]],
    ) -> ValidationResult:
        """
        Validate extracted fields against evidence.

        Args:
            extracted_fields: Dictionary of field_name -> value
            evidence: List of evidence chunks with text, page, etc.

        Returns:
            ValidationResult
        """
        if not extracted_fields:
            return ValidationResult(
                overall_status=ValidationStatus.ABSTAIN,
                overall_confidence=0.0,
                field_validations=[],
                should_accept=False,
                abstain_reason="No fields to validate",
            )

        # Check if evidence is available
        if not evidence and self.config.evidence_required:
            return self._create_no_evidence_result(extracted_fields)

        # Validate using LLM
        field_validations = self._validate_with_llm(extracted_fields, evidence)

        # Calculate overall statistics
        valid_count = sum(1 for v in field_validations if v.status == ValidationStatus.VALID)
        invalid_count = sum(1 for v in field_validations if v.status == ValidationStatus.INVALID)
        uncertain_count = sum(1 for v in field_validations if v.status == ValidationStatus.UNCERTAIN)
        abstain_count = sum(1 for v in field_validations if v.status == ValidationStatus.ABSTAIN)

        # Calculate overall confidence
        if field_validations:
            overall_confidence = sum(v.confidence for v in field_validations) / len(field_validations)
        else:
            overall_confidence = 0.0

        # Determine overall status
        if invalid_count > 0:
            overall_status = ValidationStatus.INVALID
        elif abstain_count > valid_count:
            overall_status = ValidationStatus.ABSTAIN
        elif uncertain_count > valid_count:
            overall_status = ValidationStatus.UNCERTAIN
        else:
            overall_status = ValidationStatus.VALID

        # Determine if should accept
        should_accept = (
            overall_confidence >= self.config.confidence_threshold
            and invalid_count == 0
            and overall_status in [ValidationStatus.VALID, ValidationStatus.UNCERTAIN]
        )

        # Abstain reason
        abstain_reason = None
        if not should_accept:
            if overall_confidence < self.config.confidence_threshold:
                abstain_reason = f"Confidence ({overall_confidence:.2f}) below threshold ({self.config.confidence_threshold})"
            elif invalid_count > 0:
                abstain_reason = f"{invalid_count} field(s) validated as invalid"
            elif overall_status == ValidationStatus.ABSTAIN:
                abstain_reason = "Insufficient evidence to validate"

        return ValidationResult(
            overall_status=overall_status,
            overall_confidence=overall_confidence,
            field_validations=field_validations,
            valid_count=valid_count,
            invalid_count=invalid_count,
            uncertain_count=uncertain_count,
            abstain_count=abstain_count,
            should_accept=should_accept,
            abstain_reason=abstain_reason,
        )

    def _validate_with_llm(
        self,
        fields: Dict[str, Any],
        evidence: List[Dict[str, Any]],
    ) -> List[FieldValidation]:
        """Validate fields using LLM."""
        # Format evidence
        evidence_text = self._format_evidence(evidence)

        # Format fields
        fields_text = "\n".join(
            f"- {name}: {value}"
            for name, value in fields.items()
        )

        # Build prompt
        prompt = self.VALIDATION_PROMPT.format(
            evidence=evidence_text,
            fields=fields_text,
        )

        # Call LLM
        try:
            response = self._call_llm(prompt)
            validations = self._parse_validation_response(response, fields, evidence)
        except Exception as e:
            logger.error(f"LLM validation failed: {e}")
            # Fall back to heuristic validation
            validations = self._heuristic_validation(fields, evidence)

        return validations

    def _format_evidence(self, evidence: List[Dict[str, Any]]) -> str:
        """Format evidence for prompt."""
        parts = []
        for i, ev in enumerate(evidence[:10], 1):  # Limit to 10 chunks
            page = ev.get("page", "?")
            text = ev.get("text", ev.get("snippet", ""))[:500]
            parts.append(f"[{i}] Page {page}: {text}")
        return "\n\n".join(parts)

    def _call_llm(self, prompt: str) -> str:
        """Call LLM for validation."""
        if not HTTPX_AVAILABLE:
            raise ImportError("httpx required for LLM calls")

        with httpx.Client(timeout=self.config.timeout) as client:
            response = client.post(
                f"{self.config.ollama_base_url}/api/generate",
                json={
                    "model": self.config.ollama_model,
                    "prompt": prompt,
                    "stream": False,
                    "options": {"temperature": 0.1},
                },
            )
            response.raise_for_status()
            return response.json().get("response", "")

    def _parse_validation_response(
        self,
        response: str,
        fields: Dict[str, Any],
        evidence: List[Dict[str, Any]],
    ) -> List[FieldValidation]:
        """Parse LLM validation response."""
        import json
        import re

        validations = []

        # Try to extract JSON from response
        json_match = re.search(r'\{[\s\S]*\}', response)
        if json_match:
            try:
                data = json.loads(json_match.group())
                llm_validations = data.get("validations", [])

                for v in llm_validations:
                    field_name = v.get("field", "")
                    if field_name not in fields:
                        continue

                    status_str = v.get("status", "uncertain").lower()
                    try:
                        status = ValidationStatus(status_str)
                    except ValueError:
                        status = ValidationStatus.UNCERTAIN

                    validation = FieldValidation(
                        field_name=field_name,
                        extracted_value=fields[field_name],
                        status=status,
                        confidence=float(v.get("confidence", 0.5)),
                        reasoning=v.get("reasoning", ""),
                        evidence_found=status != ValidationStatus.NO_EVIDENCE,
                        suggested_value=v.get("suggested_value"),
                    )
                    validations.append(validation)

            except json.JSONDecodeError:
                pass

        # Add any missing fields
        validated_fields = {v.field_name for v in validations}
        for field_name, value in fields.items():
            if field_name not in validated_fields:
                validations.append(FieldValidation(
                    field_name=field_name,
                    extracted_value=value,
                    status=ValidationStatus.UNCERTAIN,
                    confidence=0.5,
                    reasoning="Could not validate",
                    evidence_found=False,
                ))

        return validations

    def _heuristic_validation(
        self,
        fields: Dict[str, Any],
        evidence: List[Dict[str, Any]],
    ) -> List[FieldValidation]:
        """Heuristic validation when LLM fails."""
        validations = []
        evidence_text = " ".join(
            ev.get("text", ev.get("snippet", "")).lower()
            for ev in evidence
        )

        for field_name, value in fields.items():
            # Simple substring matching
            value_str = str(value).lower()
            found = value_str in evidence_text if value_str else False

            if found:
                status = ValidationStatus.VALID
                confidence = 0.7
                reasoning = "Value found in evidence"
            elif evidence:
                status = ValidationStatus.UNCERTAIN
                confidence = 0.4
                reasoning = "Value not directly found in evidence"
            else:
                status = ValidationStatus.NO_EVIDENCE
                confidence = 0.2
                reasoning = "No evidence available"

            validations.append(FieldValidation(
                field_name=field_name,
                extracted_value=value,
                status=status,
                confidence=confidence,
                reasoning=reasoning,
                evidence_found=found,
            ))

        return validations

    def _create_no_evidence_result(
        self,
        fields: Dict[str, Any],
    ) -> ValidationResult:
        """Create result when no evidence is available."""
        validations = [
            FieldValidation(
                field_name=name,
                extracted_value=value,
                status=ValidationStatus.NO_EVIDENCE,
                confidence=0.0,
                reasoning="No evidence provided for validation",
                evidence_found=False,
            )
            for name, value in fields.items()
        ]

        return ValidationResult(
            overall_status=ValidationStatus.ABSTAIN,
            overall_confidence=0.0,
            field_validations=validations,
            abstain_count=len(validations),
            should_accept=False,
            abstain_reason="No evidence available for validation",
        )


# Global instance and factory
_extraction_critic: Optional[ExtractionCritic] = None


def get_extraction_critic(
    config: Optional[CriticConfig] = None,
) -> ExtractionCritic:
    """Get or create singleton extraction critic."""
    global _extraction_critic
    if _extraction_critic is None:
        _extraction_critic = ExtractionCritic(config)
    return _extraction_critic


def reset_extraction_critic():
    """Reset the global critic instance."""
    global _extraction_critic
    _extraction_critic = None