File size: 16,042 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""
Field Extraction Engine

Extracts structured data from parsed documents using schemas.
"""

import logging
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from ..chunks.models import (
    DocumentChunk,
    ExtractionResult,
    FieldExtraction,
    EvidenceRef,
    ParseResult,
    TableChunk,
    ChartChunk,
    ChunkType,
    ConfidenceLevel,
)
from ..grounding.evidence import EvidenceBuilder, EvidenceTracker
from .schema import ExtractionSchema, FieldSpec, FieldType

logger = logging.getLogger(__name__)


@dataclass
class ExtractionConfig:
    """Configuration for field extraction."""

    # Confidence thresholds
    min_field_confidence: float = 0.5
    min_overall_confidence: float = 0.5

    # Abstention behavior
    abstain_on_low_confidence: bool = True
    abstain_threshold: float = 0.3

    # Search behavior
    search_all_chunks: bool = True
    prefer_structured_sources: bool = True  # Tables, forms

    # Validation
    validate_extracted_values: bool = True
    normalize_values: bool = True


class FieldExtractor:
    """
    Extracts fields from parsed documents.

    Uses schema definitions to identify and extract
    structured data with evidence grounding.
    """

    def __init__(
        self,
        config: Optional[ExtractionConfig] = None,
        evidence_builder: Optional[EvidenceBuilder] = None,
    ):
        self.config = config or ExtractionConfig()
        self.evidence_builder = evidence_builder or EvidenceBuilder()
        self._normalizers: Dict[FieldType, Callable] = self._build_normalizers()
        self._validators: Dict[FieldType, Callable] = self._build_validators()

    def extract(
        self,
        parse_result: ParseResult,
        schema: ExtractionSchema,
    ) -> ExtractionResult:
        """
        Extract fields from a parsed document.

        Args:
            parse_result: Parsed document with chunks
            schema: Extraction schema defining fields

        Returns:
            ExtractionResult with extracted values and evidence
        """
        logger.info(f"Extracting {len(schema.fields)} fields from {parse_result.filename}")

        evidence_tracker = EvidenceTracker()
        field_extractions: List[FieldExtraction] = []
        extracted_data: Dict[str, Any] = {}
        abstained_fields: List[str] = []

        for field_spec in schema.fields:
            extraction = self._extract_field(
                field_spec=field_spec,
                chunks=parse_result.chunks,
                evidence_tracker=evidence_tracker,
            )

            if extraction:
                field_extractions.append(extraction)
                extracted_data[field_spec.name] = extraction.value

                # Check for abstention
                if extraction.confidence < self.config.abstain_threshold:
                    if self.config.abstain_on_low_confidence:
                        abstained_fields.append(field_spec.name)
                        extracted_data[field_spec.name] = None
            else:
                # Field not found
                if field_spec.required:
                    abstained_fields.append(field_spec.name)
                extracted_data[field_spec.name] = field_spec.default

        # Calculate overall confidence
        if field_extractions:
            overall_confidence = sum(f.confidence for f in field_extractions) / len(field_extractions)
        else:
            overall_confidence = 0.0

        return ExtractionResult(
            data=extracted_data,
            fields=field_extractions,
            evidence=evidence_tracker.get_all(),
            overall_confidence=overall_confidence,
            abstained_fields=abstained_fields,
        )

    def _extract_field(
        self,
        field_spec: FieldSpec,
        chunks: List[DocumentChunk],
        evidence_tracker: EvidenceTracker,
    ) -> Optional[FieldExtraction]:
        """Extract a single field from chunks."""
        candidates: List[Tuple[Any, float, DocumentChunk]] = []

        # Search relevant chunks
        relevant_chunks = self._find_relevant_chunks(field_spec, chunks)

        for chunk in relevant_chunks:
            value, confidence = self._extract_from_chunk(field_spec, chunk)

            if value is not None and confidence >= self.config.min_field_confidence:
                candidates.append((value, confidence, chunk))

        if not candidates:
            return None

        # Select best candidate
        candidates.sort(key=lambda x: x[1], reverse=True)
        best_value, best_confidence, best_chunk = candidates[0]

        # Normalize value
        if self.config.normalize_values:
            best_value = self._normalize_value(best_value, field_spec.field_type)

        # Validate
        if self.config.validate_extracted_values:
            is_valid = self._validate_value(best_value, field_spec)
            if not is_valid:
                best_confidence *= 0.5  # Penalize invalid values

        # Create evidence
        evidence = self.evidence_builder.create_evidence(
            chunk=best_chunk,
            value=best_value,
            field_name=field_spec.name,
        )
        evidence_tracker.add(evidence, field_spec.name)

        # Determine confidence level
        confidence_level = self._confidence_to_level(best_confidence)

        return FieldExtraction(
            field_name=field_spec.name,
            value=best_value,
            confidence=best_confidence,
            confidence_level=confidence_level,
            evidence=evidence,
            raw_text=best_chunk.text[:200],
        )

    def _find_relevant_chunks(
        self,
        field_spec: FieldSpec,
        chunks: List[DocumentChunk],
    ) -> List[DocumentChunk]:
        """Find chunks that might contain the field value."""
        # Build search terms
        search_terms = [field_spec.name.lower().replace("_", " ")]
        search_terms.extend(a.lower() for a in field_spec.aliases)
        search_terms.extend(h.lower() for h in field_spec.context_hints)

        relevant = []

        for chunk in chunks:
            # Prefer structured sources
            if self.config.prefer_structured_sources:
                if isinstance(chunk, (TableChunk, )) or chunk.chunk_type == ChunkType.FORM_FIELD:
                    relevant.append(chunk)
                    continue

            # Check text content
            text_lower = chunk.text.lower()
            for term in search_terms:
                if term in text_lower:
                    relevant.append(chunk)
                    break

        # If no relevant chunks found and search_all_chunks enabled
        if not relevant and self.config.search_all_chunks:
            return chunks

        return relevant

    def _extract_from_chunk(
        self,
        field_spec: FieldSpec,
        chunk: DocumentChunk,
    ) -> Tuple[Optional[Any], float]:
        """Extract field value from a single chunk."""
        # Handle structured chunks specially
        if isinstance(chunk, TableChunk):
            return self._extract_from_table(field_spec, chunk)

        # Text-based extraction
        return self._extract_from_text(field_spec, chunk.text)

    def _extract_from_table(
        self,
        field_spec: FieldSpec,
        table: TableChunk,
    ) -> Tuple[Optional[Any], float]:
        """Extract field from a table chunk."""
        search_terms = [field_spec.name.lower().replace("_", " ")]
        search_terms.extend(a.lower() for a in field_spec.aliases)

        # Search in header row for field name
        for col_idx in range(table.num_cols):
            header_cell = table.get_cell(0, col_idx)
            if header_cell is None:
                continue

            header_text = header_cell.text.lower()
            for term in search_terms:
                if term in header_text:
                    # Found column - get value from first data row
                    value_cell = table.get_cell(1, col_idx)
                    if value_cell and value_cell.text:
                        return value_cell.text, value_cell.confidence

        # Search in first column for field name
        for row_idx in range(table.num_rows):
            label_cell = table.get_cell(row_idx, 0)
            if label_cell is None:
                continue

            label_text = label_cell.text.lower()
            for term in search_terms:
                if term in label_text:
                    # Found row - get value from second column
                    value_cell = table.get_cell(row_idx, 1)
                    if value_cell and value_cell.text:
                        return value_cell.text, value_cell.confidence

        return None, 0.0

    def _extract_from_text(
        self,
        field_spec: FieldSpec,
        text: str,
    ) -> Tuple[Optional[Any], float]:
        """Extract field from text using patterns."""
        # Build patterns based on field type
        patterns = self._get_extraction_patterns(field_spec)

        for pattern, confidence_boost in patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            if matches:
                # Return first match
                value = matches[0]
                if isinstance(value, tuple):
                    value = value[0]  # Take first capture group
                return value.strip(), 0.7 + confidence_boost

        # Try simple key-value pattern
        search_terms = [field_spec.name.replace("_", " ")]
        search_terms.extend(field_spec.aliases)

        for term in search_terms:
            # Pattern: "Term: Value" or "Term - Value"
            pattern = rf"{re.escape(term)}[\s::\-]+([^\n]+)"
            matches = re.findall(pattern, text, re.IGNORECASE)
            if matches:
                return matches[0].strip(), 0.6

        return None, 0.0

    def _get_extraction_patterns(
        self,
        field_spec: FieldSpec,
    ) -> List[Tuple[str, float]]:
        """Get regex patterns for field type."""
        patterns = []

        # Use custom pattern if provided
        if field_spec.pattern:
            patterns.append((field_spec.pattern, 0.2))

        # Type-specific patterns
        if field_spec.field_type == FieldType.DATE:
            patterns.extend([
                (r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', 0.1),
                (r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', 0.1),
                (r'\b([A-Z][a-z]+\s+\d{1,2},?\s+\d{4})\b', 0.1),
            ])
        elif field_spec.field_type == FieldType.CURRENCY:
            patterns.extend([
                (r'[\$\€\£][\s]*([\d,]+\.?\d*)', 0.2),
                (r'([\d,]+\.?\d*)\s*(?:USD|EUR|GBP)', 0.1),
            ])
        elif field_spec.field_type == FieldType.PERCENTAGE:
            patterns.append((r'([\d.]+)\s*%', 0.2))
        elif field_spec.field_type == FieldType.EMAIL:
            patterns.append((r'([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', 0.3))
        elif field_spec.field_type == FieldType.PHONE:
            patterns.extend([
                (r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}', 0.2),
                (r'\+\d{1,3}[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}', 0.2),
            ])
        elif field_spec.field_type == FieldType.INTEGER:
            patterns.append((r'\b(\d+)\b', 0.0))
        elif field_spec.field_type == FieldType.FLOAT:
            patterns.append((r'\b(\d+\.?\d*)\b', 0.0))

        return patterns

    def _normalize_value(self, value: Any, field_type: FieldType) -> Any:
        """Normalize extracted value."""
        normalizer = self._normalizers.get(field_type)
        if normalizer:
            try:
                return normalizer(value)
            except Exception:
                pass
        return value

    def _validate_value(self, value: Any, field_spec: FieldSpec) -> bool:
        """Validate extracted value against field spec."""
        if value is None:
            return not field_spec.required

        # Type validation
        validator = self._validators.get(field_spec.field_type)
        if validator and not validator(value):
            return False

        # Pattern validation
        if field_spec.pattern:
            if not re.match(field_spec.pattern, str(value)):
                return False

        # Range validation
        if field_spec.min_value is not None:
            try:
                if float(value) < field_spec.min_value:
                    return False
            except (ValueError, TypeError):
                pass

        if field_spec.max_value is not None:
            try:
                if float(value) > field_spec.max_value:
                    return False
            except (ValueError, TypeError):
                pass

        # Length validation
        if field_spec.min_length is not None:
            if len(str(value)) < field_spec.min_length:
                return False

        if field_spec.max_length is not None:
            if len(str(value)) > field_spec.max_length:
                return False

        # Allowed values
        if field_spec.allowed_values:
            if value not in field_spec.allowed_values:
                return False

        return True

    def _confidence_to_level(self, confidence: float) -> ConfidenceLevel:
        """Convert numeric confidence to level."""
        if confidence >= 0.9:
            return ConfidenceLevel.VERY_HIGH
        elif confidence >= 0.7:
            return ConfidenceLevel.HIGH
        elif confidence >= 0.5:
            return ConfidenceLevel.MEDIUM
        elif confidence >= 0.3:
            return ConfidenceLevel.LOW
        else:
            return ConfidenceLevel.VERY_LOW

    def _build_normalizers(self) -> Dict[FieldType, Callable]:
        """Build value normalizers for each type."""
        return {
            FieldType.STRING: lambda v: str(v).strip(),
            FieldType.INTEGER: lambda v: int(re.sub(r'[^\d-]', '', str(v))),
            FieldType.FLOAT: lambda v: float(re.sub(r'[^\d.-]', '', str(v))),
            FieldType.BOOLEAN: lambda v: str(v).lower() in ('true', 'yes', '1', 'y'),
            FieldType.CURRENCY: self._normalize_currency,
            FieldType.PERCENTAGE: lambda v: float(re.sub(r'[^\d.-]', '', str(v))),
            FieldType.EMAIL: lambda v: str(v).lower().strip(),
            FieldType.PHONE: self._normalize_phone,
        }

    def _build_validators(self) -> Dict[FieldType, Callable]:
        """Build validators for each type."""
        return {
            FieldType.EMAIL: lambda v: '@' in str(v) and '.' in str(v),
            FieldType.PHONE: lambda v: len(re.sub(r'\D', '', str(v))) >= 7,
            FieldType.DATE: lambda v: bool(re.search(r'\d', str(v))),
        }

    def _normalize_currency(self, value: str) -> str:
        """Normalize currency value."""
        # Remove currency symbols but keep the number
        amount = re.sub(r'[^\d.,]', '', str(value))
        # Handle European format (1.234,56) vs US format (1,234.56)
        if ',' in amount and '.' in amount:
            if amount.rfind(',') > amount.rfind('.'):
                # European format
                amount = amount.replace('.', '').replace(',', '.')
        elif ',' in amount:
            # Could be European decimal or US thousands
            parts = amount.split(',')
            if len(parts[-1]) == 2:
                # Likely European decimal
                amount = amount.replace(',', '.')
            else:
                # US thousands separator
                amount = amount.replace(',', '')
        return amount

    def _normalize_phone(self, value: str) -> str:
        """Normalize phone number."""
        digits = re.sub(r'\D', '', str(value))
        if len(digits) == 10:
            return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
        elif len(digits) == 11 and digits[0] == '1':
            return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
        return value