File size: 5,192 Bytes
72805b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""

SQLite Database Manager for SQL Arena.



Creates in-memory SQLite databases for each task.

Executes agent queries safely and formats results.



Key design decisions:

- In-memory databases (fast, no disk I/O, no cleanup needed)

- Each reset() creates a fresh database

- Query execution is sandboxed (read-only would be ideal but SQLite

  in-memory is ephemeral anyway)

"""

import sqlite3
from typing import Tuple, Optional, Any, Dict


class DatabaseManager:
    """

    Manages SQLite in-memory databases for SQL challenges.

    

    Each task gets its own fresh database with schema and sample data.

    The agent's queries are executed against this database.

    """

    def __init__(self):
        self.conn: Optional[sqlite3.Connection] = None

    def create_database(self, setup_sql: str) -> None:
        """

        Create a new in-memory database with the given schema and data.

        

        Args:

            setup_sql: SQL string containing CREATE TABLE and INSERT statements

        """
        # Close any existing connection
        self.close()
        
        # Create fresh in-memory database
        self.conn = sqlite3.connect(":memory:")
        
        # Enable foreign keys
        self.conn.execute("PRAGMA foreign_keys = ON")
        
        # Run the setup SQL (creates tables and inserts data)
        self.conn.executescript(setup_sql)
        self.conn.commit()

    def execute_query(self, sql: str) -> Tuple[bool, Optional[Dict], Optional[str]]:
        """

        Execute a SQL query and return results.

        

        This is the main method called when the agent submits a query.

        It catches all exceptions to prevent crashes.

        

        Args:

            sql: The SQL query string to execute

            

        Returns:

            Tuple of (success, result_dict, error_message):

            - success: True if query executed without error

            - result_dict: {"columns": [...], "rows": [...]} if successful

            - error_message: Error string if failed, None if success

        """
        if not self.conn:
            return False, None, "No database connection. Call create_database() first."

        try:
            cursor = self.conn.execute(sql)
            
            # Get column names from cursor description
            if cursor.description:
                columns = [desc[0] for desc in cursor.description]
            else:
                columns = []
            
            # Fetch all rows
            rows = cursor.fetchall()
            
            result = {
                "columns": columns,
                "rows": rows,
            }
            
            return True, result, None
            
        except sqlite3.Error as e:
            return False, None, f"SQL Error: {str(e)}"
        except Exception as e:
            return False, None, f"Execution Error: {str(e)}"

    def format_result(self, result: Dict, max_rows: int = 20) -> str:
        """

        Format query result as a human-readable table string.

        

        This formatted string is shown to the agent in the observation

        so it can see what its query returned.

        

        Args:

            result: Dict with "columns" and "rows" keys

            max_rows: Maximum number of rows to display

            

        Returns:

            Formatted table string

        """
        if not result or not result.get("columns"):
            return "(empty result set)"

        columns = result["columns"]
        rows = result["rows"]

        if not rows:
            return f"Columns: {', '.join(columns)}\n(0 rows returned)"

        # Calculate column widths (at least as wide as header)
        col_widths = [len(str(c)) for c in columns]
        for row in rows[:max_rows]:
            for i, val in enumerate(row):
                if i < len(col_widths):
                    col_widths[i] = max(col_widths[i], len(str(val)))

        # Build formatted table
        # Header
        header = " | ".join(
            str(c).ljust(w) for c, w in zip(columns, col_widths)
        )
        separator = "-+-".join("-" * w for w in col_widths)

        # Data rows
        formatted_rows = []
        for row in rows[:max_rows]:
            formatted_row = " | ".join(
                str(v).ljust(w) for v, w in zip(row, col_widths)
            )
            formatted_rows.append(formatted_row)

        # Assemble
        table_str = f"{header}\n{separator}\n" + "\n".join(formatted_rows)

        # Truncation notice
        if len(rows) > max_rows:
            table_str += f"\n... ({len(rows) - max_rows} more rows not shown)"

        # Row count
        table_str += f"\n\n({len(rows)} row{'s' if len(rows) != 1 else ''} returned)"

        return table_str

    def close(self) -> None:
        """Close the database connection and free resources."""
        if self.conn:
            try:
                self.conn.close()
            except Exception:
                pass  # Ignore errors on close
            self.conn = None