|
|
""" |
|
|
Document Intelligence Tools for Agents |
|
|
|
|
|
Tool implementations for DocumentAgent integration. |
|
|
Each tool is designed for ReAct-style agent execution. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolResult: |
|
|
"""Result from a tool execution.""" |
|
|
|
|
|
success: bool |
|
|
data: Any = None |
|
|
error: Optional[str] = None |
|
|
evidence: List[Dict[str, Any]] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.evidence is None: |
|
|
self.evidence = [] |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"success": self.success, |
|
|
"data": self.data, |
|
|
"error": self.error, |
|
|
"evidence": self.evidence, |
|
|
} |
|
|
|
|
|
|
|
|
class DocumentTool: |
|
|
"""Base class for document tools.""" |
|
|
|
|
|
name: str = "base_tool" |
|
|
description: str = "Base document tool" |
|
|
|
|
|
def execute(self, **kwargs) -> ToolResult: |
|
|
"""Execute the tool.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class ParseDocumentTool(DocumentTool): |
|
|
""" |
|
|
Parse a document into semantic chunks. |
|
|
|
|
|
Input: |
|
|
path: Path to document file |
|
|
max_pages: Optional maximum pages to process |
|
|
|
|
|
Output: |
|
|
ParseResult with chunks and metadata |
|
|
""" |
|
|
|
|
|
name = "parse_document" |
|
|
description = "Parse a document into semantic chunks with OCR and layout detection" |
|
|
|
|
|
def __init__(self, parser=None): |
|
|
from ..parsing import DocumentParser |
|
|
self.parser = parser or DocumentParser() |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
path: str, |
|
|
max_pages: Optional[int] = None, |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
|
|
|
if max_pages: |
|
|
self.parser.config.max_pages = max_pages |
|
|
|
|
|
result = self.parser.parse(path) |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"doc_id": result.doc_id, |
|
|
"filename": result.filename, |
|
|
"num_pages": result.num_pages, |
|
|
"num_chunks": len(result.chunks), |
|
|
"chunks": [ |
|
|
{ |
|
|
"chunk_id": c.chunk_id, |
|
|
"type": c.chunk_type.value, |
|
|
"text": c.text[:500], |
|
|
"page": c.page, |
|
|
"confidence": c.confidence, |
|
|
} |
|
|
for c in result.chunks[:20] |
|
|
], |
|
|
"markdown_preview": result.markdown_full[:2000], |
|
|
}, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Parse document failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
|
|
|
class ExtractFieldsTool(DocumentTool): |
|
|
""" |
|
|
Extract fields from a parsed document using a schema. |
|
|
|
|
|
Input: |
|
|
parse_result: Previously parsed document |
|
|
schema: Extraction schema (dict or ExtractionSchema) |
|
|
fields: Optional list of specific fields to extract |
|
|
|
|
|
Output: |
|
|
ExtractionResult with values and evidence |
|
|
""" |
|
|
|
|
|
name = "extract_fields" |
|
|
description = "Extract structured fields from document using a schema" |
|
|
|
|
|
def __init__(self, extractor=None): |
|
|
from ..extraction import FieldExtractor |
|
|
self.extractor = extractor or FieldExtractor() |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
parse_result: Any, |
|
|
schema: Union[Dict, Any], |
|
|
fields: Optional[List[str]] = None, |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
from ..extraction import ExtractionSchema |
|
|
|
|
|
|
|
|
if isinstance(schema, dict): |
|
|
schema = ExtractionSchema.from_json_schema(schema) |
|
|
|
|
|
|
|
|
if fields: |
|
|
schema.fields = [f for f in schema.fields if f.name in fields] |
|
|
|
|
|
result = self.extractor.extract(parse_result, schema) |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"extracted_data": result.data, |
|
|
"confidence": result.overall_confidence, |
|
|
"abstained_fields": result.abstained_fields, |
|
|
}, |
|
|
evidence=[ |
|
|
{ |
|
|
"chunk_id": e.chunk_id, |
|
|
"page": e.page, |
|
|
"bbox": e.bbox.xyxy, |
|
|
"snippet": e.snippet, |
|
|
"confidence": e.confidence, |
|
|
} |
|
|
for e in result.evidence |
|
|
], |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Extract fields failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
|
|
|
class SearchChunksTool(DocumentTool): |
|
|
""" |
|
|
Search for chunks containing specific text or matching criteria. |
|
|
|
|
|
Input: |
|
|
parse_result: Parsed document |
|
|
query: Search query |
|
|
chunk_types: Optional list of chunk types to filter |
|
|
top_k: Maximum results to return |
|
|
|
|
|
Output: |
|
|
List of matching chunks with scores |
|
|
""" |
|
|
|
|
|
name = "search_chunks" |
|
|
description = "Search document chunks for specific content" |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
parse_result: Any, |
|
|
query: str, |
|
|
chunk_types: Optional[List[str]] = None, |
|
|
top_k: int = 10, |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
from ..chunks import ChunkType |
|
|
|
|
|
query_lower = query.lower() |
|
|
results = [] |
|
|
|
|
|
for chunk in parse_result.chunks: |
|
|
|
|
|
if chunk_types: |
|
|
if chunk.chunk_type.value not in chunk_types: |
|
|
continue |
|
|
|
|
|
|
|
|
text_lower = chunk.text.lower() |
|
|
if query_lower in text_lower: |
|
|
|
|
|
count = text_lower.count(query_lower) |
|
|
position = text_lower.find(query_lower) |
|
|
score = count * 10 + (1 / (position + 1)) * 5 |
|
|
|
|
|
results.append({ |
|
|
"chunk_id": chunk.chunk_id, |
|
|
"type": chunk.chunk_type.value, |
|
|
"text": chunk.text[:300], |
|
|
"page": chunk.page, |
|
|
"score": score, |
|
|
"bbox": chunk.bbox.xyxy, |
|
|
}) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
|
results = results[:top_k] |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"query": query, |
|
|
"total_matches": len(results), |
|
|
"results": results, |
|
|
}, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Search chunks failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
|
|
|
class GetChunkDetailsTool(DocumentTool): |
|
|
""" |
|
|
Get detailed information about a specific chunk. |
|
|
|
|
|
Input: |
|
|
parse_result: Parsed document |
|
|
chunk_id: ID of chunk to retrieve |
|
|
|
|
|
Output: |
|
|
Full chunk details including content and metadata |
|
|
""" |
|
|
|
|
|
name = "get_chunk_details" |
|
|
description = "Get detailed information about a specific chunk" |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
parse_result: Any, |
|
|
chunk_id: str, |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
from ..chunks import TableChunk, ChartChunk |
|
|
|
|
|
|
|
|
chunk = None |
|
|
for c in parse_result.chunks: |
|
|
if c.chunk_id == chunk_id: |
|
|
chunk = c |
|
|
break |
|
|
|
|
|
if chunk is None: |
|
|
return ToolResult( |
|
|
success=False, |
|
|
error=f"Chunk not found: {chunk_id}" |
|
|
) |
|
|
|
|
|
data = { |
|
|
"chunk_id": chunk.chunk_id, |
|
|
"doc_id": chunk.doc_id, |
|
|
"type": chunk.chunk_type.value, |
|
|
"text": chunk.text, |
|
|
"page": chunk.page, |
|
|
"bbox": { |
|
|
"x_min": chunk.bbox.x_min, |
|
|
"y_min": chunk.bbox.y_min, |
|
|
"x_max": chunk.bbox.x_max, |
|
|
"y_max": chunk.bbox.y_max, |
|
|
"normalized": chunk.bbox.normalized, |
|
|
}, |
|
|
"confidence": chunk.confidence, |
|
|
"sequence_index": chunk.sequence_index, |
|
|
} |
|
|
|
|
|
|
|
|
if isinstance(chunk, TableChunk): |
|
|
data["table"] = { |
|
|
"num_rows": chunk.num_rows, |
|
|
"num_cols": chunk.num_cols, |
|
|
"markdown": chunk.to_markdown(), |
|
|
"csv": chunk.to_csv(), |
|
|
} |
|
|
elif isinstance(chunk, ChartChunk): |
|
|
data["chart"] = { |
|
|
"chart_type": chunk.chart_type, |
|
|
"title": chunk.title, |
|
|
"data_points": len(chunk.data_points), |
|
|
"trends": chunk.trends, |
|
|
} |
|
|
|
|
|
return ToolResult(success=True, data=data) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Get chunk details failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
|
|
|
class GetTableDataTool(DocumentTool): |
|
|
""" |
|
|
Get structured data from a table chunk. |
|
|
|
|
|
Input: |
|
|
parse_result: Parsed document |
|
|
chunk_id: ID of table chunk |
|
|
format: Output format (json, csv, markdown) |
|
|
|
|
|
Output: |
|
|
Table data in requested format |
|
|
""" |
|
|
|
|
|
name = "get_table_data" |
|
|
description = "Extract structured data from a table" |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
parse_result: Any, |
|
|
chunk_id: str, |
|
|
format: str = "json", |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
from ..chunks import TableChunk |
|
|
|
|
|
|
|
|
table = None |
|
|
for c in parse_result.chunks: |
|
|
if c.chunk_id == chunk_id and isinstance(c, TableChunk): |
|
|
table = c |
|
|
break |
|
|
|
|
|
if table is None: |
|
|
return ToolResult( |
|
|
success=False, |
|
|
error=f"Table chunk not found: {chunk_id}" |
|
|
) |
|
|
|
|
|
if format == "csv": |
|
|
data = table.to_csv() |
|
|
elif format == "markdown": |
|
|
data = table.to_markdown() |
|
|
else: |
|
|
data = table.to_structured_json() |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"chunk_id": chunk_id, |
|
|
"format": format, |
|
|
"num_rows": table.num_rows, |
|
|
"num_cols": table.num_cols, |
|
|
"content": data, |
|
|
}, |
|
|
evidence=[{ |
|
|
"chunk_id": chunk_id, |
|
|
"page": table.page, |
|
|
"bbox": table.bbox.xyxy, |
|
|
"source_type": "table", |
|
|
}], |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Get table data failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
|
|
|
class AnswerQuestionTool(DocumentTool): |
|
|
""" |
|
|
Answer a question about the document using available chunks. |
|
|
|
|
|
Input: |
|
|
parse_result: Parsed document |
|
|
question: Question to answer |
|
|
use_rag: Whether to use RAG for retrieval (requires indexed document) |
|
|
document_id: Document ID for RAG retrieval (defaults to parse_result.doc_id) |
|
|
top_k: Number of chunks to consider |
|
|
|
|
|
Output: |
|
|
Answer with supporting evidence |
|
|
""" |
|
|
|
|
|
name = "answer_question" |
|
|
description = "Answer a question about the document content" |
|
|
|
|
|
def __init__(self, llm_client=None): |
|
|
self.llm_client = llm_client |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
parse_result: Any, |
|
|
question: str, |
|
|
use_rag: bool = False, |
|
|
document_id: Optional[str] = None, |
|
|
top_k: int = 5, |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
|
|
|
if use_rag: |
|
|
return self._answer_with_rag( |
|
|
question=question, |
|
|
document_id=document_id or (parse_result.doc_id if parse_result else None), |
|
|
top_k=top_k, |
|
|
) |
|
|
|
|
|
|
|
|
return self._answer_with_keywords( |
|
|
parse_result=parse_result, |
|
|
question=question, |
|
|
top_k=top_k, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Answer question failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
def _answer_with_rag( |
|
|
self, |
|
|
question: str, |
|
|
document_id: Optional[str], |
|
|
top_k: int, |
|
|
) -> ToolResult: |
|
|
"""Answer using RAG retrieval.""" |
|
|
try: |
|
|
from .rag_tools import RAGAnswerTool |
|
|
rag_tool = RAGAnswerTool(llm_client=self.llm_client) |
|
|
return rag_tool.execute( |
|
|
question=question, |
|
|
document_id=document_id, |
|
|
top_k=top_k, |
|
|
) |
|
|
except ImportError: |
|
|
return ToolResult( |
|
|
success=False, |
|
|
error="RAG module not available. Use use_rag=False or install chromadb." |
|
|
) |
|
|
|
|
|
def _answer_with_keywords( |
|
|
self, |
|
|
parse_result: Any, |
|
|
question: str, |
|
|
top_k: int, |
|
|
) -> ToolResult: |
|
|
"""Answer using keyword-based search on parse_result.""" |
|
|
if parse_result is None: |
|
|
return ToolResult( |
|
|
success=False, |
|
|
error="parse_result is required when use_rag=False" |
|
|
) |
|
|
|
|
|
|
|
|
question_lower = question.lower() |
|
|
relevant_chunks = [] |
|
|
|
|
|
for chunk in parse_result.chunks: |
|
|
text_lower = chunk.text.lower() |
|
|
|
|
|
keywords = [w for w in question_lower.split() if len(w) > 3] |
|
|
matches = sum(1 for k in keywords if k in text_lower) |
|
|
if matches > 0: |
|
|
relevant_chunks.append((chunk, matches)) |
|
|
|
|
|
|
|
|
relevant_chunks.sort(key=lambda x: x[1], reverse=True) |
|
|
top_chunks = relevant_chunks[:top_k] |
|
|
|
|
|
if not top_chunks: |
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"question": question, |
|
|
"answer": "I could not find relevant information in the document to answer this question.", |
|
|
"confidence": 0.0, |
|
|
"abstained": True, |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
context = "\n\n".join( |
|
|
f"[Page {c.page}] {c.text}" |
|
|
for c, _ in top_chunks |
|
|
) |
|
|
|
|
|
|
|
|
if self.llm_client is None: |
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"question": question, |
|
|
"answer": f"Based on the document: {top_chunks[0][0].text[:500]}", |
|
|
"confidence": 0.6, |
|
|
"context_chunks": len(top_chunks), |
|
|
}, |
|
|
evidence=[ |
|
|
{ |
|
|
"chunk_id": c.chunk_id, |
|
|
"page": c.page, |
|
|
"bbox": c.bbox.xyxy, |
|
|
"snippet": c.text[:200], |
|
|
} |
|
|
for c, _ in top_chunks |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
from ...rag import get_grounded_generator |
|
|
|
|
|
generator = get_grounded_generator(llm_client=self.llm_client) |
|
|
|
|
|
|
|
|
chunk_dicts = [ |
|
|
{ |
|
|
"chunk_id": c.chunk_id, |
|
|
"document_id": c.doc_id, |
|
|
"text": c.text, |
|
|
"similarity": score / 10.0, |
|
|
"page": c.page, |
|
|
"chunk_type": c.chunk_type.value, |
|
|
} |
|
|
for c, score in top_chunks |
|
|
] |
|
|
|
|
|
answer = generator.generate_answer( |
|
|
question=question, |
|
|
context=context, |
|
|
chunks=chunk_dicts, |
|
|
) |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"question": question, |
|
|
"answer": answer.text, |
|
|
"confidence": answer.confidence, |
|
|
"abstained": answer.abstained, |
|
|
}, |
|
|
evidence=[ |
|
|
{ |
|
|
"chunk_id": c.chunk_id, |
|
|
"page": c.page, |
|
|
"bbox": c.bbox.xyxy, |
|
|
"snippet": c.text[:200], |
|
|
} |
|
|
for c, _ in top_chunks |
|
|
], |
|
|
) |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"question": question, |
|
|
"answer": f"Based on the document: {top_chunks[0][0].text[:500]}", |
|
|
"confidence": 0.6, |
|
|
"context_chunks": len(top_chunks), |
|
|
}, |
|
|
evidence=[ |
|
|
{ |
|
|
"chunk_id": c.chunk_id, |
|
|
"page": c.page, |
|
|
"bbox": c.bbox.xyxy, |
|
|
"snippet": c.text[:200], |
|
|
} |
|
|
for c, _ in top_chunks |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
class CropRegionTool(DocumentTool): |
|
|
""" |
|
|
Crop a region from a document page image. |
|
|
|
|
|
Input: |
|
|
doc_path: Path to document |
|
|
page: Page number (1-indexed) |
|
|
bbox: Bounding box (x_min, y_min, x_max, y_max) |
|
|
output_path: Optional path to save crop |
|
|
|
|
|
Output: |
|
|
Crop image path or base64 data |
|
|
""" |
|
|
|
|
|
name = "crop_region" |
|
|
description = "Crop a specific region from a document page" |
|
|
|
|
|
def execute( |
|
|
self, |
|
|
doc_path: str, |
|
|
page: int, |
|
|
bbox: List[float], |
|
|
output_path: Optional[str] = None, |
|
|
**kwargs |
|
|
) -> ToolResult: |
|
|
try: |
|
|
from ..io import load_document, RenderOptions |
|
|
from ..grounding import crop_region |
|
|
from ..chunks import BoundingBox |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
loader, renderer = load_document(doc_path) |
|
|
page_image = renderer.render_page(page, RenderOptions(dpi=200)) |
|
|
loader.close() |
|
|
|
|
|
|
|
|
bbox_obj = BoundingBox( |
|
|
x_min=bbox[0], |
|
|
y_min=bbox[1], |
|
|
x_max=bbox[2], |
|
|
y_max=bbox[3], |
|
|
normalized=True, |
|
|
) |
|
|
|
|
|
|
|
|
crop = crop_region(page_image, bbox_obj) |
|
|
|
|
|
|
|
|
if output_path: |
|
|
Image.fromarray(crop).save(output_path) |
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"output_path": output_path, |
|
|
"width": crop.shape[1], |
|
|
"height": crop.shape[0], |
|
|
}, |
|
|
) |
|
|
else: |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
pil_img = Image.fromarray(crop) |
|
|
buffer = io.BytesIO() |
|
|
pil_img.save(buffer, format="PNG") |
|
|
b64 = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
|
return ToolResult( |
|
|
success=True, |
|
|
data={ |
|
|
"width": crop.shape[1], |
|
|
"height": crop.shape[0], |
|
|
"base64": b64[:100] + "...", |
|
|
}, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Crop region failed: {e}") |
|
|
return ToolResult(success=False, error=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
DOCUMENT_TOOLS = { |
|
|
"parse_document": ParseDocumentTool, |
|
|
"extract_fields": ExtractFieldsTool, |
|
|
"search_chunks": SearchChunksTool, |
|
|
"get_chunk_details": GetChunkDetailsTool, |
|
|
"get_table_data": GetTableDataTool, |
|
|
"answer_question": AnswerQuestionTool, |
|
|
"crop_region": CropRegionTool, |
|
|
} |
|
|
|
|
|
|
|
|
def get_tool(name: str, **kwargs) -> DocumentTool: |
|
|
"""Get a tool instance by name.""" |
|
|
if name not in DOCUMENT_TOOLS: |
|
|
raise ValueError(f"Unknown tool: {name}") |
|
|
return DOCUMENT_TOOLS[name](**kwargs) |
|
|
|
|
|
|
|
|
def list_tools() -> List[Dict[str, str]]: |
|
|
"""List all available tools.""" |
|
|
return [ |
|
|
{"name": name, "description": cls.description} |
|
|
for name, cls in DOCUMENT_TOOLS.items() |
|
|
] |
|
|
|