| """SqlCompiler — IR → (SQL string, named-params dict). |
| |
| Identifiers (table / column names) come from the catalog and are quoted |
| verbatim — they were verified by the IR validator against the catalog, |
| so injection through identifiers is not possible at this layer. |
| Values from filter clauses are ALWAYS parameterized. |
| |
| The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders |
| (`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params |
| dict on a sync SQLAlchemy engine. |
| |
| v1 supports the Postgres dialect only. Supabase reuses the same compiler |
| output (Supabase = Postgres). MySQL / BigQuery / Snowflake compilers will |
| be separate classes that implement `BaseCompiler`. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Any |
|
|
| from ...catalog.models import Catalog, Column, Source, Table |
| from ..ir.models import ( |
| AggSelect, |
| ColumnSelect, |
| FilterClause, |
| OrderByClause, |
| QueryIR, |
| SelectItem, |
| ) |
| from .base import BaseCompiler |
|
|
|
|
| @dataclass |
| class CompiledSql: |
| sql: str |
| params: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| class SqlCompilerError(Exception): |
| pass |
|
|
|
|
| _NULLARY_OPS = frozenset({"is_null", "is_not_null"}) |
| _LIST_OPS = frozenset({"in", "not_in"}) |
| _COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="}) |
|
|
|
|
| class SqlCompiler(BaseCompiler): |
| """Deterministic IR → Postgres SQL. No LLM.""" |
|
|
| def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None: |
| if dialect not in {"postgres", "supabase"}: |
| raise SqlCompilerError( |
| f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}" |
| ) |
| self._catalog = catalog |
| self._dialect = dialect |
|
|
| def compile(self, ir: QueryIR) -> CompiledSql: |
| _, table, cols_by_id = self._lookup(ir) |
| params: dict[str, Any] = {} |
| param_seq = [0] |
|
|
| select_clause, select_aliases = self._build_select(ir.select, table, cols_by_id) |
| from_clause = self._build_from(table) |
| where_clause = self._build_where(ir.filters, table, cols_by_id, params, param_seq) |
| groupby_clause = self._build_groupby(ir.group_by, table, cols_by_id) |
| orderby_clause = self._build_orderby( |
| ir.order_by, table, cols_by_id, select_aliases |
| ) |
| limit_clause = self._build_limit(ir.limit) |
|
|
| parts: list[str] = [select_clause, from_clause] |
| for clause in (where_clause, groupby_clause, orderby_clause, limit_clause): |
| if clause: |
| parts.append(clause) |
|
|
| return CompiledSql(sql=" ".join(parts), params=params) |
|
|
| |
| |
| |
|
|
| def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]: |
| source = next( |
| (s for s in self._catalog.sources if s.source_id == ir.source_id), None |
| ) |
| if source is None: |
| raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog") |
| table = next( |
| (t for t in source.tables if t.table_id == ir.table_id), None |
| ) |
| if table is None: |
| raise SqlCompilerError( |
| f"table_id {ir.table_id!r} not in source {ir.source_id!r}" |
| ) |
| return source, table, {c.column_id: c for c in table.columns} |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _qident(name: str) -> str: |
| """Postgres-style double-quoted identifier with embedded-quote escape.""" |
| return '"' + name.replace('"', '""') + '"' |
|
|
| def _qcol(self, table: Table, col: Column) -> str: |
| return f"{self._qident(table.name)}.{self._qident(col.name)}" |
|
|
| |
| |
| |
|
|
| def _build_select( |
| self, |
| items: list[SelectItem], |
| table: Table, |
| cols_by_id: dict[str, Column], |
| ) -> tuple[str, set[str]]: |
| if not items: |
| raise SqlCompilerError("select clause cannot be empty") |
| parts: list[str] = [] |
| aliases: set[str] = set() |
| for i, item in enumerate(items): |
| expr, alias = self._select_item(item, table, cols_by_id, i) |
| if alias: |
| parts.append(f"{expr} AS {self._qident(alias)}") |
| aliases.add(alias) |
| else: |
| parts.append(expr) |
| return "SELECT " + ", ".join(parts), aliases |
|
|
| def _select_item( |
| self, |
| item: SelectItem, |
| table: Table, |
| cols_by_id: dict[str, Column], |
| index: int, |
| ) -> tuple[str, str | None]: |
| if isinstance(item, ColumnSelect): |
| col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") |
| return self._qcol(table, col), item.alias |
| if not isinstance(item, AggSelect): |
| raise SqlCompilerError( |
| f"select[{index}]: unknown SelectItem kind {type(item).__name__}" |
| ) |
| return self._compile_agg(item, table, cols_by_id, index), item.alias |
|
|
| def _compile_agg( |
| self, |
| item: AggSelect, |
| table: Table, |
| cols_by_id: dict[str, Column], |
| index: int, |
| ) -> str: |
| if item.fn == "count_distinct": |
| if item.column_id is None: |
| raise SqlCompilerError( |
| f"select[{index}].fn=count_distinct requires column_id" |
| ) |
| col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") |
| return f"COUNT(DISTINCT {self._qcol(table, col)})" |
| if item.column_id is None: |
| if item.fn != "count": |
| raise SqlCompilerError( |
| f"select[{index}].fn={item.fn!r} requires column_id " |
| "(only 'count' may omit it for COUNT(*))" |
| ) |
| return "COUNT(*)" |
| col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") |
| return f"{item.fn.upper()}({self._qcol(table, col)})" |
|
|
| def _build_from(self, table: Table) -> str: |
| return f"FROM {self._qident(table.name)}" |
|
|
| def _build_where( |
| self, |
| filters: list[FilterClause], |
| table: Table, |
| cols_by_id: dict[str, Column], |
| params: dict[str, Any], |
| param_seq: list[int], |
| ) -> str: |
| if not filters: |
| return "" |
| parts = [ |
| self._compile_filter(f, table, cols_by_id, params, param_seq, index=i) |
| for i, f in enumerate(filters) |
| ] |
| return "WHERE " + " AND ".join(parts) |
|
|
| def _compile_filter( |
| self, |
| f: FilterClause, |
| table: Table, |
| cols_by_id: dict[str, Column], |
| params: dict[str, Any], |
| param_seq: list[int], |
| index: int, |
| ) -> str: |
| col = self._require_col(cols_by_id, f.column_id, f"filters[{index}]") |
| col_ref = self._qcol(table, col) |
| op = f.op |
|
|
| if op == "is_null": |
| return f"{col_ref} IS NULL" |
| if op == "is_not_null": |
| return f"{col_ref} IS NOT NULL" |
|
|
| if op in _LIST_OPS: |
| if not isinstance(f.value, list) or not f.value: |
| raise SqlCompilerError( |
| f"filters[{index}]: op {op!r} requires a non-empty list value" |
| ) |
| placeholders = [ |
| ":" + self._next_param(params, param_seq, v) for v in f.value |
| ] |
| sql_op = "IN" if op == "in" else "NOT IN" |
| return f"{col_ref} {sql_op} ({', '.join(placeholders)})" |
|
|
| if op == "between": |
| if not isinstance(f.value, list) or len(f.value) != 2: |
| raise SqlCompilerError( |
| f"filters[{index}]: op 'between' requires a list of two values" |
| ) |
| lo = self._next_param(params, param_seq, f.value[0]) |
| hi = self._next_param(params, param_seq, f.value[1]) |
| return f"{col_ref} BETWEEN :{lo} AND :{hi}" |
|
|
| if op == "like": |
| p = self._next_param(params, param_seq, f.value) |
| return f"{col_ref} LIKE :{p}" |
|
|
| if op in _COMPARISON_OPS: |
| p = self._next_param(params, param_seq, f.value) |
| return f"{col_ref} {op} :{p}" |
|
|
| |
| raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}") |
|
|
| def _build_groupby( |
| self, |
| group_by: list[str], |
| table: Table, |
| cols_by_id: dict[str, Column], |
| ) -> str: |
| if not group_by: |
| return "" |
| parts = [ |
| self._qcol(table, self._require_col(cols_by_id, col_id, f"group_by[{i}]")) |
| for i, col_id in enumerate(group_by) |
| ] |
| return "GROUP BY " + ", ".join(parts) |
|
|
| def _build_orderby( |
| self, |
| order_by: list[OrderByClause], |
| table: Table, |
| cols_by_id: dict[str, Column], |
| select_aliases: set[str], |
| ) -> str: |
| if not order_by: |
| return "" |
| parts: list[str] = [] |
| for i, ob in enumerate(order_by): |
| if ob.column_id in cols_by_id: |
| ref = self._qcol(table, cols_by_id[ob.column_id]) |
| elif ob.column_id in select_aliases: |
| ref = self._qident(ob.column_id) |
| else: |
| raise SqlCompilerError( |
| f"order_by[{i}].column_id: {ob.column_id!r} not in table " |
| "columns or select aliases" |
| ) |
| parts.append(f"{ref} {ob.dir.upper()}") |
| return "ORDER BY " + ", ".join(parts) |
|
|
| def _build_limit(self, limit: int | None) -> str: |
| if limit is None: |
| return "" |
| return f"LIMIT {int(limit)}" |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _next_param( |
| params: dict[str, Any], param_seq: list[int], value: Any |
| ) -> str: |
| name = f"p_{param_seq[0]}" |
| param_seq[0] += 1 |
| params[name] = value |
| return name |
|
|
| @staticmethod |
| def _require_col( |
| cols_by_id: dict[str, Column], col_id: str, where: str |
| ) -> Column: |
| col = cols_by_id.get(col_id) |
| if col is None: |
| raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in table") |
| return col |
|
|