| """IRValidator — checks a QueryIR against a user's catalog. |
| |
| See ARCHITECTURE.md §7 for the validation rules. On failure, the planner |
| is re-prompted with the error context (max 3 retries) — error messages |
| must therefore be specific enough that the LLM can self-correct. |
| """ |
|
|
| from ...catalog.models import Catalog, Column, Source, Table |
| from .models import QueryIR |
| from .operators import ( |
| ALLOWED_AGG_FNS, |
| ALLOWED_FILTER_OPS, |
| LIMIT_HARD_CAP, |
| TYPE_COMPATIBILITY, |
| ) |
|
|
| _NULLARY_FILTER_OPS = frozenset({"is_null", "is_not_null"}) |
|
|
|
|
| class IRValidationError(Exception): |
| pass |
|
|
|
|
| class IRValidator: |
| """Reject IRs that reference unknown sources/tables/columns or use disallowed ops. |
| |
| Rules: |
| - source_id exists in catalog for this user |
| - table_id belongs to that source |
| - every column_id exists in that table |
| - every agg.fn and filter.op is whitelisted (see operators.py) |
| - value_type consistent with column.data_type (TYPE_COMPATIBILITY) |
| - limit positive int, ≤ LIMIT_HARD_CAP |
| """ |
|
|
| def validate(self, ir: QueryIR, catalog: Catalog) -> None: |
| source = self._find_source(catalog, ir.source_id) |
| table = self._find_table(source, ir.table_id) |
| columns_by_id: dict[str, Column] = {c.column_id: c for c in table.columns} |
|
|
| select_aliases: set[str] = set() |
| for i, item in enumerate(ir.select): |
| where = f"select[{i}]" |
| if item.kind == "column": |
| self._require_column(columns_by_id, item.column_id, where) |
| else: |
| if item.fn not in ALLOWED_AGG_FNS: |
| raise IRValidationError( |
| f"{where}.fn: must be in {sorted(ALLOWED_AGG_FNS)}, " |
| f"got {item.fn!r}" |
| ) |
| if item.column_id is not None: |
| self._require_column(columns_by_id, item.column_id, where) |
| elif item.fn != "count": |
| raise IRValidationError( |
| f"{where}.fn={item.fn!r} requires a column_id " |
| "(only 'count' may omit it for COUNT(*))" |
| ) |
| if item.alias: |
| select_aliases.add(item.alias) |
|
|
| for i, f in enumerate(ir.filters): |
| where = f"filters[{i}]" |
| col = self._require_column(columns_by_id, f.column_id, where) |
| if f.op not in ALLOWED_FILTER_OPS: |
| raise IRValidationError( |
| f"{where}.op: must be in {sorted(ALLOWED_FILTER_OPS)}, " |
| f"got {f.op!r}" |
| ) |
| if f.op not in _NULLARY_FILTER_OPS: |
| allowed = TYPE_COMPATIBILITY.get(col.data_type, frozenset()) |
| if f.value_type not in allowed: |
| raise IRValidationError( |
| f"{where}: value_type {f.value_type!r} incompatible with " |
| f"column.data_type {col.data_type!r} " |
| f"(allowed: {sorted(allowed)})" |
| ) |
|
|
| for i, col_id in enumerate(ir.group_by): |
| self._require_column(columns_by_id, col_id, f"group_by[{i}]") |
|
|
| for i, ob in enumerate(ir.order_by): |
| if ob.column_id not in columns_by_id and ob.column_id not in select_aliases: |
| raise IRValidationError( |
| f"order_by[{i}].column_id: {ob.column_id!r} not found in table " |
| f"{ir.table_id!r} columns or select aliases " |
| f"(known columns: {sorted(columns_by_id.keys())}, " |
| f"aliases: {sorted(select_aliases)})" |
| ) |
|
|
| if ir.limit is not None: |
| if ir.limit <= 0: |
| raise IRValidationError(f"limit must be positive, got {ir.limit}") |
| if ir.limit > LIMIT_HARD_CAP: |
| raise IRValidationError( |
| f"limit {ir.limit} exceeds hard cap {LIMIT_HARD_CAP}" |
| ) |
|
|
| @staticmethod |
| def _find_source(catalog: Catalog, source_id: str) -> Source: |
| for s in catalog.sources: |
| if s.source_id == source_id: |
| return s |
| raise IRValidationError( |
| f"source_id {source_id!r} not in catalog " |
| f"(known: {[s.source_id for s in catalog.sources]})" |
| ) |
|
|
| @staticmethod |
| def _find_table(source: Source, table_id: str) -> Table: |
| for t in source.tables: |
| if t.table_id == table_id: |
| return t |
| raise IRValidationError( |
| f"table_id {table_id!r} not in source {source.source_id!r} " |
| f"(known: {[t.table_id for t in source.tables]})" |
| ) |
|
|
| @staticmethod |
| def _require_column( |
| columns_by_id: dict[str, Column], col_id: str, where: str |
| ) -> Column: |
| col = columns_by_id.get(col_id) |
| if col is None: |
| raise IRValidationError( |
| f"{where}.column_id: {col_id!r} not in table " |
| f"(known: {sorted(columns_by_id.keys())})" |
| ) |
| return col |
|
|