MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
RAG Tools for Document Intelligence
Provides RAG-powered tools for:
- IndexDocumentTool: Index documents into vector store
- RetrieveChunksTool: Semantic retrieval with filters
- RAGAnswerTool: Answer questions using RAG
"""
import logging
from typing import Any, Dict, List, Optional
from .document_tools import DocumentTool, ToolResult
logger = logging.getLogger(__name__)
# Check RAG availability
try:
from ...rag import (
get_docint_indexer,
get_docint_retriever,
get_grounded_generator,
GeneratorConfig,
)
from ...rag.indexer import IndexerConfig
RAG_AVAILABLE = True
except ImportError:
RAG_AVAILABLE = False
logger.warning("RAG module not available")
class IndexDocumentTool(DocumentTool):
"""
Index a document into the vector store for RAG.
Input:
parse_result: Previously parsed document (ParseResult)
OR
path: Path to document file (will parse first)
max_pages: Optional maximum pages to process
Output:
IndexingResult with stats
"""
name = "index_document"
description = "Index a document into the vector store for semantic retrieval"
def __init__(self, indexer_config: Optional[Any] = None):
self.indexer_config = indexer_config
def execute(
self,
parse_result: Optional[Any] = None,
path: Optional[str] = None,
max_pages: Optional[int] = None,
**kwargs
) -> ToolResult:
if not RAG_AVAILABLE:
return ToolResult(
success=False,
error="RAG module not available. Install chromadb: pip install chromadb"
)
try:
indexer = get_docint_indexer(config=self.indexer_config)
if parse_result is not None:
# Index already-parsed document
result = indexer.index_parse_result(parse_result)
elif path is not None:
# Parse and index document
result = indexer.index_document(path, max_pages=max_pages)
else:
return ToolResult(
success=False,
error="Either parse_result or path must be provided"
)
return ToolResult(
success=result.success,
data={
"document_id": result.document_id,
"source_path": result.source_path,
"chunks_indexed": result.num_chunks_indexed,
"chunks_skipped": result.num_chunks_skipped,
},
error=result.error,
)
except Exception as e:
logger.error(f"Index document failed: {e}")
return ToolResult(success=False, error=str(e))
class RetrieveChunksTool(DocumentTool):
"""
Retrieve relevant chunks using semantic search.
Input:
query: Search query
top_k: Number of results (default: 5)
document_id: Filter by document ID
chunk_types: Filter by chunk type(s) (e.g., ["paragraph", "table"])
page_range: Filter by page range (start, end)
Output:
List of relevant chunks with similarity scores
"""
name = "retrieve_chunks"
description = "Retrieve relevant document chunks using semantic search"
def __init__(self, similarity_threshold: float = 0.5):
self.similarity_threshold = similarity_threshold
def execute(
self,
query: str,
top_k: int = 5,
document_id: Optional[str] = None,
chunk_types: Optional[List[str]] = None,
page_range: Optional[tuple] = None,
include_evidence: bool = True,
**kwargs
) -> ToolResult:
if not RAG_AVAILABLE:
return ToolResult(
success=False,
error="RAG module not available. Install chromadb: pip install chromadb"
)
try:
retriever = get_docint_retriever(
similarity_threshold=self.similarity_threshold
)
if include_evidence:
chunks, evidence_refs = retriever.retrieve_with_evidence(
query=query,
top_k=top_k,
document_id=document_id,
chunk_types=chunk_types,
page_range=page_range,
)
evidence = [
{
"chunk_id": ev.chunk_id,
"page": ev.page,
"bbox": ev.bbox.xyxy if ev.bbox else None,
"snippet": ev.snippet,
"confidence": ev.confidence,
}
for ev in evidence_refs
]
else:
chunks = retriever.retrieve(
query=query,
top_k=top_k,
document_id=document_id,
chunk_types=chunk_types,
page_range=page_range,
)
evidence = []
return ToolResult(
success=True,
data={
"query": query,
"num_results": len(chunks),
"chunks": [
{
"chunk_id": c["chunk_id"],
"document_id": c["document_id"],
"text": c["text"][:500], # Truncate for display
"similarity": c["similarity"],
"page": c.get("page"),
"chunk_type": c.get("chunk_type"),
}
for c in chunks
],
},
evidence=evidence,
)
except Exception as e:
logger.error(f"Retrieve chunks failed: {e}")
return ToolResult(success=False, error=str(e))
class RAGAnswerTool(DocumentTool):
"""
Answer a question using RAG (Retrieval-Augmented Generation).
Input:
question: Question to answer
document_id: Filter to specific document
top_k: Number of chunks to retrieve (default: 5)
chunk_types: Filter by chunk type(s)
page_range: Filter by page range
Output:
Answer with citations and evidence
"""
name = "rag_answer"
description = "Answer a question using RAG with grounded citations"
def __init__(
self,
llm_client: Optional[Any] = None,
min_confidence: float = 0.5,
abstain_threshold: float = 0.3,
):
self.llm_client = llm_client
self.min_confidence = min_confidence
self.abstain_threshold = abstain_threshold
def execute(
self,
question: str,
document_id: Optional[str] = None,
top_k: int = 5,
chunk_types: Optional[List[str]] = None,
page_range: Optional[tuple] = None,
**kwargs
) -> ToolResult:
if not RAG_AVAILABLE:
return ToolResult(
success=False,
error="RAG module not available. Install chromadb: pip install chromadb"
)
try:
# Retrieve relevant chunks
retriever = get_docint_retriever()
chunks, evidence_refs = retriever.retrieve_with_evidence(
query=question,
top_k=top_k,
document_id=document_id,
chunk_types=chunk_types,
page_range=page_range,
)
if not chunks:
return ToolResult(
success=True,
data={
"question": question,
"answer": "I could not find relevant information to answer this question.",
"confidence": 0.0,
"abstained": True,
"reason": "No relevant chunks found",
},
)
# Build context
context = retriever.build_context(chunks)
# Check if we have LLM for generation
if self.llm_client is None:
# Return context-based answer without LLM
best_chunk = chunks[0]
return ToolResult(
success=True,
data={
"question": question,
"answer": f"Based on the document: {best_chunk['text'][:500]}",
"confidence": best_chunk["similarity"],
"abstained": False,
"context_chunks": len(chunks),
},
evidence=[
{
"chunk_id": ev.chunk_id,
"page": ev.page,
"bbox": ev.bbox.xyxy if ev.bbox else None,
"snippet": ev.snippet,
}
for ev in evidence_refs
],
)
# Use grounded generator
generator_config = GeneratorConfig(
min_confidence=self.min_confidence,
abstain_on_low_confidence=True,
abstain_threshold=self.abstain_threshold,
)
generator = get_grounded_generator(
config=generator_config,
llm_client=self.llm_client,
)
answer = generator.generate_answer(
question=question,
context=context,
chunks=chunks,
)
return ToolResult(
success=True,
data={
"question": question,
"answer": answer.text,
"confidence": answer.confidence,
"abstained": answer.abstained,
"citations": [
{
"index": c.index,
"chunk_id": c.chunk_id,
"text": c.text,
}
for c in (answer.citations or [])
],
},
evidence=[
{
"chunk_id": ev.chunk_id,
"page": ev.page,
"bbox": ev.bbox.xyxy if ev.bbox else None,
"snippet": ev.snippet,
}
for ev in evidence_refs
],
)
except Exception as e:
logger.error(f"RAG answer failed: {e}")
return ToolResult(success=False, error=str(e))
class DeleteDocumentTool(DocumentTool):
"""
Delete a document from the vector store index.
Input:
document_id: ID of document to delete
Output:
Number of chunks deleted
"""
name = "delete_document"
description = "Remove a document from the vector store index"
def execute(self, document_id: str, **kwargs) -> ToolResult:
if not RAG_AVAILABLE:
return ToolResult(
success=False,
error="RAG module not available"
)
try:
indexer = get_docint_indexer()
deleted_count = indexer.delete_document(document_id)
return ToolResult(
success=True,
data={
"document_id": document_id,
"chunks_deleted": deleted_count,
},
)
except Exception as e:
logger.error(f"Delete document failed: {e}")
return ToolResult(success=False, error=str(e))
class GetIndexStatsTool(DocumentTool):
"""
Get statistics about the vector store index.
Output:
Index statistics (total chunks, embedding model, etc.)
"""
name = "get_index_stats"
description = "Get statistics about the vector store index"
def execute(self, **kwargs) -> ToolResult:
if not RAG_AVAILABLE:
return ToolResult(
success=False,
error="RAG module not available"
)
try:
indexer = get_docint_indexer()
stats = indexer.get_stats()
return ToolResult(
success=True,
data=stats,
)
except Exception as e:
logger.error(f"Get index stats failed: {e}")
return ToolResult(success=False, error=str(e))
# Tool registry for RAG tools
RAG_TOOLS = {
"index_document": IndexDocumentTool,
"retrieve_chunks": RetrieveChunksTool,
"rag_answer": RAGAnswerTool,
"delete_document": DeleteDocumentTool,
"get_index_stats": GetIndexStatsTool,
}
def get_rag_tool(name: str, **kwargs) -> DocumentTool:
"""Get a RAG tool instance by name."""
if name not in RAG_TOOLS:
raise ValueError(f"Unknown RAG tool: {name}")
return RAG_TOOLS[name](**kwargs)
def list_rag_tools() -> List[Dict[str, str]]:
"""List all available RAG tools."""
return [
{"name": name, "description": cls.description}
for name, cls in RAG_TOOLS.items()
]