"""IRValidator — checks a QueryIR against a user's catalog. See ARCHITECTURE.md §7 for the validation rules. On failure, the planner is re-prompted with the error context (max 3 retries) — error messages must therefore be specific enough that the LLM can self-correct. """ from ...catalog.models import Catalog, Column, Source, Table from .models import QueryIR from .operators import ( ALLOWED_AGG_FNS, ALLOWED_FILTER_OPS, LIMIT_HARD_CAP, TYPE_COMPATIBILITY, ) _NULLARY_FILTER_OPS = frozenset({"is_null", "is_not_null"}) class IRValidationError(Exception): pass class IRValidator: """Reject IRs that reference unknown sources/tables/columns or use disallowed ops. Rules: - source_id exists in catalog for this user - table_id belongs to that source - every column_id exists in that table - every agg.fn and filter.op is whitelisted (see operators.py) - value_type consistent with column.data_type (TYPE_COMPATIBILITY) - limit positive int, ≤ LIMIT_HARD_CAP """ def validate(self, ir: QueryIR, catalog: Catalog) -> None: source = self._find_source(catalog, ir.source_id) table = self._find_table(source, ir.table_id) columns_by_id: dict[str, Column] = {c.column_id: c for c in table.columns} select_aliases: set[str] = set() for i, item in enumerate(ir.select): where = f"select[{i}]" if item.kind == "column": self._require_column(columns_by_id, item.column_id, where) else: # "agg" if item.fn not in ALLOWED_AGG_FNS: raise IRValidationError( f"{where}.fn: must be in {sorted(ALLOWED_AGG_FNS)}, " f"got {item.fn!r}" ) if item.column_id is not None: self._require_column(columns_by_id, item.column_id, where) elif item.fn != "count": raise IRValidationError( f"{where}.fn={item.fn!r} requires a column_id " "(only 'count' may omit it for COUNT(*))" ) if item.alias: select_aliases.add(item.alias) for i, f in enumerate(ir.filters): where = f"filters[{i}]" col = self._require_column(columns_by_id, f.column_id, where) if f.op not in ALLOWED_FILTER_OPS: raise IRValidationError( f"{where}.op: must be in {sorted(ALLOWED_FILTER_OPS)}, " f"got {f.op!r}" ) if f.op not in _NULLARY_FILTER_OPS: allowed = TYPE_COMPATIBILITY.get(col.data_type, frozenset()) if f.value_type not in allowed: raise IRValidationError( f"{where}: value_type {f.value_type!r} incompatible with " f"column.data_type {col.data_type!r} " f"(allowed: {sorted(allowed)})" ) for i, col_id in enumerate(ir.group_by): self._require_column(columns_by_id, col_id, f"group_by[{i}]") for i, ob in enumerate(ir.order_by): if ob.column_id not in columns_by_id and ob.column_id not in select_aliases: raise IRValidationError( f"order_by[{i}].column_id: {ob.column_id!r} not found in table " f"{ir.table_id!r} columns or select aliases " f"(known columns: {sorted(columns_by_id.keys())}, " f"aliases: {sorted(select_aliases)})" ) if ir.limit is not None: if ir.limit <= 0: raise IRValidationError(f"limit must be positive, got {ir.limit}") if ir.limit > LIMIT_HARD_CAP: raise IRValidationError( f"limit {ir.limit} exceeds hard cap {LIMIT_HARD_CAP}" ) @staticmethod def _find_source(catalog: Catalog, source_id: str) -> Source: for s in catalog.sources: if s.source_id == source_id: return s raise IRValidationError( f"source_id {source_id!r} not in catalog " f"(known: {[s.source_id for s in catalog.sources]})" ) @staticmethod def _find_table(source: Source, table_id: str) -> Table: for t in source.tables: if t.table_id == table_id: return t raise IRValidationError( f"table_id {table_id!r} not in source {source.source_id!r} " f"(known: {[t.table_id for t in source.tables]})" ) @staticmethod def _require_column( columns_by_id: dict[str, Column], col_id: str, where: str ) -> Column: col = columns_by_id.get(col_id) if col is None: raise IRValidationError( f"{where}.column_id: {col_id!r} not in table " f"(known: {sorted(columns_by_id.keys())})" ) return col