|
|
""" |
|
|
Field Extraction Engine |
|
|
|
|
|
Extracts structured data from parsed documents using schemas. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import re |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
from ..chunks.models import ( |
|
|
DocumentChunk, |
|
|
ExtractionResult, |
|
|
FieldExtraction, |
|
|
EvidenceRef, |
|
|
ParseResult, |
|
|
TableChunk, |
|
|
ChartChunk, |
|
|
ChunkType, |
|
|
ConfidenceLevel, |
|
|
) |
|
|
from ..grounding.evidence import EvidenceBuilder, EvidenceTracker |
|
|
from .schema import ExtractionSchema, FieldSpec, FieldType |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExtractionConfig: |
|
|
"""Configuration for field extraction.""" |
|
|
|
|
|
|
|
|
min_field_confidence: float = 0.5 |
|
|
min_overall_confidence: float = 0.5 |
|
|
|
|
|
|
|
|
abstain_on_low_confidence: bool = True |
|
|
abstain_threshold: float = 0.3 |
|
|
|
|
|
|
|
|
search_all_chunks: bool = True |
|
|
prefer_structured_sources: bool = True |
|
|
|
|
|
|
|
|
validate_extracted_values: bool = True |
|
|
normalize_values: bool = True |
|
|
|
|
|
|
|
|
class FieldExtractor: |
|
|
""" |
|
|
Extracts fields from parsed documents. |
|
|
|
|
|
Uses schema definitions to identify and extract |
|
|
structured data with evidence grounding. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Optional[ExtractionConfig] = None, |
|
|
evidence_builder: Optional[EvidenceBuilder] = None, |
|
|
): |
|
|
self.config = config or ExtractionConfig() |
|
|
self.evidence_builder = evidence_builder or EvidenceBuilder() |
|
|
self._normalizers: Dict[FieldType, Callable] = self._build_normalizers() |
|
|
self._validators: Dict[FieldType, Callable] = self._build_validators() |
|
|
|
|
|
def extract( |
|
|
self, |
|
|
parse_result: ParseResult, |
|
|
schema: ExtractionSchema, |
|
|
) -> ExtractionResult: |
|
|
""" |
|
|
Extract fields from a parsed document. |
|
|
|
|
|
Args: |
|
|
parse_result: Parsed document with chunks |
|
|
schema: Extraction schema defining fields |
|
|
|
|
|
Returns: |
|
|
ExtractionResult with extracted values and evidence |
|
|
""" |
|
|
logger.info(f"Extracting {len(schema.fields)} fields from {parse_result.filename}") |
|
|
|
|
|
evidence_tracker = EvidenceTracker() |
|
|
field_extractions: List[FieldExtraction] = [] |
|
|
extracted_data: Dict[str, Any] = {} |
|
|
abstained_fields: List[str] = [] |
|
|
|
|
|
for field_spec in schema.fields: |
|
|
extraction = self._extract_field( |
|
|
field_spec=field_spec, |
|
|
chunks=parse_result.chunks, |
|
|
evidence_tracker=evidence_tracker, |
|
|
) |
|
|
|
|
|
if extraction: |
|
|
field_extractions.append(extraction) |
|
|
extracted_data[field_spec.name] = extraction.value |
|
|
|
|
|
|
|
|
if extraction.confidence < self.config.abstain_threshold: |
|
|
if self.config.abstain_on_low_confidence: |
|
|
abstained_fields.append(field_spec.name) |
|
|
extracted_data[field_spec.name] = None |
|
|
else: |
|
|
|
|
|
if field_spec.required: |
|
|
abstained_fields.append(field_spec.name) |
|
|
extracted_data[field_spec.name] = field_spec.default |
|
|
|
|
|
|
|
|
if field_extractions: |
|
|
overall_confidence = sum(f.confidence for f in field_extractions) / len(field_extractions) |
|
|
else: |
|
|
overall_confidence = 0.0 |
|
|
|
|
|
return ExtractionResult( |
|
|
data=extracted_data, |
|
|
fields=field_extractions, |
|
|
evidence=evidence_tracker.get_all(), |
|
|
overall_confidence=overall_confidence, |
|
|
abstained_fields=abstained_fields, |
|
|
) |
|
|
|
|
|
def _extract_field( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
chunks: List[DocumentChunk], |
|
|
evidence_tracker: EvidenceTracker, |
|
|
) -> Optional[FieldExtraction]: |
|
|
"""Extract a single field from chunks.""" |
|
|
candidates: List[Tuple[Any, float, DocumentChunk]] = [] |
|
|
|
|
|
|
|
|
relevant_chunks = self._find_relevant_chunks(field_spec, chunks) |
|
|
|
|
|
for chunk in relevant_chunks: |
|
|
value, confidence = self._extract_from_chunk(field_spec, chunk) |
|
|
|
|
|
if value is not None and confidence >= self.config.min_field_confidence: |
|
|
candidates.append((value, confidence, chunk)) |
|
|
|
|
|
if not candidates: |
|
|
return None |
|
|
|
|
|
|
|
|
candidates.sort(key=lambda x: x[1], reverse=True) |
|
|
best_value, best_confidence, best_chunk = candidates[0] |
|
|
|
|
|
|
|
|
if self.config.normalize_values: |
|
|
best_value = self._normalize_value(best_value, field_spec.field_type) |
|
|
|
|
|
|
|
|
if self.config.validate_extracted_values: |
|
|
is_valid = self._validate_value(best_value, field_spec) |
|
|
if not is_valid: |
|
|
best_confidence *= 0.5 |
|
|
|
|
|
|
|
|
evidence = self.evidence_builder.create_evidence( |
|
|
chunk=best_chunk, |
|
|
value=best_value, |
|
|
field_name=field_spec.name, |
|
|
) |
|
|
evidence_tracker.add(evidence, field_spec.name) |
|
|
|
|
|
|
|
|
confidence_level = self._confidence_to_level(best_confidence) |
|
|
|
|
|
return FieldExtraction( |
|
|
field_name=field_spec.name, |
|
|
value=best_value, |
|
|
confidence=best_confidence, |
|
|
confidence_level=confidence_level, |
|
|
evidence=evidence, |
|
|
raw_text=best_chunk.text[:200], |
|
|
) |
|
|
|
|
|
def _find_relevant_chunks( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
chunks: List[DocumentChunk], |
|
|
) -> List[DocumentChunk]: |
|
|
"""Find chunks that might contain the field value.""" |
|
|
|
|
|
search_terms = [field_spec.name.lower().replace("_", " ")] |
|
|
search_terms.extend(a.lower() for a in field_spec.aliases) |
|
|
search_terms.extend(h.lower() for h in field_spec.context_hints) |
|
|
|
|
|
relevant = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
|
|
|
if self.config.prefer_structured_sources: |
|
|
if isinstance(chunk, (TableChunk, )) or chunk.chunk_type == ChunkType.FORM_FIELD: |
|
|
relevant.append(chunk) |
|
|
continue |
|
|
|
|
|
|
|
|
text_lower = chunk.text.lower() |
|
|
for term in search_terms: |
|
|
if term in text_lower: |
|
|
relevant.append(chunk) |
|
|
break |
|
|
|
|
|
|
|
|
if not relevant and self.config.search_all_chunks: |
|
|
return chunks |
|
|
|
|
|
return relevant |
|
|
|
|
|
def _extract_from_chunk( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
chunk: DocumentChunk, |
|
|
) -> Tuple[Optional[Any], float]: |
|
|
"""Extract field value from a single chunk.""" |
|
|
|
|
|
if isinstance(chunk, TableChunk): |
|
|
return self._extract_from_table(field_spec, chunk) |
|
|
|
|
|
|
|
|
return self._extract_from_text(field_spec, chunk.text) |
|
|
|
|
|
def _extract_from_table( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
table: TableChunk, |
|
|
) -> Tuple[Optional[Any], float]: |
|
|
"""Extract field from a table chunk.""" |
|
|
search_terms = [field_spec.name.lower().replace("_", " ")] |
|
|
search_terms.extend(a.lower() for a in field_spec.aliases) |
|
|
|
|
|
|
|
|
for col_idx in range(table.num_cols): |
|
|
header_cell = table.get_cell(0, col_idx) |
|
|
if header_cell is None: |
|
|
continue |
|
|
|
|
|
header_text = header_cell.text.lower() |
|
|
for term in search_terms: |
|
|
if term in header_text: |
|
|
|
|
|
value_cell = table.get_cell(1, col_idx) |
|
|
if value_cell and value_cell.text: |
|
|
return value_cell.text, value_cell.confidence |
|
|
|
|
|
|
|
|
for row_idx in range(table.num_rows): |
|
|
label_cell = table.get_cell(row_idx, 0) |
|
|
if label_cell is None: |
|
|
continue |
|
|
|
|
|
label_text = label_cell.text.lower() |
|
|
for term in search_terms: |
|
|
if term in label_text: |
|
|
|
|
|
value_cell = table.get_cell(row_idx, 1) |
|
|
if value_cell and value_cell.text: |
|
|
return value_cell.text, value_cell.confidence |
|
|
|
|
|
return None, 0.0 |
|
|
|
|
|
def _extract_from_text( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
text: str, |
|
|
) -> Tuple[Optional[Any], float]: |
|
|
"""Extract field from text using patterns.""" |
|
|
|
|
|
patterns = self._get_extraction_patterns(field_spec) |
|
|
|
|
|
for pattern, confidence_boost in patterns: |
|
|
matches = re.findall(pattern, text, re.IGNORECASE) |
|
|
if matches: |
|
|
|
|
|
value = matches[0] |
|
|
if isinstance(value, tuple): |
|
|
value = value[0] |
|
|
return value.strip(), 0.7 + confidence_boost |
|
|
|
|
|
|
|
|
search_terms = [field_spec.name.replace("_", " ")] |
|
|
search_terms.extend(field_spec.aliases) |
|
|
|
|
|
for term in search_terms: |
|
|
|
|
|
pattern = rf"{re.escape(term)}[\s::\-]+([^\n]+)" |
|
|
matches = re.findall(pattern, text, re.IGNORECASE) |
|
|
if matches: |
|
|
return matches[0].strip(), 0.6 |
|
|
|
|
|
return None, 0.0 |
|
|
|
|
|
def _get_extraction_patterns( |
|
|
self, |
|
|
field_spec: FieldSpec, |
|
|
) -> List[Tuple[str, float]]: |
|
|
"""Get regex patterns for field type.""" |
|
|
patterns = [] |
|
|
|
|
|
|
|
|
if field_spec.pattern: |
|
|
patterns.append((field_spec.pattern, 0.2)) |
|
|
|
|
|
|
|
|
if field_spec.field_type == FieldType.DATE: |
|
|
patterns.extend([ |
|
|
(r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', 0.1), |
|
|
(r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', 0.1), |
|
|
(r'\b([A-Z][a-z]+\s+\d{1,2},?\s+\d{4})\b', 0.1), |
|
|
]) |
|
|
elif field_spec.field_type == FieldType.CURRENCY: |
|
|
patterns.extend([ |
|
|
(r'[\$\€\£][\s]*([\d,]+\.?\d*)', 0.2), |
|
|
(r'([\d,]+\.?\d*)\s*(?:USD|EUR|GBP)', 0.1), |
|
|
]) |
|
|
elif field_spec.field_type == FieldType.PERCENTAGE: |
|
|
patterns.append((r'([\d.]+)\s*%', 0.2)) |
|
|
elif field_spec.field_type == FieldType.EMAIL: |
|
|
patterns.append((r'([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', 0.3)) |
|
|
elif field_spec.field_type == FieldType.PHONE: |
|
|
patterns.extend([ |
|
|
(r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}', 0.2), |
|
|
(r'\+\d{1,3}[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}', 0.2), |
|
|
]) |
|
|
elif field_spec.field_type == FieldType.INTEGER: |
|
|
patterns.append((r'\b(\d+)\b', 0.0)) |
|
|
elif field_spec.field_type == FieldType.FLOAT: |
|
|
patterns.append((r'\b(\d+\.?\d*)\b', 0.0)) |
|
|
|
|
|
return patterns |
|
|
|
|
|
def _normalize_value(self, value: Any, field_type: FieldType) -> Any: |
|
|
"""Normalize extracted value.""" |
|
|
normalizer = self._normalizers.get(field_type) |
|
|
if normalizer: |
|
|
try: |
|
|
return normalizer(value) |
|
|
except Exception: |
|
|
pass |
|
|
return value |
|
|
|
|
|
def _validate_value(self, value: Any, field_spec: FieldSpec) -> bool: |
|
|
"""Validate extracted value against field spec.""" |
|
|
if value is None: |
|
|
return not field_spec.required |
|
|
|
|
|
|
|
|
validator = self._validators.get(field_spec.field_type) |
|
|
if validator and not validator(value): |
|
|
return False |
|
|
|
|
|
|
|
|
if field_spec.pattern: |
|
|
if not re.match(field_spec.pattern, str(value)): |
|
|
return False |
|
|
|
|
|
|
|
|
if field_spec.min_value is not None: |
|
|
try: |
|
|
if float(value) < field_spec.min_value: |
|
|
return False |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
if field_spec.max_value is not None: |
|
|
try: |
|
|
if float(value) > field_spec.max_value: |
|
|
return False |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
|
|
|
if field_spec.min_length is not None: |
|
|
if len(str(value)) < field_spec.min_length: |
|
|
return False |
|
|
|
|
|
if field_spec.max_length is not None: |
|
|
if len(str(value)) > field_spec.max_length: |
|
|
return False |
|
|
|
|
|
|
|
|
if field_spec.allowed_values: |
|
|
if value not in field_spec.allowed_values: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _confidence_to_level(self, confidence: float) -> ConfidenceLevel: |
|
|
"""Convert numeric confidence to level.""" |
|
|
if confidence >= 0.9: |
|
|
return ConfidenceLevel.VERY_HIGH |
|
|
elif confidence >= 0.7: |
|
|
return ConfidenceLevel.HIGH |
|
|
elif confidence >= 0.5: |
|
|
return ConfidenceLevel.MEDIUM |
|
|
elif confidence >= 0.3: |
|
|
return ConfidenceLevel.LOW |
|
|
else: |
|
|
return ConfidenceLevel.VERY_LOW |
|
|
|
|
|
def _build_normalizers(self) -> Dict[FieldType, Callable]: |
|
|
"""Build value normalizers for each type.""" |
|
|
return { |
|
|
FieldType.STRING: lambda v: str(v).strip(), |
|
|
FieldType.INTEGER: lambda v: int(re.sub(r'[^\d-]', '', str(v))), |
|
|
FieldType.FLOAT: lambda v: float(re.sub(r'[^\d.-]', '', str(v))), |
|
|
FieldType.BOOLEAN: lambda v: str(v).lower() in ('true', 'yes', '1', 'y'), |
|
|
FieldType.CURRENCY: self._normalize_currency, |
|
|
FieldType.PERCENTAGE: lambda v: float(re.sub(r'[^\d.-]', '', str(v))), |
|
|
FieldType.EMAIL: lambda v: str(v).lower().strip(), |
|
|
FieldType.PHONE: self._normalize_phone, |
|
|
} |
|
|
|
|
|
def _build_validators(self) -> Dict[FieldType, Callable]: |
|
|
"""Build validators for each type.""" |
|
|
return { |
|
|
FieldType.EMAIL: lambda v: '@' in str(v) and '.' in str(v), |
|
|
FieldType.PHONE: lambda v: len(re.sub(r'\D', '', str(v))) >= 7, |
|
|
FieldType.DATE: lambda v: bool(re.search(r'\d', str(v))), |
|
|
} |
|
|
|
|
|
def _normalize_currency(self, value: str) -> str: |
|
|
"""Normalize currency value.""" |
|
|
|
|
|
amount = re.sub(r'[^\d.,]', '', str(value)) |
|
|
|
|
|
if ',' in amount and '.' in amount: |
|
|
if amount.rfind(',') > amount.rfind('.'): |
|
|
|
|
|
amount = amount.replace('.', '').replace(',', '.') |
|
|
elif ',' in amount: |
|
|
|
|
|
parts = amount.split(',') |
|
|
if len(parts[-1]) == 2: |
|
|
|
|
|
amount = amount.replace(',', '.') |
|
|
else: |
|
|
|
|
|
amount = amount.replace(',', '') |
|
|
return amount |
|
|
|
|
|
def _normalize_phone(self, value: str) -> str: |
|
|
"""Normalize phone number.""" |
|
|
digits = re.sub(r'\D', '', str(value)) |
|
|
if len(digits) == 10: |
|
|
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}" |
|
|
elif len(digits) == 11 and digits[0] == '1': |
|
|
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}" |
|
|
return value |
|
|
|