SmartDocAI / intelligence /context_validator.py
TilanB's picture
Initial commit for Hugging Face Space
50fcf88
"""
Relevance checker module for document retrieval quality assessment.
"""
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel, Field
from typing import Literal, Optional, List
import logging
from configuration.parameters import parameters
logger = logging.getLogger(__name__)
def estimate_tokens(text: str, chars_per_token: int = 4) -> int:
"""Estimate token count from text length."""
return len(text) // chars_per_token
# ============================================================================
# Structured Output Models
# ============================================================================
class ContextValidationClassification(BaseModel):
"""Structured output for context validation classification."""
classification: Literal["CAN_ANSWER", "PARTIAL", "NO_MATCH"] = Field(
description=(
"CAN_ANSWER: Passages contain enough info to fully answer. "
"PARTIAL: Passages mention the topic but incomplete. "
"NO_MATCH: Passages don't discuss the topic at all."
)
)
confidence: Literal["HIGH", "MEDIUM", "LOW"] = Field(
default="MEDIUM",
description="Confidence level in the classification"
)
reasoning: str = Field(
default="",
description="Brief explanation for the classification"
)
class ContextQueryExpansion(BaseModel):
"""Structured output for query expansion/rewriting."""
rewritten_query: str = Field(
description="A rephrased version of the original query"
)
key_terms: List[str] = Field(
default_factory=list,
description="Key terms and synonyms to search for"
)
search_strategy: str = Field(
default="",
description="Brief explanation of the search approach"
)
class ContextValidator:
"""
Checks context relevance of retrieved documents to a user's question.
Uses Gemini model with structured output to classify coverage
and provides query rewriting for improved retrieval.
"""
VALID_LABELS = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"}
def __init__(self):
"""Initialize the context validator."""
logger.info("Initializing ContextValidator...")
base_llm = ChatGoogleGenerativeAI(
model=parameters.RELEVANCE_CHECKER_MODEL,
google_api_key=parameters.GOOGLE_API_KEY,
temperature=0,
max_output_tokens=100,
)
self.llm = base_llm
self.structured_llm = base_llm.with_structured_output(ContextValidationClassification)
self.query_expansion_llm = base_llm.with_structured_output(ContextQueryExpansion)
logger.info(f"ContextValidator initialized (model={parameters.RELEVANCE_CHECKER_MODEL})")
def context_query_rewrite(self, original_query: str, context_hint: Optional[str] = None) -> Optional[ContextQueryExpansion]:
"""
Rewrite a query to potentially retrieve better results.
Args:
original_query: The original user query
context_hint: Optional hint about available documents
Returns:
ContextQueryExpansion with rewritten query, or None on failure
"""
logger.debug(f"Rewriting query: {original_query[:80]}...")
context_section = f"\n**Available Context:** {context_hint}\n" if context_hint else ""
prompt = f"""Rewrite this query to improve document retrieval.
**Original Query:** {original_query}
{context_section}
**Instructions:**
1. Rephrase to be more specific and searchable
2. Extract key terms and add synonyms
3. Consider exact phrases in formal documents"""
try:
result: ContextQueryExpansion = self.query_expansion_llm.invoke(prompt)
logger.debug(f"Query rewritten: {result.rewritten_query[:60]}...")
return result
except Exception as e:
logger.error(f"Query rewrite failed: {e}")
return None
def context_validate(self, question: str, retriever, k: int = 3) -> str:
"""
Retrieve top-k passages and classify coverage.
Args:
question: The user's question
retriever: The retriever for fetching documents
k: Number of top documents to consider
Returns:
Classification: "CAN_ANSWER", "PARTIAL", or "NO_MATCH"
"""
if not question or not question.strip():
logger.warning("Empty question provided")
return "NO_MATCH"
if k < 1:
k = 3
logger.info(f"Checking context relevance for: {question[:60]}...")
# Retrieve documents
try:
top_docs = retriever.invoke(question)
except Exception as e:
logger.error(f"Retriever invocation failed: {e}")
return "NO_MATCH"
if not top_docs:
logger.info("No documents returned")
return "NO_MATCH"
logger.debug(f"Retrieved {len(top_docs)} documents")
passages = "\n\n".join(doc.page_content for doc in top_docs[:k])
prompt = f"""Classify how well the passages address the question.
**Question:** {question}
**Passages:**
{passages}
Classify as CAN_ANSWER (fully answers), PARTIAL (mentions topic), or NO_MATCH (unrelated)."""
try:
result: ContextValidationClassification = self.structured_llm.invoke(prompt)
logger.info(f"Context relevance: {result.classification} ({result.confidence})")
return result.classification
except Exception as e:
logger.error(f"Structured output failed: {e}")
# Fallback to text parsing
try:
response = self.llm.invoke(prompt)
raw_response = response.content if hasattr(response, "content") else str(response)
llm_response = raw_response.strip().upper()
for label in self.VALID_LABELS:
if label in llm_response:
logger.info(f"Fallback classification: {label}")
return label
return "NO_MATCH"
except Exception as fallback_error:
logger.error(f"Fallback failed: {fallback_error}")
return "NO_MATCH"
def context_validate_with_rewrite(self, question: str, retriever, k: int = 3, max_rewrites: int = 1) -> dict:
"""
Check relevance with automatic query rewriting if needed.
Args:
question: The user's question
retriever: The retriever to use
k: Number of top documents
max_rewrites: Maximum rewrite attempts
Returns:
Dict with classification, query_used, and was_rewritten
"""
classification = self.context_validate(question, retriever, k)
if classification == "CAN_ANSWER" or max_rewrites <= 0:
return {
"classification": classification,
"query_used": question,
"was_rewritten": False
}
# Try query rewriting for poor results
if classification in ["PARTIAL", "NO_MATCH"]:
logger.info("Attempting query rewrite...")
expansion = self.context_query_rewrite(question)
if expansion and expansion.rewritten_query != question:
new_classification = self.context_validate(expansion.rewritten_query, retriever, k)
if self._is_better_classification(new_classification, classification):
logger.info(f"Rewrite improved: {classification} -> {new_classification}")
return {
"classification": new_classification,
"query_used": expansion.rewritten_query,
"was_rewritten": True,
"key_terms": expansion.key_terms
}
return {
"classification": classification,
"query_used": question,
"was_rewritten": False
}
def _is_better_classification(self, new: str, old: str) -> bool:
"""Check if new classification is better than old."""
ranking = {"NO_MATCH": 0, "PARTIAL": 1, "CAN_ANSWER": 2}
return ranking.get(new, 0) > ranking.get(old, 0)