MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Core Document Intelligence Schemas
Pydantic models for OCR regions, layout regions, chunks, and evidence.
These form the foundation of the document processing pipeline.
"""
from enum import Enum
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
import hashlib
import json
class BoundingBox(BaseModel):
"""
Bounding box in normalized coordinates (0-1) or pixel coordinates.
Uses xyxy format: (x_min, y_min, x_max, y_max).
"""
x_min: float = Field(..., description="Left edge coordinate")
y_min: float = Field(..., description="Top edge coordinate")
x_max: float = Field(..., description="Right edge coordinate")
y_max: float = Field(..., description="Bottom edge coordinate")
# Optional: track if normalized (0-1) or pixel coordinates
normalized: bool = Field(default=False, description="True if coordinates are 0-1 normalized")
page_width: Optional[int] = Field(default=None, description="Original page width in pixels")
page_height: Optional[int] = Field(default=None, description="Original page height in pixels")
@field_validator('x_max')
@classmethod
def x_max_greater_than_x_min(cls, v, info):
if 'x_min' in info.data and v < info.data['x_min']:
raise ValueError('x_max must be >= x_min')
return v
@field_validator('y_max')
@classmethod
def y_max_greater_than_y_min(cls, v, info):
if 'y_min' in info.data and v < info.data['y_min']:
raise ValueError('y_max must be >= y_min')
return v
@property
def width(self) -> float:
return self.x_max - self.x_min
@property
def height(self) -> float:
return self.y_max - self.y_min
@property
def area(self) -> float:
return self.width * self.height
@property
def center(self) -> Tuple[float, float]:
return ((self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2)
def to_xyxy(self) -> Tuple[float, float, float, float]:
"""Return as (x_min, y_min, x_max, y_max) tuple."""
return (self.x_min, self.y_min, self.x_max, self.y_max)
def to_xywh(self) -> Tuple[float, float, float, float]:
"""Return as (x, y, width, height) tuple."""
return (self.x_min, self.y_min, self.width, self.height)
def normalize(self, width: int, height: int) -> "BoundingBox":
"""Convert pixel coordinates to normalized (0-1) coordinates."""
if self.normalized:
return self
return BoundingBox(
x_min=self.x_min / width,
y_min=self.y_min / height,
x_max=self.x_max / width,
y_max=self.y_max / height,
normalized=True,
page_width=width,
page_height=height,
)
def denormalize(self, width: int, height: int) -> "BoundingBox":
"""Convert normalized coordinates to pixel coordinates."""
if not self.normalized:
return self
return BoundingBox(
x_min=self.x_min * width,
y_min=self.y_min * height,
x_max=self.x_max * width,
y_max=self.y_max * height,
normalized=False,
page_width=width,
page_height=height,
)
def iou(self, other: "BoundingBox") -> float:
"""Calculate Intersection over Union with another bbox."""
x1 = max(self.x_min, other.x_min)
y1 = max(self.y_min, other.y_min)
x2 = min(self.x_max, other.x_max)
y2 = min(self.y_max, other.y_max)
if x2 < x1 or y2 < y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
union = self.area + other.area - intersection
return intersection / union if union > 0 else 0.0
def contains(self, other: "BoundingBox") -> bool:
"""Check if this bbox fully contains another."""
return (
self.x_min <= other.x_min and
self.y_min <= other.y_min and
self.x_max >= other.x_max and
self.y_max >= other.y_max
)
class OCRRegion(BaseModel):
"""
Result from OCR processing for a single text region.
Includes text, confidence, and precise location.
"""
text: str = Field(..., description="Recognized text content")
confidence: float = Field(..., ge=0.0, le=1.0, description="OCR confidence score")
bbox: BoundingBox = Field(..., description="Bounding box of the text region")
polygon: Optional[List[Tuple[float, float]]] = Field(
default=None,
description="Polygon points for non-rectangular regions"
)
page: int = Field(..., ge=0, description="Zero-indexed page number")
line_id: Optional[int] = Field(default=None, description="Line grouping ID")
word_id: Optional[int] = Field(default=None, description="Word index within line")
# OCR engine metadata
engine: str = Field(default="unknown", description="OCR engine used (paddle/tesseract)")
language: Optional[str] = Field(default=None, description="Detected language code")
def __hash__(self):
return hash((self.text, self.page, self.bbox.x_min, self.bbox.y_min))
class LayoutType(str, Enum):
"""Document layout region types."""
TEXT = "text"
TITLE = "title"
HEADING = "heading"
PARAGRAPH = "paragraph"
LIST = "list"
TABLE = "table"
FIGURE = "figure"
CHART = "chart"
FORMULA = "formula"
HEADER = "header"
FOOTER = "footer"
PAGE_NUMBER = "page_number"
CAPTION = "caption"
FOOTNOTE = "footnote"
WATERMARK = "watermark"
LOGO = "logo"
SIGNATURE = "signature"
UNKNOWN = "unknown"
class LayoutRegion(BaseModel):
"""
Result from layout detection for a document region.
Identifies structural elements like tables, figures, paragraphs.
"""
id: str = Field(..., description="Unique region identifier")
type: LayoutType = Field(..., description="Region type classification")
confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence")
bbox: BoundingBox = Field(..., description="Bounding box of the region")
page: int = Field(..., ge=0, description="Zero-indexed page number")
# Reading order
reading_order: Optional[int] = Field(
default=None,
description="Position in reading order (0 = first)"
)
# Hierarchy
parent_id: Optional[str] = Field(default=None, description="Parent region ID")
children_ids: List[str] = Field(default_factory=list, description="Child region IDs")
# Associated OCR regions
ocr_region_ids: List[int] = Field(
default_factory=list,
description="Indices of OCR regions within this layout region"
)
# Additional metadata
extra: Dict[str, Any] = Field(default_factory=dict, description="Type-specific metadata")
def __hash__(self):
return hash(self.id)
class ChunkType(str, Enum):
"""Document chunk types for semantic segmentation."""
TEXT = "text"
TITLE = "title"
HEADING = "heading"
PARAGRAPH = "paragraph"
LIST_ITEM = "list_item"
TABLE = "table"
TABLE_CELL = "table_cell"
FIGURE = "figure"
CHART = "chart"
FORMULA = "formula"
CAPTION = "caption"
FOOTNOTE = "footnote"
HEADER = "header"
FOOTER = "footer"
METADATA = "metadata"
class DocumentChunk(BaseModel):
"""
Semantic chunk of a document for retrieval and processing.
Contains text, location evidence, and metadata for grounding.
"""
chunk_id: str = Field(..., description="Unique chunk identifier")
chunk_type: ChunkType = Field(..., description="Semantic type of chunk")
text: str = Field(..., description="Text content of the chunk")
bbox: BoundingBox = Field(..., description="Bounding box covering the chunk")
page: int = Field(..., ge=0, description="Zero-indexed page number")
# Source tracking
document_id: str = Field(..., description="Parent document identifier")
source_path: Optional[str] = Field(default=None, description="Original file path")
# Sequence position
sequence_index: int = Field(..., ge=0, description="Position in document reading order")
# Confidence and quality
confidence: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="Chunk extraction confidence"
)
# Table-specific fields
table_cell_ids: Optional[List[str]] = Field(
default=None,
description="Cell IDs if this is a table chunk"
)
row_index: Optional[int] = Field(default=None, description="Table row index")
col_index: Optional[int] = Field(default=None, description="Table column index")
# Caption/reference linking
caption: Optional[str] = Field(default=None, description="Associated caption text")
references: List[str] = Field(
default_factory=list,
description="References to other chunks"
)
# Embedding placeholder
embedding: Optional[List[float]] = Field(
default=None,
description="Vector embedding for retrieval"
)
# Additional metadata
extra: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
@property
def content_hash(self) -> str:
"""Generate hash of chunk content for deduplication."""
content = f"{self.text}:{self.page}:{self.chunk_type}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def to_retrieval_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for vector store metadata."""
return {
"chunk_id": self.chunk_id,
"chunk_type": self.chunk_type.value,
"page": self.page,
"document_id": self.document_id,
"source_path": self.source_path,
"bbox_xyxy": self.bbox.to_xyxy(),
"sequence_index": self.sequence_index,
"confidence": self.confidence,
}
def __hash__(self):
return hash(self.chunk_id)
class EvidenceRef(BaseModel):
"""
Evidence reference for grounding extracted information.
Links extracted data back to source document locations.
"""
chunk_id: str = Field(..., description="Referenced chunk ID")
page: int = Field(..., ge=0, description="Page number")
bbox: BoundingBox = Field(..., description="Bounding box of evidence")
source_type: str = Field(..., description="Type of source (text/table/figure)")
snippet: str = Field(..., max_length=500, description="Text snippet as evidence")
confidence: float = Field(..., ge=0.0, le=1.0, description="Evidence confidence")
# Optional visual evidence
image_base64: Optional[str] = Field(
default=None,
description="Base64-encoded crop of the evidence region"
)
# Metadata
extra: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
def to_citation(self) -> str:
"""Format as a human-readable citation."""
return f"[Page {self.page + 1}, {self.source_type}]: \"{self.snippet[:100]}...\""
class ExtractionResult(BaseModel):
"""
Result of a field extraction or analysis task.
Always includes evidence for grounding.
"""
data: Dict[str, Any] = Field(..., description="Extracted data dictionary")
evidence: List[EvidenceRef] = Field(
default_factory=list,
description="Evidence supporting the extraction"
)
warnings: List[str] = Field(
default_factory=list,
description="Warnings or issues encountered"
)
confidence: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="Overall extraction confidence"
)
# Abstention tracking
abstained_fields: List[str] = Field(
default_factory=list,
description="Fields where extraction was abstained due to low confidence"
)
# Processing metadata
processing_time_ms: Optional[float] = Field(
default=None,
description="Processing time in milliseconds"
)
model_used: Optional[str] = Field(default=None, description="Model used for extraction")
@property
def is_grounded(self) -> bool:
"""Check if all extracted data has evidence."""
return len(self.evidence) > 0 and len(self.abstained_fields) == 0
def add_warning(self, warning: str):
"""Add a warning message."""
self.warnings.append(warning)
def abstain(self, field: str, reason: str):
"""Mark a field as abstained with reason."""
self.abstained_fields.append(field)
self.warnings.append(f"Abstained from extracting '{field}': {reason}")
class DocumentMetadata(BaseModel):
"""Metadata about a processed document."""
document_id: str = Field(..., description="Unique document identifier")
source_path: str = Field(..., description="Original file path")
filename: str = Field(..., description="Original filename")
file_type: str = Field(..., description="File type (pdf/image/etc)")
file_size_bytes: int = Field(..., ge=0, description="File size in bytes")
# Page information
num_pages: int = Field(..., ge=1, description="Total number of pages")
page_dimensions: List[Tuple[int, int]] = Field(
default_factory=list,
description="(width, height) for each page"
)
# Processing timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
processed_at: Optional[datetime] = Field(default=None)
# Content statistics
total_chunks: int = Field(default=0, description="Number of chunks extracted")
total_characters: int = Field(default=0, description="Total character count")
# Language detection
detected_language: Optional[str] = Field(default=None, description="Primary language")
language_confidence: Optional[float] = Field(default=None)
# Quality metrics
ocr_confidence_avg: Optional[float] = Field(default=None)
layout_confidence_avg: Optional[float] = Field(default=None)
# Additional metadata
extra: Dict[str, Any] = Field(default_factory=dict)
class ProcessedDocument(BaseModel):
"""
Complete processed document with all extracted information.
This is the main output of the document processing pipeline.
"""
metadata: DocumentMetadata = Field(..., description="Document metadata")
# OCR results
ocr_regions: List[OCRRegion] = Field(
default_factory=list,
description="All OCR regions"
)
# Layout analysis results
layout_regions: List[LayoutRegion] = Field(
default_factory=list,
description="All layout regions"
)
# Semantic chunks
chunks: List[DocumentChunk] = Field(
default_factory=list,
description="Document chunks for retrieval"
)
# Full text (reading order)
full_text: str = Field(default="", description="Full text in reading order")
# Processing status
status: str = Field(default="completed", description="Processing status")
errors: List[str] = Field(default_factory=list, description="Processing errors")
warnings: List[str] = Field(default_factory=list, description="Processing warnings")
def get_page_chunks(self, page: int) -> List[DocumentChunk]:
"""Get all chunks for a specific page."""
return [c for c in self.chunks if c.page == page]
def get_chunks_by_type(self, chunk_type: ChunkType) -> List[DocumentChunk]:
"""Get all chunks of a specific type."""
return [c for c in self.chunks if c.chunk_type == chunk_type]
def to_json(self, indent: int = 2) -> str:
"""Serialize to JSON string."""
return self.model_dump_json(indent=indent)
@classmethod
def from_json(cls, json_str: str) -> "ProcessedDocument":
"""Deserialize from JSON string."""
return cls.model_validate_json(json_str)
def save(self, path: str):
"""Save to JSON file."""
with open(path, "w", encoding="utf-8") as f:
f.write(self.to_json())
@classmethod
def load(cls, path: str) -> "ProcessedDocument":
"""Load from JSON file."""
with open(path, "r", encoding="utf-8") as f:
return cls.from_json(f.read())