MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Document Intelligence Tools for Agents
Tool implementations for DocumentAgent integration.
Each tool is designed for ReAct-style agent execution.
"""
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
logger = logging.getLogger(__name__)
@dataclass
class ToolResult:
"""Result from a tool execution."""
success: bool
data: Any = None
error: Optional[str] = None
evidence: List[Dict[str, Any]] = None
def __post_init__(self):
if self.evidence is None:
self.evidence = []
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
"data": self.data,
"error": self.error,
"evidence": self.evidence,
}
class DocumentTool:
"""Base class for document tools."""
name: str = "base_tool"
description: str = "Base document tool"
def execute(self, **kwargs) -> ToolResult:
"""Execute the tool."""
raise NotImplementedError
class ParseDocumentTool(DocumentTool):
"""
Parse a document into semantic chunks.
Input:
path: Path to document file
max_pages: Optional maximum pages to process
Output:
ParseResult with chunks and metadata
"""
name = "parse_document"
description = "Parse a document into semantic chunks with OCR and layout detection"
def __init__(self, parser=None):
from ..parsing import DocumentParser
self.parser = parser or DocumentParser()
def execute(
self,
path: str,
max_pages: Optional[int] = None,
**kwargs
) -> ToolResult:
try:
# Update config if max_pages specified
if max_pages:
self.parser.config.max_pages = max_pages
result = self.parser.parse(path)
return ToolResult(
success=True,
data={
"doc_id": result.doc_id,
"filename": result.filename,
"num_pages": result.num_pages,
"num_chunks": len(result.chunks),
"chunks": [
{
"chunk_id": c.chunk_id,
"type": c.chunk_type.value,
"text": c.text[:500], # Truncate for display
"page": c.page,
"confidence": c.confidence,
}
for c in result.chunks[:20] # Limit for display
],
"markdown_preview": result.markdown_full[:2000],
},
)
except Exception as e:
logger.error(f"Parse document failed: {e}")
return ToolResult(success=False, error=str(e))
class ExtractFieldsTool(DocumentTool):
"""
Extract fields from a parsed document using a schema.
Input:
parse_result: Previously parsed document
schema: Extraction schema (dict or ExtractionSchema)
fields: Optional list of specific fields to extract
Output:
ExtractionResult with values and evidence
"""
name = "extract_fields"
description = "Extract structured fields from document using a schema"
def __init__(self, extractor=None):
from ..extraction import FieldExtractor
self.extractor = extractor or FieldExtractor()
def execute(
self,
parse_result: Any,
schema: Union[Dict, Any],
fields: Optional[List[str]] = None,
**kwargs
) -> ToolResult:
try:
from ..extraction import ExtractionSchema
# Convert dict schema to ExtractionSchema
if isinstance(schema, dict):
schema = ExtractionSchema.from_json_schema(schema)
# Filter fields if specified
if fields:
schema.fields = [f for f in schema.fields if f.name in fields]
result = self.extractor.extract(parse_result, schema)
return ToolResult(
success=True,
data={
"extracted_data": result.data,
"confidence": result.overall_confidence,
"abstained_fields": result.abstained_fields,
},
evidence=[
{
"chunk_id": e.chunk_id,
"page": e.page,
"bbox": e.bbox.xyxy,
"snippet": e.snippet,
"confidence": e.confidence,
}
for e in result.evidence
],
)
except Exception as e:
logger.error(f"Extract fields failed: {e}")
return ToolResult(success=False, error=str(e))
class SearchChunksTool(DocumentTool):
"""
Search for chunks containing specific text or matching criteria.
Input:
parse_result: Parsed document
query: Search query
chunk_types: Optional list of chunk types to filter
top_k: Maximum results to return
Output:
List of matching chunks with scores
"""
name = "search_chunks"
description = "Search document chunks for specific content"
def execute(
self,
parse_result: Any,
query: str,
chunk_types: Optional[List[str]] = None,
top_k: int = 10,
**kwargs
) -> ToolResult:
try:
from ..chunks import ChunkType
query_lower = query.lower()
results = []
for chunk in parse_result.chunks:
# Filter by type
if chunk_types:
if chunk.chunk_type.value not in chunk_types:
continue
# Simple text matching with scoring
text_lower = chunk.text.lower()
if query_lower in text_lower:
# Calculate relevance score
count = text_lower.count(query_lower)
position = text_lower.find(query_lower)
score = count * 10 + (1 / (position + 1)) * 5
results.append({
"chunk_id": chunk.chunk_id,
"type": chunk.chunk_type.value,
"text": chunk.text[:300],
"page": chunk.page,
"score": score,
"bbox": chunk.bbox.xyxy,
})
# Sort by score and limit
results.sort(key=lambda x: x["score"], reverse=True)
results = results[:top_k]
return ToolResult(
success=True,
data={
"query": query,
"total_matches": len(results),
"results": results,
},
)
except Exception as e:
logger.error(f"Search chunks failed: {e}")
return ToolResult(success=False, error=str(e))
class GetChunkDetailsTool(DocumentTool):
"""
Get detailed information about a specific chunk.
Input:
parse_result: Parsed document
chunk_id: ID of chunk to retrieve
Output:
Full chunk details including content and metadata
"""
name = "get_chunk_details"
description = "Get detailed information about a specific chunk"
def execute(
self,
parse_result: Any,
chunk_id: str,
**kwargs
) -> ToolResult:
try:
from ..chunks import TableChunk, ChartChunk
# Find chunk
chunk = None
for c in parse_result.chunks:
if c.chunk_id == chunk_id:
chunk = c
break
if chunk is None:
return ToolResult(
success=False,
error=f"Chunk not found: {chunk_id}"
)
data = {
"chunk_id": chunk.chunk_id,
"doc_id": chunk.doc_id,
"type": chunk.chunk_type.value,
"text": chunk.text,
"page": chunk.page,
"bbox": {
"x_min": chunk.bbox.x_min,
"y_min": chunk.bbox.y_min,
"x_max": chunk.bbox.x_max,
"y_max": chunk.bbox.y_max,
"normalized": chunk.bbox.normalized,
},
"confidence": chunk.confidence,
"sequence_index": chunk.sequence_index,
}
# Add type-specific data
if isinstance(chunk, TableChunk):
data["table"] = {
"num_rows": chunk.num_rows,
"num_cols": chunk.num_cols,
"markdown": chunk.to_markdown(),
"csv": chunk.to_csv(),
}
elif isinstance(chunk, ChartChunk):
data["chart"] = {
"chart_type": chunk.chart_type,
"title": chunk.title,
"data_points": len(chunk.data_points),
"trends": chunk.trends,
}
return ToolResult(success=True, data=data)
except Exception as e:
logger.error(f"Get chunk details failed: {e}")
return ToolResult(success=False, error=str(e))
class GetTableDataTool(DocumentTool):
"""
Get structured data from a table chunk.
Input:
parse_result: Parsed document
chunk_id: ID of table chunk
format: Output format (json, csv, markdown)
Output:
Table data in requested format
"""
name = "get_table_data"
description = "Extract structured data from a table"
def execute(
self,
parse_result: Any,
chunk_id: str,
format: str = "json",
**kwargs
) -> ToolResult:
try:
from ..chunks import TableChunk
# Find table chunk
table = None
for c in parse_result.chunks:
if c.chunk_id == chunk_id and isinstance(c, TableChunk):
table = c
break
if table is None:
return ToolResult(
success=False,
error=f"Table chunk not found: {chunk_id}"
)
if format == "csv":
data = table.to_csv()
elif format == "markdown":
data = table.to_markdown()
else: # json
data = table.to_structured_json()
return ToolResult(
success=True,
data={
"chunk_id": chunk_id,
"format": format,
"num_rows": table.num_rows,
"num_cols": table.num_cols,
"content": data,
},
evidence=[{
"chunk_id": chunk_id,
"page": table.page,
"bbox": table.bbox.xyxy,
"source_type": "table",
}],
)
except Exception as e:
logger.error(f"Get table data failed: {e}")
return ToolResult(success=False, error=str(e))
class AnswerQuestionTool(DocumentTool):
"""
Answer a question about the document using available chunks.
Input:
parse_result: Parsed document
question: Question to answer
use_rag: Whether to use RAG for retrieval (requires indexed document)
document_id: Document ID for RAG retrieval (defaults to parse_result.doc_id)
top_k: Number of chunks to consider
Output:
Answer with supporting evidence
"""
name = "answer_question"
description = "Answer a question about the document content"
def __init__(self, llm_client=None):
self.llm_client = llm_client
def execute(
self,
parse_result: Any,
question: str,
use_rag: bool = False,
document_id: Optional[str] = None,
top_k: int = 5,
**kwargs
) -> ToolResult:
try:
# Use RAG if requested and available
if use_rag:
return self._answer_with_rag(
question=question,
document_id=document_id or (parse_result.doc_id if parse_result else None),
top_k=top_k,
)
# Fall back to keyword-based search on parse_result
return self._answer_with_keywords(
parse_result=parse_result,
question=question,
top_k=top_k,
)
except Exception as e:
logger.error(f"Answer question failed: {e}")
return ToolResult(success=False, error=str(e))
def _answer_with_rag(
self,
question: str,
document_id: Optional[str],
top_k: int,
) -> ToolResult:
"""Answer using RAG retrieval."""
try:
from .rag_tools import RAGAnswerTool
rag_tool = RAGAnswerTool(llm_client=self.llm_client)
return rag_tool.execute(
question=question,
document_id=document_id,
top_k=top_k,
)
except ImportError:
return ToolResult(
success=False,
error="RAG module not available. Use use_rag=False or install chromadb."
)
def _answer_with_keywords(
self,
parse_result: Any,
question: str,
top_k: int,
) -> ToolResult:
"""Answer using keyword-based search on parse_result."""
if parse_result is None:
return ToolResult(
success=False,
error="parse_result is required when use_rag=False"
)
# Find relevant chunks using keyword matching
question_lower = question.lower()
relevant_chunks = []
for chunk in parse_result.chunks:
text_lower = chunk.text.lower()
# Check for keyword overlap
keywords = [w for w in question_lower.split() if len(w) > 3]
matches = sum(1 for k in keywords if k in text_lower)
if matches > 0:
relevant_chunks.append((chunk, matches))
# Sort by relevance
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
top_chunks = relevant_chunks[:top_k]
if not top_chunks:
return ToolResult(
success=True,
data={
"question": question,
"answer": "I could not find relevant information in the document to answer this question.",
"confidence": 0.0,
"abstained": True,
},
)
# Build context
context = "\n\n".join(
f"[Page {c.page}] {c.text}"
for c, _ in top_chunks
)
# If no LLM, return context-based answer
if self.llm_client is None:
return ToolResult(
success=True,
data={
"question": question,
"answer": f"Based on the document: {top_chunks[0][0].text[:500]}",
"confidence": 0.6,
"context_chunks": len(top_chunks),
},
evidence=[
{
"chunk_id": c.chunk_id,
"page": c.page,
"bbox": c.bbox.xyxy,
"snippet": c.text[:200],
}
for c, _ in top_chunks
],
)
# Use LLM to generate answer if available
try:
from ...rag import get_grounded_generator
generator = get_grounded_generator(llm_client=self.llm_client)
# Convert chunks to format expected by generator
chunk_dicts = [
{
"chunk_id": c.chunk_id,
"document_id": c.doc_id,
"text": c.text,
"similarity": score / 10.0, # Normalize score
"page": c.page,
"chunk_type": c.chunk_type.value,
}
for c, score in top_chunks
]
answer = generator.generate_answer(
question=question,
context=context,
chunks=chunk_dicts,
)
return ToolResult(
success=True,
data={
"question": question,
"answer": answer.text,
"confidence": answer.confidence,
"abstained": answer.abstained,
},
evidence=[
{
"chunk_id": c.chunk_id,
"page": c.page,
"bbox": c.bbox.xyxy,
"snippet": c.text[:200],
}
for c, _ in top_chunks
],
)
except ImportError:
# Fall back to simple answer without LLM generation
return ToolResult(
success=True,
data={
"question": question,
"answer": f"Based on the document: {top_chunks[0][0].text[:500]}",
"confidence": 0.6,
"context_chunks": len(top_chunks),
},
evidence=[
{
"chunk_id": c.chunk_id,
"page": c.page,
"bbox": c.bbox.xyxy,
"snippet": c.text[:200],
}
for c, _ in top_chunks
],
)
class CropRegionTool(DocumentTool):
"""
Crop a region from a document page image.
Input:
doc_path: Path to document
page: Page number (1-indexed)
bbox: Bounding box (x_min, y_min, x_max, y_max)
output_path: Optional path to save crop
Output:
Crop image path or base64 data
"""
name = "crop_region"
description = "Crop a specific region from a document page"
def execute(
self,
doc_path: str,
page: int,
bbox: List[float],
output_path: Optional[str] = None,
**kwargs
) -> ToolResult:
try:
from ..io import load_document, RenderOptions
from ..grounding import crop_region
from ..chunks import BoundingBox
from PIL import Image
# Load and render page
loader, renderer = load_document(doc_path)
page_image = renderer.render_page(page, RenderOptions(dpi=200))
loader.close()
# Create bbox
bbox_obj = BoundingBox(
x_min=bbox[0],
y_min=bbox[1],
x_max=bbox[2],
y_max=bbox[3],
normalized=True, # Assume normalized
)
# Crop
crop = crop_region(page_image, bbox_obj)
# Save or return
if output_path:
Image.fromarray(crop).save(output_path)
return ToolResult(
success=True,
data={
"output_path": output_path,
"width": crop.shape[1],
"height": crop.shape[0],
},
)
else:
import base64
import io
pil_img = Image.fromarray(crop)
buffer = io.BytesIO()
pil_img.save(buffer, format="PNG")
b64 = base64.b64encode(buffer.getvalue()).decode()
return ToolResult(
success=True,
data={
"width": crop.shape[1],
"height": crop.shape[0],
"base64": b64[:100] + "...", # Truncated for display
},
)
except Exception as e:
logger.error(f"Crop region failed: {e}")
return ToolResult(success=False, error=str(e))
# Tool registry for agent use
DOCUMENT_TOOLS = {
"parse_document": ParseDocumentTool,
"extract_fields": ExtractFieldsTool,
"search_chunks": SearchChunksTool,
"get_chunk_details": GetChunkDetailsTool,
"get_table_data": GetTableDataTool,
"answer_question": AnswerQuestionTool,
"crop_region": CropRegionTool,
}
def get_tool(name: str, **kwargs) -> DocumentTool:
"""Get a tool instance by name."""
if name not in DOCUMENT_TOOLS:
raise ValueError(f"Unknown tool: {name}")
return DOCUMENT_TOOLS[name](**kwargs)
def list_tools() -> List[Dict[str, str]]:
"""List all available tools."""
return [
{"name": name, "description": cls.description}
for name, cls in DOCUMENT_TOOLS.items()
]