File size: 5,192 Bytes
72805b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
SQLite Database Manager for SQL Arena.
Creates in-memory SQLite databases for each task.
Executes agent queries safely and formats results.
Key design decisions:
- In-memory databases (fast, no disk I/O, no cleanup needed)
- Each reset() creates a fresh database
- Query execution is sandboxed (read-only would be ideal but SQLite
in-memory is ephemeral anyway)
"""
import sqlite3
from typing import Tuple, Optional, Any, Dict
class DatabaseManager:
"""
Manages SQLite in-memory databases for SQL challenges.
Each task gets its own fresh database with schema and sample data.
The agent's queries are executed against this database.
"""
def __init__(self):
self.conn: Optional[sqlite3.Connection] = None
def create_database(self, setup_sql: str) -> None:
"""
Create a new in-memory database with the given schema and data.
Args:
setup_sql: SQL string containing CREATE TABLE and INSERT statements
"""
# Close any existing connection
self.close()
# Create fresh in-memory database
self.conn = sqlite3.connect(":memory:")
# Enable foreign keys
self.conn.execute("PRAGMA foreign_keys = ON")
# Run the setup SQL (creates tables and inserts data)
self.conn.executescript(setup_sql)
self.conn.commit()
def execute_query(self, sql: str) -> Tuple[bool, Optional[Dict], Optional[str]]:
"""
Execute a SQL query and return results.
This is the main method called when the agent submits a query.
It catches all exceptions to prevent crashes.
Args:
sql: The SQL query string to execute
Returns:
Tuple of (success, result_dict, error_message):
- success: True if query executed without error
- result_dict: {"columns": [...], "rows": [...]} if successful
- error_message: Error string if failed, None if success
"""
if not self.conn:
return False, None, "No database connection. Call create_database() first."
try:
cursor = self.conn.execute(sql)
# Get column names from cursor description
if cursor.description:
columns = [desc[0] for desc in cursor.description]
else:
columns = []
# Fetch all rows
rows = cursor.fetchall()
result = {
"columns": columns,
"rows": rows,
}
return True, result, None
except sqlite3.Error as e:
return False, None, f"SQL Error: {str(e)}"
except Exception as e:
return False, None, f"Execution Error: {str(e)}"
def format_result(self, result: Dict, max_rows: int = 20) -> str:
"""
Format query result as a human-readable table string.
This formatted string is shown to the agent in the observation
so it can see what its query returned.
Args:
result: Dict with "columns" and "rows" keys
max_rows: Maximum number of rows to display
Returns:
Formatted table string
"""
if not result or not result.get("columns"):
return "(empty result set)"
columns = result["columns"]
rows = result["rows"]
if not rows:
return f"Columns: {', '.join(columns)}\n(0 rows returned)"
# Calculate column widths (at least as wide as header)
col_widths = [len(str(c)) for c in columns]
for row in rows[:max_rows]:
for i, val in enumerate(row):
if i < len(col_widths):
col_widths[i] = max(col_widths[i], len(str(val)))
# Build formatted table
# Header
header = " | ".join(
str(c).ljust(w) for c, w in zip(columns, col_widths)
)
separator = "-+-".join("-" * w for w in col_widths)
# Data rows
formatted_rows = []
for row in rows[:max_rows]:
formatted_row = " | ".join(
str(v).ljust(w) for v, w in zip(row, col_widths)
)
formatted_rows.append(formatted_row)
# Assemble
table_str = f"{header}\n{separator}\n" + "\n".join(formatted_rows)
# Truncation notice
if len(rows) > max_rows:
table_str += f"\n... ({len(rows) - max_rows} more rows not shown)"
# Row count
table_str += f"\n\n({len(rows)} row{'s' if len(rows) != 1 else ''} returned)"
return table_str
def close(self) -> None:
"""Close the database connection and free resources."""
if self.conn:
try:
self.conn.close()
except Exception:
pass # Ignore errors on close
self.conn = None |