""" Relevance checker module for document retrieval quality assessment. """ from langchain_google_genai import ChatGoogleGenerativeAI from pydantic import BaseModel, Field from typing import Literal, Optional, List import logging from configuration.parameters import parameters logger = logging.getLogger(__name__) def estimate_tokens(text: str, chars_per_token: int = 4) -> int: """Estimate token count from text length.""" return len(text) // chars_per_token # ============================================================================ # Structured Output Models # ============================================================================ class ContextValidationClassification(BaseModel): """Structured output for context validation classification.""" classification: Literal["CAN_ANSWER", "PARTIAL", "NO_MATCH"] = Field( description=( "CAN_ANSWER: Passages contain enough info to fully answer. " "PARTIAL: Passages mention the topic but incomplete. " "NO_MATCH: Passages don't discuss the topic at all." ) ) confidence: Literal["HIGH", "MEDIUM", "LOW"] = Field( default="MEDIUM", description="Confidence level in the classification" ) reasoning: str = Field( default="", description="Brief explanation for the classification" ) class ContextQueryExpansion(BaseModel): """Structured output for query expansion/rewriting.""" rewritten_query: str = Field( description="A rephrased version of the original query" ) key_terms: List[str] = Field( default_factory=list, description="Key terms and synonyms to search for" ) search_strategy: str = Field( default="", description="Brief explanation of the search approach" ) class ContextValidator: """ Checks context relevance of retrieved documents to a user's question. Uses Gemini model with structured output to classify coverage and provides query rewriting for improved retrieval. """ VALID_LABELS = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"} def __init__(self): """Initialize the context validator.""" logger.info("Initializing ContextValidator...") base_llm = ChatGoogleGenerativeAI( model=parameters.RELEVANCE_CHECKER_MODEL, google_api_key=parameters.GOOGLE_API_KEY, temperature=0, max_output_tokens=100, ) self.llm = base_llm self.structured_llm = base_llm.with_structured_output(ContextValidationClassification) self.query_expansion_llm = base_llm.with_structured_output(ContextQueryExpansion) logger.info(f"ContextValidator initialized (model={parameters.RELEVANCE_CHECKER_MODEL})") def context_query_rewrite(self, original_query: str, context_hint: Optional[str] = None) -> Optional[ContextQueryExpansion]: """ Rewrite a query to potentially retrieve better results. Args: original_query: The original user query context_hint: Optional hint about available documents Returns: ContextQueryExpansion with rewritten query, or None on failure """ logger.debug(f"Rewriting query: {original_query[:80]}...") context_section = f"\n**Available Context:** {context_hint}\n" if context_hint else "" prompt = f"""Rewrite this query to improve document retrieval. **Original Query:** {original_query} {context_section} **Instructions:** 1. Rephrase to be more specific and searchable 2. Extract key terms and add synonyms 3. Consider exact phrases in formal documents""" try: result: ContextQueryExpansion = self.query_expansion_llm.invoke(prompt) logger.debug(f"Query rewritten: {result.rewritten_query[:60]}...") return result except Exception as e: logger.error(f"Query rewrite failed: {e}") return None def context_validate(self, question: str, retriever, k: int = 3) -> str: """ Retrieve top-k passages and classify coverage. Args: question: The user's question retriever: The retriever for fetching documents k: Number of top documents to consider Returns: Classification: "CAN_ANSWER", "PARTIAL", or "NO_MATCH" """ if not question or not question.strip(): logger.warning("Empty question provided") return "NO_MATCH" if k < 1: k = 3 logger.info(f"Checking context relevance for: {question[:60]}...") # Retrieve documents try: top_docs = retriever.invoke(question) except Exception as e: logger.error(f"Retriever invocation failed: {e}") return "NO_MATCH" if not top_docs: logger.info("No documents returned") return "NO_MATCH" logger.debug(f"Retrieved {len(top_docs)} documents") passages = "\n\n".join(doc.page_content for doc in top_docs[:k]) prompt = f"""Classify how well the passages address the question. **Question:** {question} **Passages:** {passages} Classify as CAN_ANSWER (fully answers), PARTIAL (mentions topic), or NO_MATCH (unrelated).""" try: result: ContextValidationClassification = self.structured_llm.invoke(prompt) logger.info(f"Context relevance: {result.classification} ({result.confidence})") return result.classification except Exception as e: logger.error(f"Structured output failed: {e}") # Fallback to text parsing try: response = self.llm.invoke(prompt) raw_response = response.content if hasattr(response, "content") else str(response) llm_response = raw_response.strip().upper() for label in self.VALID_LABELS: if label in llm_response: logger.info(f"Fallback classification: {label}") return label return "NO_MATCH" except Exception as fallback_error: logger.error(f"Fallback failed: {fallback_error}") return "NO_MATCH" def context_validate_with_rewrite(self, question: str, retriever, k: int = 3, max_rewrites: int = 1) -> dict: """ Check relevance with automatic query rewriting if needed. Args: question: The user's question retriever: The retriever to use k: Number of top documents max_rewrites: Maximum rewrite attempts Returns: Dict with classification, query_used, and was_rewritten """ classification = self.context_validate(question, retriever, k) if classification == "CAN_ANSWER" or max_rewrites <= 0: return { "classification": classification, "query_used": question, "was_rewritten": False } # Try query rewriting for poor results if classification in ["PARTIAL", "NO_MATCH"]: logger.info("Attempting query rewrite...") expansion = self.context_query_rewrite(question) if expansion and expansion.rewritten_query != question: new_classification = self.context_validate(expansion.rewritten_query, retriever, k) if self._is_better_classification(new_classification, classification): logger.info(f"Rewrite improved: {classification} -> {new_classification}") return { "classification": new_classification, "query_used": expansion.rewritten_query, "was_rewritten": True, "key_terms": expansion.key_terms } return { "classification": classification, "query_used": question, "was_rewritten": False } def _is_better_classification(self, new: str, old: str) -> bool: """Check if new classification is better than old.""" ranking = {"NO_MATCH": 0, "PARTIAL": 1, "CAN_ANSWER": 2} return ranking.get(new, 0) > ranking.get(old, 0)