ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""SqlCompiler — IR → (SQL string, named-params dict).
Identifiers (table / column names) come from the catalog and are quoted
verbatim — they were verified by the IR validator against the catalog,
so injection through identifiers is not possible at this layer.
Values from filter clauses are ALWAYS parameterized.
The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders
(`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params
dict on a sync SQLAlchemy engine.
v1 supports the Postgres dialect only. Supabase reuses the same compiler
output (Supabase = Postgres). MySQL / BigQuery / Snowflake compilers will
be separate classes that implement `BaseCompiler`.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from ...catalog.models import Catalog, Column, Source, Table
from ..ir.models import (
AggSelect,
ColumnSelect,
FilterClause,
OrderByClause,
QueryIR,
SelectItem,
)
from .base import BaseCompiler
@dataclass
class CompiledSql:
sql: str
params: dict[str, Any] = field(default_factory=dict)
class SqlCompilerError(Exception):
pass
_NULLARY_OPS = frozenset({"is_null", "is_not_null"})
_LIST_OPS = frozenset({"in", "not_in"})
_COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="})
class SqlCompiler(BaseCompiler):
"""Deterministic IR → Postgres SQL. No LLM."""
def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None:
if dialect not in {"postgres", "supabase"}:
raise SqlCompilerError(
f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}"
)
self._catalog = catalog
self._dialect = dialect
def compile(self, ir: QueryIR) -> CompiledSql:
_, table, cols_by_id = self._lookup(ir)
params: dict[str, Any] = {}
param_seq = [0]
select_clause, select_aliases = self._build_select(ir.select, table, cols_by_id)
from_clause = self._build_from(table)
where_clause = self._build_where(ir.filters, table, cols_by_id, params, param_seq)
groupby_clause = self._build_groupby(ir.group_by, table, cols_by_id)
orderby_clause = self._build_orderby(
ir.order_by, table, cols_by_id, select_aliases
)
limit_clause = self._build_limit(ir.limit)
parts: list[str] = [select_clause, from_clause]
for clause in (where_clause, groupby_clause, orderby_clause, limit_clause):
if clause:
parts.append(clause)
return CompiledSql(sql=" ".join(parts), params=params)
# ------------------------------------------------------------------
# Catalog lookup
# ------------------------------------------------------------------
def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]:
source = next(
(s for s in self._catalog.sources if s.source_id == ir.source_id), None
)
if source is None:
raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog")
table = next(
(t for t in source.tables if t.table_id == ir.table_id), None
)
if table is None:
raise SqlCompilerError(
f"table_id {ir.table_id!r} not in source {ir.source_id!r}"
)
return source, table, {c.column_id: c for c in table.columns}
# ------------------------------------------------------------------
# Identifier quoting
# ------------------------------------------------------------------
@staticmethod
def _qident(name: str) -> str:
"""Postgres-style double-quoted identifier with embedded-quote escape."""
return '"' + name.replace('"', '""') + '"'
def _qcol(self, table: Table, col: Column) -> str:
return f"{self._qident(table.name)}.{self._qident(col.name)}"
# ------------------------------------------------------------------
# Clauses
# ------------------------------------------------------------------
def _build_select(
self,
items: list[SelectItem],
table: Table,
cols_by_id: dict[str, Column],
) -> tuple[str, set[str]]:
if not items:
raise SqlCompilerError("select clause cannot be empty")
parts: list[str] = []
aliases: set[str] = set()
for i, item in enumerate(items):
expr, alias = self._select_item(item, table, cols_by_id, i)
if alias:
parts.append(f"{expr} AS {self._qident(alias)}")
aliases.add(alias)
else:
parts.append(expr)
return "SELECT " + ", ".join(parts), aliases
def _select_item(
self,
item: SelectItem,
table: Table,
cols_by_id: dict[str, Column],
index: int,
) -> tuple[str, str | None]:
if isinstance(item, ColumnSelect):
col = self._require_col(cols_by_id, item.column_id, f"select[{index}]")
return self._qcol(table, col), item.alias
if not isinstance(item, AggSelect):
raise SqlCompilerError(
f"select[{index}]: unknown SelectItem kind {type(item).__name__}"
)
return self._compile_agg(item, table, cols_by_id, index), item.alias
def _compile_agg(
self,
item: AggSelect,
table: Table,
cols_by_id: dict[str, Column],
index: int,
) -> str:
if item.fn == "count_distinct":
if item.column_id is None:
raise SqlCompilerError(
f"select[{index}].fn=count_distinct requires column_id"
)
col = self._require_col(cols_by_id, item.column_id, f"select[{index}]")
return f"COUNT(DISTINCT {self._qcol(table, col)})"
if item.column_id is None:
if item.fn != "count":
raise SqlCompilerError(
f"select[{index}].fn={item.fn!r} requires column_id "
"(only 'count' may omit it for COUNT(*))"
)
return "COUNT(*)"
col = self._require_col(cols_by_id, item.column_id, f"select[{index}]")
return f"{item.fn.upper()}({self._qcol(table, col)})"
def _build_from(self, table: Table) -> str:
return f"FROM {self._qident(table.name)}"
def _build_where(
self,
filters: list[FilterClause],
table: Table,
cols_by_id: dict[str, Column],
params: dict[str, Any],
param_seq: list[int],
) -> str:
if not filters:
return ""
parts = [
self._compile_filter(f, table, cols_by_id, params, param_seq, index=i)
for i, f in enumerate(filters)
]
return "WHERE " + " AND ".join(parts)
def _compile_filter(
self,
f: FilterClause,
table: Table,
cols_by_id: dict[str, Column],
params: dict[str, Any],
param_seq: list[int],
index: int,
) -> str:
col = self._require_col(cols_by_id, f.column_id, f"filters[{index}]")
col_ref = self._qcol(table, col)
op = f.op
if op == "is_null":
return f"{col_ref} IS NULL"
if op == "is_not_null":
return f"{col_ref} IS NOT NULL"
if op in _LIST_OPS:
if not isinstance(f.value, list) or not f.value:
raise SqlCompilerError(
f"filters[{index}]: op {op!r} requires a non-empty list value"
)
placeholders = [
":" + self._next_param(params, param_seq, v) for v in f.value
]
sql_op = "IN" if op == "in" else "NOT IN"
return f"{col_ref} {sql_op} ({', '.join(placeholders)})"
if op == "between":
if not isinstance(f.value, list) or len(f.value) != 2:
raise SqlCompilerError(
f"filters[{index}]: op 'between' requires a list of two values"
)
lo = self._next_param(params, param_seq, f.value[0])
hi = self._next_param(params, param_seq, f.value[1])
return f"{col_ref} BETWEEN :{lo} AND :{hi}"
if op == "like":
p = self._next_param(params, param_seq, f.value)
return f"{col_ref} LIKE :{p}"
if op in _COMPARISON_OPS:
p = self._next_param(params, param_seq, f.value)
return f"{col_ref} {op} :{p}"
# Should not reach here — IRValidator already filters disallowed ops
raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}")
def _build_groupby(
self,
group_by: list[str],
table: Table,
cols_by_id: dict[str, Column],
) -> str:
if not group_by:
return ""
parts = [
self._qcol(table, self._require_col(cols_by_id, col_id, f"group_by[{i}]"))
for i, col_id in enumerate(group_by)
]
return "GROUP BY " + ", ".join(parts)
def _build_orderby(
self,
order_by: list[OrderByClause],
table: Table,
cols_by_id: dict[str, Column],
select_aliases: set[str],
) -> str:
if not order_by:
return ""
parts: list[str] = []
for i, ob in enumerate(order_by):
if ob.column_id in cols_by_id:
ref = self._qcol(table, cols_by_id[ob.column_id])
elif ob.column_id in select_aliases:
ref = self._qident(ob.column_id)
else:
raise SqlCompilerError(
f"order_by[{i}].column_id: {ob.column_id!r} not in table "
"columns or select aliases"
)
parts.append(f"{ref} {ob.dir.upper()}")
return "ORDER BY " + ", ".join(parts)
def _build_limit(self, limit: int | None) -> str:
if limit is None:
return ""
return f"LIMIT {int(limit)}"
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _next_param(
params: dict[str, Any], param_seq: list[int], value: Any
) -> str:
name = f"p_{param_seq[0]}"
param_seq[0] += 1
params[name] = value
return name
@staticmethod
def _require_col(
cols_by_id: dict[str, Column], col_id: str, where: str
) -> Column:
col = cols_by_id.get(col_id)
if col is None:
raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in table")
return col