|
|
""" |
|
|
Query Planner Agent |
|
|
|
|
|
Decomposes complex queries into sub-queries and identifies query intent. |
|
|
Follows the "Decomposed Prompting" approach from FAANG research. |
|
|
|
|
|
Key Features: |
|
|
- Multi-hop query decomposition |
|
|
- Query intent classification (factoid, comparison, aggregation, etc.) |
|
|
- Dependency graph for sub-queries |
|
|
- Query expansion with synonyms and related terms |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any, Literal |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
from enum import Enum |
|
|
import json |
|
|
import re |
|
|
|
|
|
try: |
|
|
import httpx |
|
|
HTTPX_AVAILABLE = True |
|
|
except ImportError: |
|
|
HTTPX_AVAILABLE = False |
|
|
|
|
|
|
|
|
class QueryIntent(str, Enum): |
|
|
"""Classification of query intent.""" |
|
|
FACTOID = "factoid" |
|
|
COMPARISON = "comparison" |
|
|
AGGREGATION = "aggregation" |
|
|
CAUSAL = "causal" |
|
|
PROCEDURAL = "procedural" |
|
|
DEFINITION = "definition" |
|
|
LIST = "list" |
|
|
MULTI_HOP = "multi_hop" |
|
|
|
|
|
|
|
|
class SubQuery(BaseModel): |
|
|
"""A decomposed sub-query.""" |
|
|
id: str |
|
|
query: str |
|
|
intent: QueryIntent |
|
|
depends_on: List[str] = Field(default_factory=list) |
|
|
priority: int = Field(default=1, ge=1, le=5) |
|
|
filters: Dict[str, Any] = Field(default_factory=dict) |
|
|
expected_answer_type: str = Field(default="text") |
|
|
|
|
|
|
|
|
class QueryPlan(BaseModel): |
|
|
"""Complete query execution plan.""" |
|
|
original_query: str |
|
|
intent: QueryIntent |
|
|
sub_queries: List[SubQuery] |
|
|
expanded_terms: List[str] = Field(default_factory=list) |
|
|
requires_aggregation: bool = False |
|
|
confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
|
|
|
|
|
|
class QueryPlannerAgent: |
|
|
""" |
|
|
Plans and decomposes queries for optimal retrieval. |
|
|
|
|
|
Capabilities: |
|
|
1. Identify query complexity and intent |
|
|
2. Decompose multi-hop queries into atomic sub-queries |
|
|
3. Build dependency graph for sub-query execution |
|
|
4. Expand queries with related terms |
|
|
""" |
|
|
|
|
|
SYSTEM_PROMPT = """You are a query planning expert. Your job is to analyze user queries and create optimal retrieval plans. |
|
|
|
|
|
For each query, you must: |
|
|
1. Classify the query intent (factoid, comparison, aggregation, causal, procedural, definition, list, multi_hop) |
|
|
2. Decompose complex queries into simpler sub-queries |
|
|
3. Identify dependencies between sub-queries |
|
|
4. Suggest query expansions (synonyms, related terms) |
|
|
|
|
|
Output your analysis as JSON with this structure: |
|
|
{ |
|
|
"intent": "factoid|comparison|aggregation|causal|procedural|definition|list|multi_hop", |
|
|
"sub_queries": [ |
|
|
{ |
|
|
"id": "sq1", |
|
|
"query": "the sub-query text", |
|
|
"intent": "factoid", |
|
|
"depends_on": [], |
|
|
"priority": 1, |
|
|
"expected_answer_type": "text|number|date|list|boolean" |
|
|
} |
|
|
], |
|
|
"expanded_terms": ["synonym1", "related_term1"], |
|
|
"requires_aggregation": false, |
|
|
"confidence": 0.95 |
|
|
} |
|
|
|
|
|
For simple queries, return a single sub-query matching the original. |
|
|
For complex queries requiring multiple steps, break them down logically. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: str = "llama3.2:3b", |
|
|
base_url: str = "http://localhost:11434", |
|
|
temperature: float = 0.1, |
|
|
use_llm: bool = True, |
|
|
): |
|
|
""" |
|
|
Initialize Query Planner. |
|
|
|
|
|
Args: |
|
|
model: LLM model for planning |
|
|
base_url: Ollama API URL |
|
|
temperature: LLM temperature (lower = more deterministic) |
|
|
use_llm: If False, use rule-based planning only |
|
|
""" |
|
|
self.model = model |
|
|
self.base_url = base_url.rstrip("/") |
|
|
self.temperature = temperature |
|
|
self.use_llm = use_llm |
|
|
|
|
|
logger.info(f"QueryPlannerAgent initialized (model={model}, use_llm={use_llm})") |
|
|
|
|
|
def plan(self, query: str) -> QueryPlan: |
|
|
""" |
|
|
Create execution plan for a query. |
|
|
|
|
|
Args: |
|
|
query: User's natural language query |
|
|
|
|
|
Returns: |
|
|
QueryPlan with sub-queries and metadata |
|
|
""" |
|
|
|
|
|
rule_based_plan = self._rule_based_planning(query) |
|
|
|
|
|
if not self.use_llm or not HTTPX_AVAILABLE: |
|
|
return rule_based_plan |
|
|
|
|
|
|
|
|
try: |
|
|
llm_plan = self._llm_planning(query) |
|
|
|
|
|
|
|
|
if rule_based_plan.expanded_terms: |
|
|
llm_plan.expanded_terms = list(set( |
|
|
llm_plan.expanded_terms + rule_based_plan.expanded_terms |
|
|
)) |
|
|
|
|
|
return llm_plan |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"LLM planning failed, using rule-based: {e}") |
|
|
return rule_based_plan |
|
|
|
|
|
def _rule_based_planning(self, query: str) -> QueryPlan: |
|
|
"""Fast rule-based query planning.""" |
|
|
query_lower = query.lower().strip() |
|
|
|
|
|
|
|
|
intent = self._detect_intent(query_lower) |
|
|
|
|
|
|
|
|
expansions = self._expand_query(query) |
|
|
|
|
|
|
|
|
sub_queries = self._decompose_if_needed(query, intent) |
|
|
|
|
|
return QueryPlan( |
|
|
original_query=query, |
|
|
intent=intent, |
|
|
sub_queries=sub_queries, |
|
|
expanded_terms=expansions, |
|
|
requires_aggregation=intent in [QueryIntent.AGGREGATION, QueryIntent.LIST], |
|
|
confidence=0.8, |
|
|
) |
|
|
|
|
|
def _detect_intent(self, query: str) -> QueryIntent: |
|
|
"""Detect query intent from patterns.""" |
|
|
|
|
|
if re.match(r"^(what is|define|what are|what does .* mean)", query): |
|
|
return QueryIntent.DEFINITION |
|
|
|
|
|
|
|
|
if any(p in query for p in ["compare", "difference between", "vs", "versus", "better than"]): |
|
|
return QueryIntent.COMPARISON |
|
|
|
|
|
|
|
|
if any(p in query for p in ["list", "what are all", "give me all", "enumerate"]): |
|
|
return QueryIntent.LIST |
|
|
|
|
|
|
|
|
if any(p in query for p in ["why", "how does", "what causes", "reason for"]): |
|
|
return QueryIntent.CAUSAL |
|
|
|
|
|
|
|
|
if any(p in query for p in ["how to", "steps to", "process for", "how can i"]): |
|
|
return QueryIntent.PROCEDURAL |
|
|
|
|
|
|
|
|
if any(p in query for p in ["summarize", "overview", "summary of", "main points"]): |
|
|
return QueryIntent.AGGREGATION |
|
|
|
|
|
|
|
|
if " and " in query and "?" in query: |
|
|
return QueryIntent.MULTI_HOP |
|
|
if query.count("?") > 1: |
|
|
return QueryIntent.MULTI_HOP |
|
|
|
|
|
|
|
|
return QueryIntent.FACTOID |
|
|
|
|
|
def _expand_query(self, query: str) -> List[str]: |
|
|
"""Generate query expansions (synonyms, related terms).""" |
|
|
expansions = [] |
|
|
query_lower = query.lower() |
|
|
|
|
|
|
|
|
expansion_map = { |
|
|
"patent": ["intellectual property", "IP", "invention", "claim"], |
|
|
"license": ["licensing", "agreement", "contract", "terms"], |
|
|
"royalty": ["royalties", "payment", "fee", "compensation"], |
|
|
"open source": ["OSS", "FOSS", "free software", "open-source"], |
|
|
"trademark": ["brand", "mark", "logo"], |
|
|
"copyright": ["rights", "authorship", "protection"], |
|
|
"infringement": ["violation", "breach", "unauthorized use"], |
|
|
"disclosure": ["reveal", "publish", "filing"], |
|
|
} |
|
|
|
|
|
for term, synonyms in expansion_map.items(): |
|
|
if term in query_lower: |
|
|
expansions.extend(synonyms) |
|
|
|
|
|
return list(set(expansions))[:10] |
|
|
|
|
|
def _decompose_if_needed(self, query: str, intent: QueryIntent) -> List[SubQuery]: |
|
|
"""Decompose query if complex.""" |
|
|
|
|
|
|
|
|
if intent == QueryIntent.COMPARISON: |
|
|
entities = self._extract_comparison_entities(query) |
|
|
if len(entities) >= 2: |
|
|
sub_queries = [] |
|
|
for i, entity in enumerate(entities): |
|
|
sub_queries.append(SubQuery( |
|
|
id=f"sq{i+1}", |
|
|
query=f"What are the key characteristics of {entity}?", |
|
|
intent=QueryIntent.FACTOID, |
|
|
priority=1, |
|
|
expected_answer_type="text", |
|
|
)) |
|
|
|
|
|
sub_queries.append(SubQuery( |
|
|
id=f"sq{len(entities)+1}", |
|
|
query=query, |
|
|
intent=QueryIntent.COMPARISON, |
|
|
depends_on=[f"sq{i+1}" for i in range(len(entities))], |
|
|
priority=2, |
|
|
expected_answer_type="text", |
|
|
)) |
|
|
return sub_queries |
|
|
|
|
|
|
|
|
if intent == QueryIntent.MULTI_HOP and " and " in query.lower(): |
|
|
parts = re.split(r'\s+and\s+', query, flags=re.IGNORECASE) |
|
|
sub_queries = [] |
|
|
for i, part in enumerate(parts): |
|
|
part = part.strip().rstrip("?") + "?" |
|
|
sub_queries.append(SubQuery( |
|
|
id=f"sq{i+1}", |
|
|
query=part, |
|
|
intent=QueryIntent.FACTOID, |
|
|
priority=i+1, |
|
|
expected_answer_type="text", |
|
|
)) |
|
|
return sub_queries |
|
|
|
|
|
|
|
|
return [SubQuery( |
|
|
id="sq1", |
|
|
query=query, |
|
|
intent=intent, |
|
|
priority=1, |
|
|
expected_answer_type="text", |
|
|
)] |
|
|
|
|
|
def _extract_comparison_entities(self, query: str) -> List[str]: |
|
|
"""Extract entities being compared.""" |
|
|
patterns = [ |
|
|
r"(?:compare|difference between)\s+(.+?)\s+(?:and|vs|versus)\s+(.+?)(?:\?|$)", |
|
|
r"(.+?)\s+(?:vs|versus)\s+(.+?)(?:\?|$)", |
|
|
r"(?:between)\s+(.+?)\s+(?:and)\s+(.+?)(?:\?|$)", |
|
|
] |
|
|
|
|
|
for pattern in patterns: |
|
|
match = re.search(pattern, query, re.IGNORECASE) |
|
|
if match: |
|
|
return [match.group(1).strip(), match.group(2).strip()] |
|
|
|
|
|
return [] |
|
|
|
|
|
def _llm_planning(self, query: str) -> QueryPlan: |
|
|
"""Use LLM for sophisticated query planning.""" |
|
|
prompt = f"""Analyze this query and create a retrieval plan: |
|
|
|
|
|
Query: {query} |
|
|
|
|
|
Provide your analysis as JSON.""" |
|
|
|
|
|
with httpx.Client(timeout=30.0) as client: |
|
|
response = client.post( |
|
|
f"{self.base_url}/api/generate", |
|
|
json={ |
|
|
"model": self.model, |
|
|
"prompt": prompt, |
|
|
"system": self.SYSTEM_PROMPT, |
|
|
"stream": False, |
|
|
"options": { |
|
|
"temperature": self.temperature, |
|
|
"num_predict": 1024, |
|
|
}, |
|
|
}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
|
|
|
|
|
|
response_text = result.get("response", "") |
|
|
plan_data = self._parse_json_response(response_text) |
|
|
|
|
|
|
|
|
sub_queries = [] |
|
|
for sq_data in plan_data.get("sub_queries", []): |
|
|
sub_queries.append(SubQuery( |
|
|
id=sq_data.get("id", "sq1"), |
|
|
query=sq_data.get("query", query), |
|
|
intent=QueryIntent(sq_data.get("intent", "factoid")), |
|
|
depends_on=sq_data.get("depends_on", []), |
|
|
priority=sq_data.get("priority", 1), |
|
|
expected_answer_type=sq_data.get("expected_answer_type", "text"), |
|
|
)) |
|
|
|
|
|
if not sub_queries: |
|
|
sub_queries = [SubQuery( |
|
|
id="sq1", |
|
|
query=query, |
|
|
intent=QueryIntent.FACTOID, |
|
|
priority=1, |
|
|
)] |
|
|
|
|
|
return QueryPlan( |
|
|
original_query=query, |
|
|
intent=QueryIntent(plan_data.get("intent", "factoid")), |
|
|
sub_queries=sub_queries, |
|
|
expanded_terms=plan_data.get("expanded_terms", []), |
|
|
requires_aggregation=plan_data.get("requires_aggregation", False), |
|
|
confidence=plan_data.get("confidence", 0.9), |
|
|
) |
|
|
|
|
|
def _parse_json_response(self, text: str) -> Dict[str, Any]: |
|
|
"""Extract JSON from LLM response.""" |
|
|
|
|
|
json_match = re.search(r'\{[\s\S]*\}', text) |
|
|
if json_match: |
|
|
try: |
|
|
return json.loads(json_match.group()) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
return { |
|
|
"intent": "factoid", |
|
|
"sub_queries": [], |
|
|
"expanded_terms": [], |
|
|
"requires_aggregation": False, |
|
|
"confidence": 0.7, |
|
|
} |
|
|
|