Spaces:
Sleeping
Sleeping
| """ | |
| StateReconciler — Dynamic Golden Database Grading Engine. | |
| ARCHITECTURE: | |
| - Instead of hardcoded expected values, we build a "golden" database by running | |
| the correct migration on a fresh copy of the seed data. | |
| - The agent's database is compared table-by-table against this golden reference. | |
| - This makes the grader SEED-INDEPENDENT: if judges change the seed data, | |
| the golden DB auto-updates and scoring remains accurate. | |
| SCORING WEIGHTS (per-table, dynamic): | |
| - Schema match (table exists, correct columns): 30% | |
| - Data match (row count + content): 40% | |
| - FK & constraint integrity: 20% | |
| - Anti-exploit checks: 10% | |
| ANTI-EXPLOIT PROTECTIONS: | |
| - Case-insensitive table/column name comparison | |
| - PRAGMA state preservation (grader doesn't corrupt agent's FK state) | |
| - Phantom row detection (SUM fingerprinting) | |
| - Empty table exploitation blocked | |
| - Extra/leftover table penalty | |
| """ | |
| import sqlite3 | |
| from typing import Any, Dict, List, Optional, Set, Tuple | |
| # Import seeds for golden migration functions | |
| try: | |
| import seeds | |
| except ImportError: | |
| from .. import seeds | |
| def _get_table_names(conn: sqlite3.Connection) -> Set[str]: | |
| """Get all user table names (case-normalized to lowercase).""" | |
| try: | |
| cursor = conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' " | |
| "AND name NOT LIKE 'sqlite_%' ORDER BY name" | |
| ) | |
| return {row[0].lower() for row in cursor.fetchall()} | |
| except Exception: | |
| return set() | |
| def _get_column_info(conn: sqlite3.Connection, table: str) -> List[dict]: | |
| """Get column info for a table. Returns list of {name, type, notnull, pk}.""" | |
| try: | |
| cursor = conn.execute(f"PRAGMA table_info([{table}])") | |
| return [ | |
| { | |
| "name": row[1].lower(), | |
| "type": row[2].upper(), | |
| "notnull": row[3], | |
| "dflt_value": row[4], # Added check for default values | |
| "pk": row[5] | |
| } | |
| for row in cursor.fetchall() | |
| ] | |
| except Exception: | |
| return [] | |
| def _get_column_names(conn: sqlite3.Connection, table: str) -> Set[str]: | |
| """Get column names (lowercase) for a table.""" | |
| return {col["name"] for col in _get_column_info(conn, table)} | |
| def _get_column_signatures(conn: sqlite3.Connection, table: str) -> Set[Tuple[str, str, int, Any]]: | |
| """Get (name, type, notnull, dflt_value) tuples for absolute schema grading.""" | |
| return { | |
| (col["name"], col["type"], col["notnull"], str(col["dflt_value"])) | |
| for col in _get_column_info(conn, table) | |
| } | |
| def _get_row_count(conn: sqlite3.Connection, table: str) -> int: | |
| """Get row count. Returns 0 on error.""" | |
| try: | |
| cursor = conn.execute(f"SELECT COUNT(*) FROM [{table}]") | |
| return cursor.fetchone()[0] | |
| except Exception: | |
| return 0 | |
| def _get_all_rows(conn: sqlite3.Connection, table: str) -> List[Tuple]: | |
| """Get all rows from a table, sorted for deterministic comparison.""" | |
| try: | |
| cols = _get_column_names(conn, table) | |
| if not cols: | |
| return [] | |
| cursor = conn.execute(f"SELECT * FROM [{table}] ORDER BY 1") | |
| return cursor.fetchall() | |
| except Exception: | |
| return [] | |
| def _has_foreign_key(conn: sqlite3.Connection, table: str, ref_table: str) -> bool: | |
| """Check if table has a FK referencing ref_table (case-insensitive).""" | |
| try: | |
| cursor = conn.execute(f"PRAGMA foreign_key_list([{table}])") | |
| for row in cursor.fetchall(): | |
| if row[2].lower() == ref_table.lower(): | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def _count_foreign_keys(conn: sqlite3.Connection, table: str) -> int: | |
| """Count unique FK constraints for a table using the FK id.""" | |
| try: | |
| cursor = conn.execute(f"PRAGMA foreign_key_list([{table}])") | |
| refs = set() | |
| for row in cursor.fetchall(): | |
| refs.add(row[0]) # row[0] is the sequential ID of the foreign key constraint | |
| return len(refs) | |
| except Exception: | |
| return 0 | |
| def _build_golden_db(task_name: str) -> sqlite3.Connection: | |
| """ | |
| Build a golden reference database for a task. | |
| Seeds a fresh in-memory DB with the task's seed data, then applies | |
| the golden migration to produce the expected final state. | |
| """ | |
| task_config = seeds.TASKS[task_name] | |
| conn = sqlite3.connect(":memory:") | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| # Seed with same data as agent | |
| task_config["seed_fn"](conn) | |
| # Apply perfect migration | |
| task_config["golden_fn"](conn) | |
| return conn | |
| def _compare_row_data( | |
| agent_rows: List[Tuple], | |
| golden_rows: List[Tuple], | |
| ) -> float: | |
| """ | |
| Compare row data between agent and golden databases. | |
| Returns a similarity score between 0.0 and 1.0. | |
| Handles: different row counts, partial matches, type coercion differences. | |
| """ | |
| if not golden_rows: | |
| return 1.0 if not agent_rows else 0.0 | |
| if not agent_rows: | |
| return 0.0 | |
| # Exact match | |
| if agent_rows == golden_rows: | |
| return 1.0 | |
| # Row count match bonus | |
| count_match = 1.0 if len(agent_rows) == len(golden_rows) else ( | |
| min(len(agent_rows), len(golden_rows)) / max(len(agent_rows), len(golden_rows)) | |
| ) | |
| # Per-row comparison (order-independent for flexibility) | |
| golden_set = set() | |
| for row in golden_rows: | |
| # Normalize: try numerical comparison first, fallback to string | |
| normalized_golden = [] | |
| for v in row: | |
| if v is None: normalized_golden.append("") | |
| else: | |
| try: | |
| normalized_golden.append(float(v)) | |
| except (ValueError, TypeError): | |
| normalized_golden.append(str(v).strip()) | |
| golden_set.add(tuple(normalized_golden)) | |
| matched = 0 | |
| for row in agent_rows: | |
| # Normalize: try numerical comparison first, fallback to string | |
| normalized_agent = [] | |
| for v in row: | |
| if v is None: normalized_agent.append("") | |
| else: | |
| try: # If it looks like a number, treat as a float for comparison | |
| normalized_agent.append(float(v)) | |
| except (ValueError, TypeError): | |
| normalized_agent.append(str(v).strip()) | |
| normalized_agent = tuple(normalized_agent) | |
| # Look for match in golden set with numerical tolerance | |
| found = False | |
| for g_row in list(golden_set): | |
| # Compare rows with float tolerance where applicable | |
| match = True | |
| if len(normalized_agent) != len(g_row): continue | |
| for a_val, g_val in zip(normalized_agent, g_row): | |
| if isinstance(a_val, float) and isinstance(g_val, float): | |
| if abs(a_val - g_val) > 1e-7: | |
| match = False; break | |
| elif str(a_val) != str(g_val): | |
| match = False; break | |
| if match: | |
| matched += 1 | |
| golden_set.discard(g_row) | |
| found = True | |
| break | |
| if len(golden_rows) == 0: | |
| content_match = 1.0 # Logic fix: if golden has no rows, agent with 0 rows is 1.0 | |
| else: | |
| content_match = matched / len(golden_rows) | |
| # Penalize extra rows (data bloat) | |
| if len(agent_rows) > len(golden_rows): | |
| # 1% penalty per extra row to prevent "guessing" by dumping all possible data | |
| bloat_penalty = max(0.5, 1.0 - (len(agent_rows) - len(golden_rows)) * 0.01) | |
| content_match *= bloat_penalty | |
| return 0.4 * count_match + 0.6 * content_match | |
| class StateReconciler: | |
| """ | |
| Dynamic Golden Database grading engine. | |
| Compares the agent's database state against a dynamically-generated | |
| golden reference database. No hardcoded expected values. | |
| """ | |
| def __init__(self, task_name: str): | |
| self.task_name = task_name | |
| self._last_score: float = 0.0 | |
| self._golden_conn: Optional[sqlite3.Connection] = None | |
| # Build golden reference DB | |
| try: | |
| self._golden_conn = _build_golden_db(task_name) | |
| self._golden_tables = _get_table_names(self._golden_conn) | |
| self._golden_table_data: Dict[str, dict] = {} | |
| for table in self._golden_tables: | |
| self._golden_table_data[table] = { | |
| "columns": _get_column_info(self._golden_conn, table), | |
| "col_names": _get_column_names(self._golden_conn, table), | |
| "col_signatures": _get_column_signatures(self._golden_conn, table), | |
| "rows": _get_all_rows(self._golden_conn, table), | |
| "row_count": _get_row_count(self._golden_conn, table), | |
| "fk_count": _count_foreign_keys(self._golden_conn, table), | |
| } | |
| except Exception: | |
| self._golden_tables = set() | |
| self._golden_table_data = {} | |
| def __del__(self): | |
| """Clean up golden DB connection.""" | |
| if self._golden_conn is not None: | |
| try: | |
| self._golden_conn.close() | |
| except Exception: | |
| pass | |
| def score(self, conn: sqlite3.Connection) -> float: | |
| """ | |
| Compute migration score by comparing agent DB against golden reference. | |
| Scoring breakdown: | |
| - Schema match: 0.30 (tables exist with correct columns) | |
| - Data match: 0.40 (row content matches golden DB) | |
| - FK/constraint integrity: 0.20 (FKs enforced, integrity OK) | |
| - Anti-exploit bonus: 0.10 (no empty tables, no extra tables) | |
| Returns: float in [0.01, 0.99] | |
| """ | |
| try: | |
| return self._score_dynamic(conn) | |
| except Exception: | |
| return 0.01 | |
| def compute_step_reward(self, conn: sqlite3.Connection) -> Tuple[float, float]: | |
| """ | |
| Compute current score and step reward delta. | |
| CRITICAL: Preserves the agent's PRAGMA foreign_keys state. | |
| The grader reads FK state, does its work, then restores it. | |
| """ | |
| # A8: Preserve PRAGMA state | |
| try: | |
| original_fk = conn.execute("PRAGMA foreign_keys").fetchone()[0] | |
| except Exception: | |
| original_fk = 1 | |
| current_score = self.score(conn) | |
| step_reward = current_score - self._last_score | |
| self._last_score = current_score | |
| # A8: Restore original PRAGMA state | |
| try: | |
| conn.execute(f"PRAGMA foreign_keys = {'ON' if original_fk else 'OFF'}") | |
| except Exception: | |
| pass | |
| return current_score, step_reward | |
| def _score_dynamic(self, conn: sqlite3.Connection) -> float: | |
| """Core dynamic scoring: compare agent DB against golden DB.""" | |
| if not self._golden_tables: | |
| return 0.01 | |
| agent_tables = _get_table_names(conn) | |
| # ---- 1. Schema Match (0.30) ---- | |
| schema_score = 0.0 | |
| tables_found = 0 | |
| total_col_match = 0.0 | |
| for table in self._golden_tables: | |
| golden_info = self._golden_table_data[table] | |
| if table in agent_tables: | |
| tables_found += 1 | |
| # Signature (name + type) comparison | |
| agent_cols = _get_column_signatures(conn, table) | |
| golden_cols = golden_info["col_signatures"] | |
| if golden_cols: | |
| col_overlap = len(agent_cols & golden_cols) / len(golden_cols) | |
| total_col_match += col_overlap | |
| else: | |
| total_col_match += 1.0 | |
| if self._golden_tables: | |
| table_ratio = tables_found / len(self._golden_tables) | |
| col_ratio = total_col_match / len(self._golden_tables) if self._golden_tables else 0 | |
| schema_score = 0.15 * table_ratio + 0.15 * col_ratio | |
| # ---- 2. Data Match (0.40) ---- | |
| data_score = 0.0 | |
| data_checks = 0 | |
| for table in self._golden_tables: | |
| golden_info = self._golden_table_data[table] | |
| if table not in agent_tables: | |
| data_checks += 1 | |
| continue | |
| agent_rows = _get_all_rows(conn, table) | |
| golden_rows = golden_info["rows"] | |
| similarity = _compare_row_data(agent_rows, golden_rows) | |
| data_score += similarity | |
| data_checks += 1 | |
| if data_checks > 0: | |
| data_score = 0.40 * (data_score / data_checks) | |
| # ---- 3. FK & Constraint Integrity (0.20) ---- | |
| fk_score = 0.0 | |
| fk_checks = 0 | |
| for table in self._golden_tables: | |
| golden_info = self._golden_table_data[table] | |
| expected_fks = golden_info["fk_count"] | |
| if expected_fks > 0 and table in agent_tables: | |
| agent_fks = _count_foreign_keys(conn, table) | |
| fk_ratio = min(agent_fks, expected_fks) / expected_fks | |
| fk_score += fk_ratio | |
| fk_checks += 1 | |
| # PRAGMA integrity check | |
| integrity_ok = False | |
| try: | |
| # Temporarily enable FK for integrity check | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| cursor = conn.execute("PRAGMA integrity_check") | |
| result = cursor.fetchone()[0] | |
| # Explicitly run foreign_key_check to catch orphaned rows | |
| fk_cursor = conn.execute("PRAGMA foreign_key_check") | |
| fk_violations = fk_cursor.fetchall() | |
| integrity_ok = (result == "ok" and len(fk_violations) == 0) | |
| except Exception: | |
| pass | |
| if fk_checks > 0: | |
| fk_score = 0.10 * (fk_score / fk_checks) | |
| else: | |
| # No FK constraints expected — award full FK portion | |
| fk_score = 0.10 | |
| fk_score += 0.10 if integrity_ok else 0.0 | |
| # ---- 4. Anti-Exploit Checks (0.10) ---- | |
| exploit_score = 0.10 # Start with full credit, deduct for violations | |
| # Check for empty tables where golden has data | |
| for table in self._golden_tables: | |
| golden_info = self._golden_table_data[table] | |
| if golden_info["row_count"] > 0 and table in agent_tables: | |
| agent_count = _get_row_count(conn, table) | |
| if agent_count == 0: | |
| # Agent emptied a table that should have data — heavy penalty | |
| exploit_score = 0.0 | |
| # Also cap the data score for this exploit | |
| data_score = min(data_score, 0.05) | |
| break | |
| # Penalize extra non-golden tables (schema pollution) | |
| extra_tables = agent_tables - self._golden_tables | |
| if extra_tables: | |
| # Small penalty per extra table (some might be temp tables) | |
| penalty = min(0.05, 0.01 * len(extra_tables)) | |
| exploit_score = max(0, exploit_score - penalty) | |
| total = schema_score + data_score + fk_score + exploit_score | |
| return max(0.01, min(0.99, total)) | |