"""SqlCompiler — IR → (SQL string, named-params dict). Identifiers (table / column names) come from the catalog and are quoted verbatim — they were verified by the IR validator against the catalog, so injection through identifiers is not possible at this layer. Values from filter clauses are ALWAYS parameterized. The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders (`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params dict on a sync SQLAlchemy engine. v1 supports the Postgres dialect only. Supabase reuses the same compiler output (Supabase = Postgres). MySQL / BigQuery / Snowflake compilers will be separate classes that implement `BaseCompiler`. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any from ...catalog.models import Catalog, Column, Source, Table from ..ir.models import ( AggSelect, ColumnSelect, FilterClause, OrderByClause, QueryIR, SelectItem, ) from .base import BaseCompiler @dataclass class CompiledSql: sql: str params: dict[str, Any] = field(default_factory=dict) class SqlCompilerError(Exception): pass _NULLARY_OPS = frozenset({"is_null", "is_not_null"}) _LIST_OPS = frozenset({"in", "not_in"}) _COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="}) class SqlCompiler(BaseCompiler): """Deterministic IR → Postgres SQL. No LLM.""" def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None: if dialect not in {"postgres", "supabase"}: raise SqlCompilerError( f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}" ) self._catalog = catalog self._dialect = dialect def compile(self, ir: QueryIR) -> CompiledSql: _, table, cols_by_id = self._lookup(ir) params: dict[str, Any] = {} param_seq = [0] select_clause, select_aliases = self._build_select(ir.select, table, cols_by_id) from_clause = self._build_from(table) where_clause = self._build_where(ir.filters, table, cols_by_id, params, param_seq) groupby_clause = self._build_groupby(ir.group_by, table, cols_by_id) orderby_clause = self._build_orderby( ir.order_by, table, cols_by_id, select_aliases ) limit_clause = self._build_limit(ir.limit) parts: list[str] = [select_clause, from_clause] for clause in (where_clause, groupby_clause, orderby_clause, limit_clause): if clause: parts.append(clause) return CompiledSql(sql=" ".join(parts), params=params) # ------------------------------------------------------------------ # Catalog lookup # ------------------------------------------------------------------ def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]: source = next( (s for s in self._catalog.sources if s.source_id == ir.source_id), None ) if source is None: raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog") table = next( (t for t in source.tables if t.table_id == ir.table_id), None ) if table is None: raise SqlCompilerError( f"table_id {ir.table_id!r} not in source {ir.source_id!r}" ) return source, table, {c.column_id: c for c in table.columns} # ------------------------------------------------------------------ # Identifier quoting # ------------------------------------------------------------------ @staticmethod def _qident(name: str) -> str: """Postgres-style double-quoted identifier with embedded-quote escape.""" return '"' + name.replace('"', '""') + '"' def _qcol(self, table: Table, col: Column) -> str: return f"{self._qident(table.name)}.{self._qident(col.name)}" # ------------------------------------------------------------------ # Clauses # ------------------------------------------------------------------ def _build_select( self, items: list[SelectItem], table: Table, cols_by_id: dict[str, Column], ) -> tuple[str, set[str]]: if not items: raise SqlCompilerError("select clause cannot be empty") parts: list[str] = [] aliases: set[str] = set() for i, item in enumerate(items): expr, alias = self._select_item(item, table, cols_by_id, i) if alias: parts.append(f"{expr} AS {self._qident(alias)}") aliases.add(alias) else: parts.append(expr) return "SELECT " + ", ".join(parts), aliases def _select_item( self, item: SelectItem, table: Table, cols_by_id: dict[str, Column], index: int, ) -> tuple[str, str | None]: if isinstance(item, ColumnSelect): col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") return self._qcol(table, col), item.alias if not isinstance(item, AggSelect): raise SqlCompilerError( f"select[{index}]: unknown SelectItem kind {type(item).__name__}" ) return self._compile_agg(item, table, cols_by_id, index), item.alias def _compile_agg( self, item: AggSelect, table: Table, cols_by_id: dict[str, Column], index: int, ) -> str: if item.fn == "count_distinct": if item.column_id is None: raise SqlCompilerError( f"select[{index}].fn=count_distinct requires column_id" ) col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") return f"COUNT(DISTINCT {self._qcol(table, col)})" if item.column_id is None: if item.fn != "count": raise SqlCompilerError( f"select[{index}].fn={item.fn!r} requires column_id " "(only 'count' may omit it for COUNT(*))" ) return "COUNT(*)" col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") return f"{item.fn.upper()}({self._qcol(table, col)})" def _build_from(self, table: Table) -> str: return f"FROM {self._qident(table.name)}" def _build_where( self, filters: list[FilterClause], table: Table, cols_by_id: dict[str, Column], params: dict[str, Any], param_seq: list[int], ) -> str: if not filters: return "" parts = [ self._compile_filter(f, table, cols_by_id, params, param_seq, index=i) for i, f in enumerate(filters) ] return "WHERE " + " AND ".join(parts) def _compile_filter( self, f: FilterClause, table: Table, cols_by_id: dict[str, Column], params: dict[str, Any], param_seq: list[int], index: int, ) -> str: col = self._require_col(cols_by_id, f.column_id, f"filters[{index}]") col_ref = self._qcol(table, col) op = f.op if op == "is_null": return f"{col_ref} IS NULL" if op == "is_not_null": return f"{col_ref} IS NOT NULL" if op in _LIST_OPS: if not isinstance(f.value, list) or not f.value: raise SqlCompilerError( f"filters[{index}]: op {op!r} requires a non-empty list value" ) placeholders = [ ":" + self._next_param(params, param_seq, v) for v in f.value ] sql_op = "IN" if op == "in" else "NOT IN" return f"{col_ref} {sql_op} ({', '.join(placeholders)})" if op == "between": if not isinstance(f.value, list) or len(f.value) != 2: raise SqlCompilerError( f"filters[{index}]: op 'between' requires a list of two values" ) lo = self._next_param(params, param_seq, f.value[0]) hi = self._next_param(params, param_seq, f.value[1]) return f"{col_ref} BETWEEN :{lo} AND :{hi}" if op == "like": p = self._next_param(params, param_seq, f.value) return f"{col_ref} LIKE :{p}" if op in _COMPARISON_OPS: p = self._next_param(params, param_seq, f.value) return f"{col_ref} {op} :{p}" # Should not reach here — IRValidator already filters disallowed ops raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}") def _build_groupby( self, group_by: list[str], table: Table, cols_by_id: dict[str, Column], ) -> str: if not group_by: return "" parts = [ self._qcol(table, self._require_col(cols_by_id, col_id, f"group_by[{i}]")) for i, col_id in enumerate(group_by) ] return "GROUP BY " + ", ".join(parts) def _build_orderby( self, order_by: list[OrderByClause], table: Table, cols_by_id: dict[str, Column], select_aliases: set[str], ) -> str: if not order_by: return "" parts: list[str] = [] for i, ob in enumerate(order_by): if ob.column_id in cols_by_id: ref = self._qcol(table, cols_by_id[ob.column_id]) elif ob.column_id in select_aliases: ref = self._qident(ob.column_id) else: raise SqlCompilerError( f"order_by[{i}].column_id: {ob.column_id!r} not in table " "columns or select aliases" ) parts.append(f"{ref} {ob.dir.upper()}") return "ORDER BY " + ", ".join(parts) def _build_limit(self, limit: int | None) -> str: if limit is None: return "" return f"LIMIT {int(limit)}" # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _next_param( params: dict[str, Any], param_seq: list[int], value: Any ) -> str: name = f"p_{param_seq[0]}" param_seq[0] += 1 params[name] = value return name @staticmethod def _require_col( cols_by_id: dict[str, Column], col_id: str, where: str ) -> Column: col = cols_by_id.get(col_id) if col is None: raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in table") return col