Rifqi Hafizuddin commited on
Commit
bd2b1d9
·
1 Parent(s): f273db0

[NOTICKET] db_executor: CTE DML check now walks entire AST root, schema: cast instead of string interpolation

Browse files
src/query/executors/db_executor.py CHANGED
@@ -316,10 +316,9 @@ class DbExecutor(BaseExecutor):
316
  if not isinstance(parsed, exp.Select):
317
  return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
318
 
319
- # Check for DML inside CTEs
320
- for cte in parsed.find_all(exp.With):
321
- for node in cte.find_all((exp.Insert, exp.Update, exp.Delete)):
322
- return f"DML ({type(node).__name__}) inside CTE is not allowed."
323
 
324
  # Layer 2: schema grounding — table names
325
  known_tables = {t.lower() for t in schema}
@@ -342,12 +341,13 @@ class DbExecutor(BaseExecutor):
342
  if existing:
343
  current = int(existing.expression.this)
344
  if current > limit:
345
- existing.expression.set("this", str(limit))
346
  else:
347
  parsed = parsed.limit(limit)
348
  return parsed.sql()
349
 
350
  def _run_sql(self, engine: Any, sql: str) -> list[dict]:
 
351
  with engine.connect() as conn:
352
  result = conn.execute(text(sql))
353
  return [dict(row) for row in result.mappings()]
 
316
  if not isinstance(parsed, exp.Select):
317
  return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
318
 
319
+ # Check for DML anywhere in the AST (including writeable CTEs)
320
+ for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
321
+ return f"DML ({type(node).__name__}) is not allowed."
 
322
 
323
  # Layer 2: schema grounding — table names
324
  known_tables = {t.lower() for t in schema}
 
341
  if existing:
342
  current = int(existing.expression.this)
343
  if current > limit:
344
+ existing.expression.set("this", limit)
345
  else:
346
  parsed = parsed.limit(limit)
347
  return parsed.sql()
348
 
349
  def _run_sql(self, engine: Any, sql: str) -> list[dict]:
350
+ # Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient.
351
  with engine.connect() as conn:
352
  result = conn.execute(text(sql))
353
  return [dict(row) for row in result.mappings()]
src/rag/retrievers/schema.py CHANGED
@@ -52,11 +52,11 @@ class SchemaRetriever(BaseRetriever):
52
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
 
54
  if operator == "<#>":
55
- score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
56
  elif operator == "<->":
57
- score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
58
  else:
59
- score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
60
 
61
  sql = text(f"""
62
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
@@ -65,12 +65,12 @@ class SchemaRetriever(BaseRetriever):
65
  WHERE lpc.name = 'document_embeddings'
66
  AND lpe.cmetadata->>'user_id' = :user_id
67
  AND lpe.cmetadata->>'source_type' = 'database'
68
- ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
69
  LIMIT :k
70
  """)
71
 
72
  async with _pgvector_engine.connect() as conn:
73
- result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
74
  rows = result.fetchall()
75
 
76
  return [
@@ -90,11 +90,11 @@ class SchemaRetriever(BaseRetriever):
90
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
 
92
  if operator == "<#>":
93
- score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
94
  elif operator == "<->":
95
- score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
96
  else:
97
- score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
98
 
99
  sql = text(f"""
100
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
@@ -105,12 +105,12 @@ class SchemaRetriever(BaseRetriever):
105
  AND lpe.cmetadata->>'source_type' = 'document'
106
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
- ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
109
  LIMIT :k
110
  """)
111
 
112
  async with _pgvector_engine.connect() as conn:
113
- result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
114
  rows = result.fetchall()
115
 
116
  results = []
 
52
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
 
54
  if operator == "<#>":
55
+ score_sql = "(lpe.embedding <#> :emb::vector) * -1"
56
  elif operator == "<->":
57
+ score_sql = "1.0 / (1.0 + (lpe.embedding <-> :emb::vector))"
58
  else:
59
+ score_sql = "1.0 - (lpe.embedding <=> :emb::vector)"
60
 
61
  sql = text(f"""
62
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
 
65
  WHERE lpc.name = 'document_embeddings'
66
  AND lpe.cmetadata->>'user_id' = :user_id
67
  AND lpe.cmetadata->>'source_type' = 'database'
68
+ ORDER BY lpe.embedding {operator} :emb::vector ASC
69
  LIMIT :k
70
  """)
71
 
72
  async with _pgvector_engine.connect() as conn:
73
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4, "emb": emb_str})
74
  rows = result.fetchall()
75
 
76
  return [
 
90
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
 
92
  if operator == "<#>":
93
+ score_sql = "(lpe.embedding <#> :emb::vector) * -1"
94
  elif operator == "<->":
95
+ score_sql = "1.0 / (1.0 + (lpe.embedding <-> :emb::vector))"
96
  else:
97
+ score_sql = "1.0 - (lpe.embedding <=> :emb::vector)"
98
 
99
  sql = text(f"""
100
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
 
105
  AND lpe.cmetadata->>'source_type' = 'document'
106
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
+ ORDER BY lpe.embedding {operator} :emb::vector ASC
109
  LIMIT :k
110
  """)
111
 
112
  async with _pgvector_engine.connect() as conn:
113
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4, "emb": emb_str})
114
  rows = result.fetchall()
115
 
116
  results = []