Rifqi Hafizuddin commited on
Commit ·
bd2b1d9
1
Parent(s): f273db0
[NOTICKET] db_executor: CTE DML check now walks entire AST root, schema: cast instead of string interpolation
Browse files
src/query/executors/db_executor.py
CHANGED
|
@@ -316,10 +316,9 @@ class DbExecutor(BaseExecutor):
|
|
| 316 |
if not isinstance(parsed, exp.Select):
|
| 317 |
return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
|
| 318 |
|
| 319 |
-
# Check for DML
|
| 320 |
-
for
|
| 321 |
-
|
| 322 |
-
return f"DML ({type(node).__name__}) inside CTE is not allowed."
|
| 323 |
|
| 324 |
# Layer 2: schema grounding — table names
|
| 325 |
known_tables = {t.lower() for t in schema}
|
|
@@ -342,12 +341,13 @@ class DbExecutor(BaseExecutor):
|
|
| 342 |
if existing:
|
| 343 |
current = int(existing.expression.this)
|
| 344 |
if current > limit:
|
| 345 |
-
existing.expression.set("this",
|
| 346 |
else:
|
| 347 |
parsed = parsed.limit(limit)
|
| 348 |
return parsed.sql()
|
| 349 |
|
| 350 |
def _run_sql(self, engine: Any, sql: str) -> list[dict]:
|
|
|
|
| 351 |
with engine.connect() as conn:
|
| 352 |
result = conn.execute(text(sql))
|
| 353 |
return [dict(row) for row in result.mappings()]
|
|
|
|
| 316 |
if not isinstance(parsed, exp.Select):
|
| 317 |
return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
|
| 318 |
|
| 319 |
+
# Check for DML anywhere in the AST (including writeable CTEs)
|
| 320 |
+
for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
|
| 321 |
+
return f"DML ({type(node).__name__}) is not allowed."
|
|
|
|
| 322 |
|
| 323 |
# Layer 2: schema grounding — table names
|
| 324 |
known_tables = {t.lower() for t in schema}
|
|
|
|
| 341 |
if existing:
|
| 342 |
current = int(existing.expression.this)
|
| 343 |
if current > limit:
|
| 344 |
+
existing.expression.set("this", limit)
|
| 345 |
else:
|
| 346 |
parsed = parsed.limit(limit)
|
| 347 |
return parsed.sql()
|
| 348 |
|
| 349 |
def _run_sql(self, engine: Any, sql: str) -> list[dict]:
|
| 350 |
+
# Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient.
|
| 351 |
with engine.connect() as conn:
|
| 352 |
result = conn.execute(text(sql))
|
| 353 |
return [dict(row) for row in result.mappings()]
|
src/rag/retrievers/schema.py
CHANGED
|
@@ -52,11 +52,11 @@ class SchemaRetriever(BaseRetriever):
|
|
| 52 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 53 |
|
| 54 |
if operator == "<#>":
|
| 55 |
-
score_sql =
|
| 56 |
elif operator == "<->":
|
| 57 |
-
score_sql =
|
| 58 |
else:
|
| 59 |
-
score_sql =
|
| 60 |
|
| 61 |
sql = text(f"""
|
| 62 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
@@ -65,12 +65,12 @@ class SchemaRetriever(BaseRetriever):
|
|
| 65 |
WHERE lpc.name = 'document_embeddings'
|
| 66 |
AND lpe.cmetadata->>'user_id' = :user_id
|
| 67 |
AND lpe.cmetadata->>'source_type' = 'database'
|
| 68 |
-
ORDER BY lpe.embedding {operator}
|
| 69 |
LIMIT :k
|
| 70 |
""")
|
| 71 |
|
| 72 |
async with _pgvector_engine.connect() as conn:
|
| 73 |
-
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
|
| 74 |
rows = result.fetchall()
|
| 75 |
|
| 76 |
return [
|
|
@@ -90,11 +90,11 @@ class SchemaRetriever(BaseRetriever):
|
|
| 90 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 91 |
|
| 92 |
if operator == "<#>":
|
| 93 |
-
score_sql =
|
| 94 |
elif operator == "<->":
|
| 95 |
-
score_sql =
|
| 96 |
else:
|
| 97 |
-
score_sql =
|
| 98 |
|
| 99 |
sql = text(f"""
|
| 100 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
@@ -105,12 +105,12 @@ class SchemaRetriever(BaseRetriever):
|
|
| 105 |
AND lpe.cmetadata->>'source_type' = 'document'
|
| 106 |
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 107 |
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 108 |
-
ORDER BY lpe.embedding {operator}
|
| 109 |
LIMIT :k
|
| 110 |
""")
|
| 111 |
|
| 112 |
async with _pgvector_engine.connect() as conn:
|
| 113 |
-
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
|
| 114 |
rows = result.fetchall()
|
| 115 |
|
| 116 |
results = []
|
|
|
|
| 52 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 53 |
|
| 54 |
if operator == "<#>":
|
| 55 |
+
score_sql = "(lpe.embedding <#> :emb::vector) * -1"
|
| 56 |
elif operator == "<->":
|
| 57 |
+
score_sql = "1.0 / (1.0 + (lpe.embedding <-> :emb::vector))"
|
| 58 |
else:
|
| 59 |
+
score_sql = "1.0 - (lpe.embedding <=> :emb::vector)"
|
| 60 |
|
| 61 |
sql = text(f"""
|
| 62 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
|
|
| 65 |
WHERE lpc.name = 'document_embeddings'
|
| 66 |
AND lpe.cmetadata->>'user_id' = :user_id
|
| 67 |
AND lpe.cmetadata->>'source_type' = 'database'
|
| 68 |
+
ORDER BY lpe.embedding {operator} :emb::vector ASC
|
| 69 |
LIMIT :k
|
| 70 |
""")
|
| 71 |
|
| 72 |
async with _pgvector_engine.connect() as conn:
|
| 73 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4, "emb": emb_str})
|
| 74 |
rows = result.fetchall()
|
| 75 |
|
| 76 |
return [
|
|
|
|
| 90 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 91 |
|
| 92 |
if operator == "<#>":
|
| 93 |
+
score_sql = "(lpe.embedding <#> :emb::vector) * -1"
|
| 94 |
elif operator == "<->":
|
| 95 |
+
score_sql = "1.0 / (1.0 + (lpe.embedding <-> :emb::vector))"
|
| 96 |
else:
|
| 97 |
+
score_sql = "1.0 - (lpe.embedding <=> :emb::vector)"
|
| 98 |
|
| 99 |
sql = text(f"""
|
| 100 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
|
|
| 105 |
AND lpe.cmetadata->>'source_type' = 'document'
|
| 106 |
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 107 |
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 108 |
+
ORDER BY lpe.embedding {operator} :emb::vector ASC
|
| 109 |
LIMIT :k
|
| 110 |
""")
|
| 111 |
|
| 112 |
async with _pgvector_engine.connect() as conn:
|
| 113 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4, "emb": emb_str})
|
| 114 |
rows = result.fetchall()
|
| 115 |
|
| 116 |
results = []
|