Rifqi Hafizuddin commited on
Commit
abc494f
·
1 Parent(s): 1fef470

[KM-512] connect query executor to user question. add logging for db_executor

Browse files
src/query/base.py CHANGED
@@ -27,5 +27,6 @@ class BaseExecutor(ABC):
27
  results: list[RetrievalResult],
28
  user_id: str,
29
  db: AsyncSession,
 
30
  limit: int = 100,
31
  ) -> list[QueryResult]: ...
 
27
  results: list[RetrievalResult],
28
  user_id: str,
29
  db: AsyncSession,
30
+ question: str,
31
  limit: int = 100,
32
  ) -> list[QueryResult]: ...
src/query/executors/db_executor.py CHANGED
@@ -18,6 +18,7 @@ from typing import Any
18
 
19
  import sqlglot
20
  import sqlglot.expressions as exp
 
21
  from langchain_core.prompts import ChatPromptTemplate
22
  from langchain_openai import AzureChatOpenAI
23
  from sqlalchemy import text
@@ -35,6 +36,8 @@ from src.utils.db_credential_encryption import decrypt_credentials_dict
35
 
36
  logger = get_logger("db_executor")
37
 
 
 
38
  _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
39
  _MAX_RETRIES = 3
40
  _MAX_LIMIT = 500
