Rifqi Hafizuddin commited on
Commit ·
abc494f
1
Parent(s): 1fef470
[KM-512] connect query executor to user question. add logging for db_executor
Browse files- src/query/base.py +1 -0
- src/query/executors/db_executor.py +35 -13
- src/query/query_executor.py +4 -3
src/query/base.py
CHANGED
|
@@ -27,5 +27,6 @@ class BaseExecutor(ABC):
|
|
| 27 |
results: list[RetrievalResult],
|
| 28 |
user_id: str,
|
| 29 |
db: AsyncSession,
|
|
|
|
| 30 |
limit: int = 100,
|
| 31 |
) -> list[QueryResult]: ...
|
|
|
|
| 27 |
results: list[RetrievalResult],
|
| 28 |
user_id: str,
|
| 29 |
db: AsyncSession,
|
| 30 |
+
question: str,
|
| 31 |
limit: int = 100,
|
| 32 |
) -> list[QueryResult]: ...
|
src/query/executors/db_executor.py
CHANGED
|
@@ -18,6 +18,7 @@ from typing import Any
|
|
| 18 |
|
| 19 |
import sqlglot
|
| 20 |
import sqlglot.expressions as exp
|
|
|
|
| 21 |
from langchain_core.prompts import ChatPromptTemplate
|
| 22 |
from langchain_openai import AzureChatOpenAI
|
| 23 |
from sqlalchemy import text
|
|
@@ -35,6 +36,8 @@ from src.utils.db_credential_encryption import decrypt_credentials_dict
|
|
| 35 |
|
| 36 |
logger = get_logger("db_executor")
|
| 37 |
|
|
|
|
|
|
|
| 38 |
_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
|
| 39 |
_MAX_RETRIES = 3
|
| 40 |
_MAX_LIMIT = 500
|
|
@@ -43,13 +46,15 @@ _SQL_SYSTEM_PROMPT = """\
|
|
| 43 |
You are a SQL data analyst working with a user's database.
|
| 44 |
Generate a single SQL SELECT statement that answers the user's question.
|
| 45 |
|
|
|
|
|
|
|
| 46 |
Rules:
|
| 47 |
- ONLY reference tables and columns listed in the schema below. Do not invent names.
|
| 48 |
- Always include a LIMIT clause (max {limit}).
|
| 49 |
- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
|
| 50 |
- Prefer explicit JOINs over subqueries when combining tables.
|
| 51 |
- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
|
| 52 |
-
- For date filtering, use
|
| 53 |
|
| 54 |
Schema:
|
| 55 |
{schema}
|
|
@@ -81,6 +86,7 @@ class DbExecutor(BaseExecutor):
|
|
| 81 |
results: list[RetrievalResult],
|
| 82 |
user_id: str,
|
| 83 |
db: AsyncSession,
|
|
|
|
| 84 |
limit: int = 100,
|
| 85 |
) -> list[QueryResult]:
|
| 86 |
db_results = [r for r in results if r.source_type == "database"]
|
|
@@ -99,7 +105,7 @@ class DbExecutor(BaseExecutor):
|
|
| 99 |
query_results: list[QueryResult] = []
|
| 100 |
for client_id, client_results in by_client.items():
|
| 101 |
try:
|
| 102 |
-
qr = await self._execute_for_client(client_id, client_results, user_id, db, limit)
|
| 103 |
if qr:
|
| 104 |
query_results.append(qr)
|
| 105 |
except Exception as e:
|
|
@@ -117,6 +123,7 @@ class DbExecutor(BaseExecutor):
|
|
| 117 |
results: list[RetrievalResult],
|
| 118 |
user_id: str,
|
| 119 |
db: AsyncSession,
|
|
|
|
| 120 |
limit: int,
|
| 121 |
) -> QueryResult | None:
|
| 122 |
client = await database_client_service.get(db, client_id)
|
|
@@ -143,17 +150,30 @@ class DbExecutor(BaseExecutor):
|
|
| 143 |
return None
|
| 144 |
|
| 145 |
schema_ctx = self._build_schema_context(full_schema)
|
| 146 |
-
question = self._extract_question(results)
|
| 147 |
capped_limit = min(limit, _MAX_LIMIT)
|
|
|
|
| 148 |
|
| 149 |
# SQL generation with retry
|
| 150 |
validated_sql: str | None = None
|
| 151 |
prev_error: str = ""
|
|
|
|
| 152 |
for attempt in range(_MAX_RETRIES):
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
result: SQLQuery = await self._chain.ainvoke({
|
| 156 |
"schema": schema_ctx,
|
|
|
|
| 157 |
"limit": capped_limit,
|
| 158 |
"error_section": error_section,
|
| 159 |
"question": question,
|
|
@@ -162,10 +182,19 @@ class DbExecutor(BaseExecutor):
|
|
| 162 |
validation_error = self._validate(sql, full_schema, capped_limit)
|
| 163 |
if validation_error:
|
| 164 |
prev_error = validation_error
|
|
|
|
| 165 |
logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
|
| 166 |
continue
|
| 167 |
-
validated_sql = sql
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
break
|
| 170 |
except Exception as e:
|
| 171 |
prev_error = str(e)
|
|
@@ -272,13 +301,6 @@ class DbExecutor(BaseExecutor):
|
|
| 272 |
lines.append("")
|
| 273 |
return "\n".join(lines).strip()
|
| 274 |
|
| 275 |
-
def _extract_question(self, results: list[RetrievalResult]) -> str:
|
| 276 |
-
# The search_query rewritten by the orchestrator is not in RetrievalResult —
|
| 277 |
-
# the content field carries schema descriptions. Return a generic fallback;
|
| 278 |
-
# callers that have the original question should pass it explicitly.
|
| 279 |
-
# TODO: thread the original user question through to execute() when wiring into the agent.
|
| 280 |
-
return "Answer the user's data question using the schema provided."
|
| 281 |
-
|
| 282 |
# ------------------------------------------------------------------
|
| 283 |
# Guardrails
|
| 284 |
# ------------------------------------------------------------------
|
|
|
|
| 18 |
|
| 19 |
import sqlglot
|
| 20 |
import sqlglot.expressions as exp
|
| 21 |
+
import tiktoken
|
| 22 |
from langchain_core.prompts import ChatPromptTemplate
|
| 23 |
from langchain_openai import AzureChatOpenAI
|
| 24 |
from sqlalchemy import text
|
|
|
|
| 36 |
|
| 37 |
logger = get_logger("db_executor")
|
| 38 |
|
| 39 |
+
_enc = tiktoken.get_encoding("cl100k_base")
|
| 40 |
+
|
| 41 |
_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
|
| 42 |
_MAX_RETRIES = 3
|
| 43 |
_MAX_LIMIT = 500
|
|
|
|
| 46 |
You are a SQL data analyst working with a user's database.
|
| 47 |
Generate a single SQL SELECT statement that answers the user's question.
|
| 48 |
|
| 49 |
+
Database dialect: {dialect}
|
| 50 |
+
|
| 51 |
Rules:
|
| 52 |
- ONLY reference tables and columns listed in the schema below. Do not invent names.
|
| 53 |
- Always include a LIMIT clause (max {limit}).
|
| 54 |
- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
|
| 55 |
- Prefer explicit JOINs over subqueries when combining tables.
|
| 56 |
- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
|
| 57 |
+
- For date filtering, use dialect-appropriate functions ({dialect} syntax).
|
| 58 |
|
| 59 |
Schema:
|
| 60 |
{schema}
|
|
|
|
| 86 |
results: list[RetrievalResult],
|
| 87 |
user_id: str,
|
| 88 |
db: AsyncSession,
|
| 89 |
+
question: str,
|
| 90 |
limit: int = 100,
|
| 91 |
) -> list[QueryResult]:
|
| 92 |
db_results = [r for r in results if r.source_type == "database"]
|
|
|
|
| 105 |
query_results: list[QueryResult] = []
|
| 106 |
for client_id, client_results in by_client.items():
|
| 107 |
try:
|
| 108 |
+
qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
|
| 109 |
if qr:
|
| 110 |
query_results.append(qr)
|
| 111 |
except Exception as e:
|
|
|
|
| 123 |
results: list[RetrievalResult],
|
| 124 |
user_id: str,
|
| 125 |
db: AsyncSession,
|
| 126 |
+
question: str,
|
| 127 |
limit: int,
|
| 128 |
) -> QueryResult | None:
|
| 129 |
client = await database_client_service.get(db, client_id)
|
|
|
|
| 150 |
return None
|
| 151 |
|
| 152 |
schema_ctx = self._build_schema_context(full_schema)
|
|
|
|
| 153 |
capped_limit = min(limit, _MAX_LIMIT)
|
| 154 |
+
dialect = client.db_type
|
| 155 |
|
| 156 |
# SQL generation with retry
|
| 157 |
validated_sql: str | None = None
|
| 158 |
prev_error: str = ""
|
| 159 |
+
prev_reasoning: str = ""
|
| 160 |
for attempt in range(_MAX_RETRIES):
|
| 161 |
+
if prev_error:
|
| 162 |
+
error_section = (
|
| 163 |
+
f"Previous attempt reasoning: {prev_reasoning}\n"
|
| 164 |
+
f"Previous attempt failed: {prev_error}\n"
|
| 165 |
+
"Fix the issue above."
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
error_section = ""
|
| 169 |
try:
|
| 170 |
+
prompt_text = schema_ctx + error_section + question
|
| 171 |
+
input_tokens = len(_enc.encode(prompt_text))
|
| 172 |
+
logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
|
| 173 |
+
|
| 174 |
result: SQLQuery = await self._chain.ainvoke({
|
| 175 |
"schema": schema_ctx,
|
| 176 |
+
"dialect": dialect,
|
| 177 |
"limit": capped_limit,
|
| 178 |
"error_section": error_section,
|
| 179 |
"question": question,
|
|
|
|
| 182 |
validation_error = self._validate(sql, full_schema, capped_limit)
|
| 183 |
if validation_error:
|
| 184 |
prev_error = validation_error
|
| 185 |
+
prev_reasoning = result.reasoning
|
| 186 |
logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
|
| 187 |
continue
|
| 188 |
+
validated_sql = self._enforce_limit(sql, capped_limit)
|
| 189 |
+
output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
|
| 190 |
+
logger.info(
|
| 191 |
+
"sql generated",
|
| 192 |
+
attempt=attempt + 1,
|
| 193 |
+
input_tokens=input_tokens,
|
| 194 |
+
output_tokens=output_tokens,
|
| 195 |
+
total_tokens=input_tokens + output_tokens,
|
| 196 |
+
reasoning=result.reasoning,
|
| 197 |
+
)
|
| 198 |
break
|
| 199 |
except Exception as e:
|
| 200 |
prev_error = str(e)
|
|
|
|
| 301 |
lines.append("")
|
| 302 |
return "\n".join(lines).strip()
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
# ------------------------------------------------------------------
|
| 305 |
# Guardrails
|
| 306 |
# ------------------------------------------------------------------
|
src/query/query_executor.py
CHANGED
|
@@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 6 |
|
| 7 |
from src.middlewares.logging import get_logger
|
| 8 |
from src.query.base import QueryResult
|
| 9 |
-
from src.query.executors.
|
| 10 |
from src.query.executors.tabular import tabular_executor
|
| 11 |
from src.rag.base import RetrievalResult
|
| 12 |
|
|
@@ -19,6 +19,7 @@ class QueryExecutor:
|
|
| 19 |
results: list[RetrievalResult],
|
| 20 |
user_id: str,
|
| 21 |
db: AsyncSession,
|
|
|
|
| 22 |
limit: int = 100,
|
| 23 |
) -> list[QueryResult]:
|
| 24 |
db_results = [r for r in results if r.source_type == "database"]
|
|
@@ -32,8 +33,8 @@ class QueryExecutor:
|
|
| 32 |
return []
|
| 33 |
|
| 34 |
batches = await asyncio.gather(
|
| 35 |
-
db_executor.execute(db_results, user_id, db, limit) if db_results else _empty(),
|
| 36 |
-
tabular_executor.execute(tabular_results, user_id, db, limit) if tabular_results else _empty(),
|
| 37 |
return_exceptions=True,
|
| 38 |
)
|
| 39 |
|
|
|
|
| 6 |
|
| 7 |
from src.middlewares.logging import get_logger
|
| 8 |
from src.query.base import QueryResult
|
| 9 |
+
from src.query.executors.db_executor import db_executor
|
| 10 |
from src.query.executors.tabular import tabular_executor
|
| 11 |
from src.rag.base import RetrievalResult
|
| 12 |
|
|
|
|
| 19 |
results: list[RetrievalResult],
|
| 20 |
user_id: str,
|
| 21 |
db: AsyncSession,
|
| 22 |
+
question: str,
|
| 23 |
limit: int = 100,
|
| 24 |
) -> list[QueryResult]:
|
| 25 |
db_results = [r for r in results if r.source_type == "database"]
|
|
|
|
| 33 |
return []
|
| 34 |
|
| 35 |
batches = await asyncio.gather(
|
| 36 |
+
db_executor.execute(db_results, user_id, db, question, limit) if db_results else _empty(),
|
| 37 |
+
tabular_executor.execute(tabular_results, user_id, db, question, limit) if tabular_results else _empty(),
|
| 38 |
return_exceptions=True,
|
| 39 |
)
|
| 40 |
|