| """ |
| Simple schema linking for Spider-style Text-to-SQL. |
| |
| Goal: |
| - Given (question, db_id), select a small set of relevant tables/columns |
| to include in the prompt (RAG-style schema retrieval). |
| |
| Design constraints: |
| - Pure Python (no heavy external deps). |
| - Robust to missing/odd schemas: never crash. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import sqlite3 |
| from contextlib import closing |
| from dataclasses import dataclass |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
|
|
| _ALNUM_RE = re.compile(r"[A-Za-z0-9]+") |
| _CAMEL_RE = re.compile(r"([a-z])([A-Z])") |
|
|
|
|
| def _normalize_identifier(text: str) -> str: |
| """ |
| Normalize a schema identifier: |
| - split underscores |
| - split camelCase / PascalCase boundaries |
| - lowercase |
| """ |
| text = str(text or "") |
| text = text.replace("_", " ") |
| text = _CAMEL_RE.sub(r"\1 \2", text) |
| return text.lower() |
|
|
|
|
| def _tokenize(text: str) -> List[str]: |
| text = _normalize_identifier(text) |
| return _ALNUM_RE.findall(text) |
|
|
|
|
| @dataclass(frozen=True) |
| class TableSchema: |
| table_name: str |
| columns: Tuple[str, ...] |
|
|
|
|
| class SchemaLinker: |
| """ |
| Loads Spider `tables.json` and (optionally) SQLite schemas from disk. |
| Provides a lightweight table scoring function based on token overlap. |
| """ |
|
|
| def __init__(self, tables_json_path: str, db_root: Optional[str] = None): |
| self.tables_json_path = tables_json_path |
| self.db_root = db_root |
| self._tables_by_db: Dict[str, List[TableSchema]] = {} |
| self._sqlite_schema_cache: Dict[str, Dict[str, List[str]]] = {} |
| self._load_tables_json() |
|
|
| def _load_tables_json(self) -> None: |
| with open(self.tables_json_path) as f: |
| entries = json.load(f) |
|
|
| tables_by_db: Dict[str, List[TableSchema]] = {} |
| for entry in entries: |
| db_id = entry["db_id"] |
| table_names: List[str] = entry.get("table_names_original") or entry.get("table_names") or [] |
| col_names: List[Sequence] = entry.get("column_names_original") or entry.get("column_names") or [] |
|
|
| columns_by_table_idx: Dict[int, List[str]] = {i: [] for i in range(len(table_names))} |
| for col in col_names: |
| |
| if not col or len(col) < 2: |
| continue |
| table_idx, col_name = col[0], col[1] |
| if table_idx is None or table_idx < 0: |
| continue |
| if table_idx not in columns_by_table_idx: |
| continue |
| columns_by_table_idx[table_idx].append(str(col_name)) |
|
|
| tables: List[TableSchema] = [] |
| for i, tname in enumerate(table_names): |
| cols = tuple(columns_by_table_idx.get(i, [])) |
| tables.append(TableSchema(table_name=str(tname), columns=cols)) |
|
|
| tables_by_db[db_id] = tables |
|
|
| self._tables_by_db = tables_by_db |
|
|
| def _db_path(self, db_id: str) -> Optional[str]: |
| if not self.db_root: |
| return None |
| path = os.path.join(self.db_root, db_id, f"{db_id}.sqlite") |
| return path if os.path.exists(path) else None |
|
|
| def _load_sqlite_schema(self, db_id: str) -> Dict[str, List[str]]: |
| """ |
| Load actual SQLite schema (table -> columns). Cached per db_id. |
| """ |
| if db_id in self._sqlite_schema_cache: |
| return self._sqlite_schema_cache[db_id] |
|
|
| schema: Dict[str, List[str]] = {} |
| db_path = self._db_path(db_id) |
| if not db_path: |
| self._sqlite_schema_cache[db_id] = schema |
| return schema |
|
|
| try: |
| with closing(sqlite3.connect(db_path)) as conn: |
| cursor = conn.cursor() |
| tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() |
| for (table_name,) in tables: |
| columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall() |
| schema[str(table_name)] = [str(col[1]) for col in columns] |
| except Exception: |
| schema = {} |
|
|
| self._sqlite_schema_cache[db_id] = schema |
| return schema |
|
|
| def get_schema(self, db_id: str) -> List[TableSchema]: |
| """ |
| Returns a list of table schemas for this db. |
| Prefers `tables.json` (Spider canonical), but can fallback to SQLite if needed. |
| """ |
| tables = self._tables_by_db.get(db_id, []) |
| if tables: |
| return tables |
|
|
| sqlite_schema = self._load_sqlite_schema(db_id) |
| return [TableSchema(table_name=t, columns=tuple(cols)) for t, cols in sqlite_schema.items()] |
|
|
| def score_tables(self, question: str, db_id: str) -> List[Tuple[float, TableSchema]]: |
| """ |
| Score each table using token overlap: |
| - table token overlap (higher weight) |
| - column token overlap (lower weight) |
| """ |
| q_tokens = set(_tokenize(question)) |
| tables = self.get_schema(db_id) |
|
|
| scored: List[Tuple[float, TableSchema]] = [] |
| for t in tables: |
| table_tokens = set(_tokenize(t.table_name)) |
| col_tokens: set[str] = set() |
| for c in t.columns: |
| col_tokens.update(_tokenize(c)) |
|
|
| table_overlap = len(q_tokens & table_tokens) |
| col_overlap = len(q_tokens & col_tokens) |
|
|
| |
| score = 3.0 * table_overlap + 1.0 * col_overlap |
|
|
| |
| q_text = _normalize_identifier(question) |
| if t.table_name and _normalize_identifier(t.table_name) in q_text: |
| score += 0.5 |
|
|
| scored.append((score, t)) |
|
|
| scored.sort(key=lambda x: (x[0], x[1].table_name), reverse=True) |
| return scored |
|
|
| def select_top_tables(self, question: str, db_id: str, top_k: int = 4) -> List[TableSchema]: |
| scored = self.score_tables(question, db_id) |
| if not scored: |
| return [] |
| top_k = max(1, int(top_k)) |
| selected = [t for _, t in scored[:top_k]] |
|
|
| |
| if scored[0][0] <= 0: |
| tables = self.get_schema(db_id) |
| return tables[:top_k] |
|
|
| return selected |
|
|
| def columns_for_selected_tables(self, db_id: str, selected_tables: Iterable[TableSchema]) -> Dict[str, List[str]]: |
| """ |
| Returns only columns belonging to selected tables. |
| Prefer SQLite columns (actual DB) if available; fallback to tables.json. |
| """ |
| sqlite_schema = self._load_sqlite_schema(db_id) |
| out: Dict[str, List[str]] = {} |
| for t in selected_tables: |
| if t.table_name in sqlite_schema and sqlite_schema[t.table_name]: |
| out[t.table_name] = sqlite_schema[t.table_name] |
| else: |
| out[t.table_name] = list(t.columns) |
| return out |
|
|
| def format_relevant_schema(self, question: str, db_id: str, top_k: int = 4) -> Tuple[List[str], Dict[str, List[str]]]: |
| """ |
| Returns: |
| - lines: ["table(col1, col2)", ...] |
| - selected: {table: [cols...], ...} |
| """ |
| selected_tables = self.select_top_tables(question, db_id, top_k=top_k) |
| selected = self.columns_for_selected_tables(db_id, selected_tables) |
|
|
| lines: List[str] = [] |
| for table_name, cols in selected.items(): |
| cols_str = ", ".join(cols) |
| lines.append(f"{table_name}({cols_str})") |
|
|
| return lines, selected |
|
|
|
|