@@ -43,13 +46,15 @@ _SQL_SYSTEM_PROMPT = """\
43
  You are a SQL data analyst working with a user's database.
44
  Generate a single SQL SELECT statement that answers the user's question.
45
 
 
 
46
  Rules:
47
  - ONLY reference tables and columns listed in the schema below. Do not invent names.
48
  - Always include a LIMIT clause (max {limit}).
49
  - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
50
  - Prefer explicit JOINs over subqueries when combining tables.
51
  - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
52
- - For date filtering, use standard SQL date functions appropriate for the dialect.
53
 
54
  Schema:
55
  {schema}
@@ -81,6 +86,7 @@ class DbExecutor(BaseExecutor):
81
  results: list[RetrievalResult],
82
  user_id: str,
83
  db: AsyncSession,
 
84
  limit: int = 100,
85
  ) -> list[QueryResult]:
86
  db_results = [r for r in results if r.source_type == "database"]
@@ -99,7 +105,7 @@ class DbExecutor(BaseExecutor):
99
  query_results: list[QueryResult] = []
100
  for client_id, client_results in by_client.items():
101
  try:
102
- qr = await self._execute_for_client(client_id, client_results, user_id, db, limit)
103
  if qr:
104
  query_results.append(qr)
105
  except Exception as e:
@@ -117,6 +123,7 @@ class DbExecutor(BaseExecutor):
117
  results: list[RetrievalResult],
118
  user_id: str,
119
  db: AsyncSession,
 
120
  limit: int,
121
  ) -> QueryResult | None:
122
  client = await database_client_service.get(db, client_id)
@@ -143,17 +150,30 @@ class DbExecutor(BaseExecutor):
143
  return None
144
 
145
  schema_ctx = self._build_schema_context(full_schema)
146
- question = self._extract_question(results)
147
  capped_limit = min(limit, _MAX_LIMIT)
 
148
 
149
  # SQL generation with retry
150
  validated_sql: str | None = None
151
  prev_error: str = ""
 
152
  for attempt in range(_MAX_RETRIES):
153
- error_section = f"Previous attempt failed: {prev_error}\nFix the issue above." if prev_error else ""
 
 
 
 
 
 
 
154
  try:
 
 
 
 
155
  result: SQLQuery = await self._chain.ainvoke({
156
  "schema": schema_ctx,
 
157
  "limit": capped_limit,
158
  "error_section": error_section,
159
  "question": question,
@@ -162,10 +182,19 @@ class DbExecutor(BaseExecutor):
162
  validation_error = self._validate(sql, full_schema, capped_limit)
163
  if validation_error:
164
  prev_error = validation_error
 
165
  logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
166
  continue
167
- validated_sql = sql
168
- logger.info("sql generated", attempt=attempt + 1, reasoning=result.reasoning)
 
 
 
 
 
 
 
 
169
  break
170
  except Exception as e:
171
  prev_error = str(e)
@@ -272,13 +301,6 @@ class DbExecutor(BaseExecutor):
272
  lines.append("")
273
  return "\n".join(lines).strip()
274
 
275
- def _extract_question(self, results: list[RetrievalResult]) -> str:
276
- # The search_query rewritten by the orchestrator is not in RetrievalResult —
277
- # the content field carries schema descriptions. Return a generic fallback;
278
- # callers that have the original question should pass it explicitly.
279
- # TODO: thread the original user question through to execute() when wiring into the agent.
280
- return "Answer the user's data question using the schema provided."
281
-
282
  # ------------------------------------------------------------------
283
  # Guardrails
284
  # ------------------------------------------------------------------
 
18
 
19
  import sqlglot
20
  import sqlglot.expressions as exp
21
+ import tiktoken
22
  from langchain_core.prompts import ChatPromptTemplate
23
  from langchain_openai import AzureChatOpenAI
24
  from sqlalchemy import text
 
36
 
37
  logger = get_logger("db_executor")
38
 
39
+ _enc = tiktoken.get_encoding("cl100k_base")
40
+
41
  _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
42
  _MAX_RETRIES = 3
43
  _MAX_LIMIT = 500
 
46
  You are a SQL data analyst working with a user's database.
47
  Generate a single SQL SELECT statement that answers the user's question.
48
 
49
+ Database dialect: {dialect}
50
+
51
  Rules:
52
  - ONLY reference tables and columns listed in the schema below. Do not invent names.
53
  - Always include a LIMIT clause (max {limit}).
54
  - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
55
  - Prefer explicit JOINs over subqueries when combining tables.
56
  - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
57
+ - For date filtering, use dialect-appropriate functions ({dialect} syntax).
58
 
59
  Schema:
60
  {schema}
 
86
  results: list[RetrievalResult],
87
  user_id: str,
88
  db: AsyncSession,
89
+ question: str,
90
  limit: int = 100,
91
  ) -> list[QueryResult]:
92
  db_results = [r for r in results if r.source_type == "database"]
 
105
  query_results: list[QueryResult] = []
106
  for client_id, client_results in by_client.items():
107
  try:
108
+ qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
109
  if qr:
110
  query_results.append(qr)
111
  except Exception as e:
 
123
  results: list[RetrievalResult],
124
  user_id: str,
125
  db: AsyncSession,
126
+ question: str,
127
  limit: int,
128
  ) -> QueryResult | None:
129
  client = await database_client_service.get(db, client_id)
 
150
  return None
151
 
152
  schema_ctx = self._build_schema_context(full_schema)
 
153
  capped_limit = min(limit, _MAX_LIMIT)
154
+ dialect = client.db_type
155
 
156
  # SQL generation with retry
157
  validated_sql: str | None = None
158
  prev_error: str = ""
159
+ prev_reasoning: str = ""
160
  for attempt in range(_MAX_RETRIES):
161
+ if prev_error:
162
+ error_section = (
163
+ f"Previous attempt reasoning: {prev_reasoning}\n"
164
+ f"Previous attempt failed: {prev_error}\n"
165
+ "Fix the issue above."
166
+ )
167
+ else:
168
+ error_section = ""
169
  try:
170
+ prompt_text = schema_ctx + error_section + question
171
+ input_tokens = len(_enc.encode(prompt_text))
172
+ logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
173
+
174
  result: SQLQuery = await self._chain.ainvoke({
175
  "schema": schema_ctx,
176
+ "dialect": dialect,
177
  "limit": capped_limit,
178
  "error_section": error_section,
179
  "question": question,
 
182
  validation_error = self._validate(sql, full_schema, capped_limit)
183
  if validation_error:
184
  prev_error = validation_error
185
+ prev_reasoning = result.reasoning
186
  logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
187
  continue
188
+ validated_sql = self._enforce_limit(sql, capped_limit)
189
+ output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
190
+ logger.info(
191
+ "sql generated",
192
+ attempt=attempt + 1,
193
+ input_tokens=input_tokens,
194
+ output_tokens=output_tokens,
195
+ total_tokens=input_tokens + output_tokens,
196
+ reasoning=result.reasoning,
197
+ )
198
  break
199
  except Exception as e:
200
  prev_error = str(e)
 
301
  lines.append("")
302
  return "\n".join(lines).strip()
303
 
 
 
 
 
 
 
 
304
  # ------------------------------------------------------------------
305
  # Guardrails
306
  # ------------------------------------------------------------------
src/query/query_executor.py CHANGED
@@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
6
 
7
  from src.middlewares.logging import get_logger
8
  from src.query.base import QueryResult
9
- from src.query.executors.db import db_executor
10
  from src.query.executors.tabular import tabular_executor
11
  from src.rag.base import RetrievalResult
12
 
@@ -19,6 +19,7 @@ class QueryExecutor:
19
  results: list[RetrievalResult],
20
  user_id: str,
21
  db: AsyncSession,
 
22
  limit: int = 100,
23
  ) -> list[QueryResult]:
24
  db_results = [r for r in results if r.source_type == "database"]
@@ -32,8 +33,8 @@ class QueryExecutor:
32
  return []
33
 
34
  batches = await asyncio.gather(
35
- db_executor.execute(db_results, user_id, db, limit) if db_results else _empty(),
36
- tabular_executor.execute(tabular_results, user_id, db, limit) if tabular_results else _empty(),
37
  return_exceptions=True,
38
  )
39
 
 
6
 
7
  from src.middlewares.logging import get_logger
8
  from src.query.base import QueryResult
9
+ from src.query.executors.db_executor import db_executor
10
  from src.query.executors.tabular import tabular_executor
11
  from src.rag.base import RetrievalResult
12
 
 
19
  results: list[RetrievalResult],
20
  user_id: str,
21
  db: AsyncSession,
22
+ question: str,
23
  limit: int = 100,
24
  ) -> list[QueryResult]:
25
  db_results = [r for r in results if r.source_type == "database"]
 
33
  return []
34
 
35
  batches = await asyncio.gather(
36
+ db_executor.execute(db_results, user_id, db, question, limit) if db_results else _empty(),
37
+ tabular_executor.execute(tabular_results, user_id, db, question, limit) if tabular_results else _empty(),
38
  return_exceptions=True,
39
  )
40