ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""IRValidator — checks a QueryIR against a user's catalog.
See ARCHITECTURE.md §7 for the validation rules. On failure, the planner
is re-prompted with the error context (max 3 retries) — error messages
must therefore be specific enough that the LLM can self-correct.
"""
from ...catalog.models import Catalog, Column, Source, Table
from .models import QueryIR
from .operators import (
ALLOWED_AGG_FNS,
ALLOWED_FILTER_OPS,
LIMIT_HARD_CAP,
TYPE_COMPATIBILITY,
)
_NULLARY_FILTER_OPS = frozenset({"is_null", "is_not_null"})
class IRValidationError(Exception):
pass
class IRValidator:
"""Reject IRs that reference unknown sources/tables/columns or use disallowed ops.
Rules:
- source_id exists in catalog for this user
- table_id belongs to that source
- every column_id exists in that table
- every agg.fn and filter.op is whitelisted (see operators.py)
- value_type consistent with column.data_type (TYPE_COMPATIBILITY)
- limit positive int, ≤ LIMIT_HARD_CAP
"""
def validate(self, ir: QueryIR, catalog: Catalog) -> None:
source = self._find_source(catalog, ir.source_id)
table = self._find_table(source, ir.table_id)
columns_by_id: dict[str, Column] = {c.column_id: c for c in table.columns}
select_aliases: set[str] = set()
for i, item in enumerate(ir.select):
where = f"select[{i}]"
if item.kind == "column":
self._require_column(columns_by_id, item.column_id, where)
else: # "agg"
if item.fn not in ALLOWED_AGG_FNS:
raise IRValidationError(
f"{where}.fn: must be in {sorted(ALLOWED_AGG_FNS)}, "
f"got {item.fn!r}"
)
if item.column_id is not None:
self._require_column(columns_by_id, item.column_id, where)
elif item.fn != "count":
raise IRValidationError(
f"{where}.fn={item.fn!r} requires a column_id "
"(only 'count' may omit it for COUNT(*))"
)
if item.alias:
select_aliases.add(item.alias)
for i, f in enumerate(ir.filters):
where = f"filters[{i}]"
col = self._require_column(columns_by_id, f.column_id, where)
if f.op not in ALLOWED_FILTER_OPS:
raise IRValidationError(
f"{where}.op: must be in {sorted(ALLOWED_FILTER_OPS)}, "
f"got {f.op!r}"
)
if f.op not in _NULLARY_FILTER_OPS:
allowed = TYPE_COMPATIBILITY.get(col.data_type, frozenset())
if f.value_type not in allowed:
raise IRValidationError(
f"{where}: value_type {f.value_type!r} incompatible with "
f"column.data_type {col.data_type!r} "
f"(allowed: {sorted(allowed)})"
)
for i, col_id in enumerate(ir.group_by):
self._require_column(columns_by_id, col_id, f"group_by[{i}]")
for i, ob in enumerate(ir.order_by):
if ob.column_id not in columns_by_id and ob.column_id not in select_aliases:
raise IRValidationError(
f"order_by[{i}].column_id: {ob.column_id!r} not found in table "
f"{ir.table_id!r} columns or select aliases "
f"(known columns: {sorted(columns_by_id.keys())}, "
f"aliases: {sorted(select_aliases)})"
)
if ir.limit is not None:
if ir.limit <= 0:
raise IRValidationError(f"limit must be positive, got {ir.limit}")
if ir.limit > LIMIT_HARD_CAP:
raise IRValidationError(
f"limit {ir.limit} exceeds hard cap {LIMIT_HARD_CAP}"
)
@staticmethod
def _find_source(catalog: Catalog, source_id: str) -> Source:
for s in catalog.sources:
if s.source_id == source_id:
return s
raise IRValidationError(
f"source_id {source_id!r} not in catalog "
f"(known: {[s.source_id for s in catalog.sources]})"
)
@staticmethod
def _find_table(source: Source, table_id: str) -> Table:
for t in source.tables:
if t.table_id == table_id:
return t
raise IRValidationError(
f"table_id {table_id!r} not in source {source.source_id!r} "
f"(known: {[t.table_id for t in source.tables]})"
)
@staticmethod
def _require_column(
columns_by_id: dict[str, Column], col_id: str, where: str
) -> Column:
col = columns_by_id.get(col_id)
if col is None:
raise IRValidationError(
f"{where}.column_id: {col_id!r} not in table "
f"(known: {sorted(columns_by_id.keys())})"
)
return col