File size: 12,094 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
 
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""Smoke tests for the structured SQL environment loop."""

import json
import sqlite3

import pytest

from sql_env.client import SQLEnvClient
from sql_env.models import SQLAction, SQLObservation, SQLState
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.server.mock_tokenizer import MockTokenizer


@pytest.fixture
def environment_paths(tmp_path):
    db_id = "testdb"
    db_root = tmp_path / "databases"
    db_dir = db_root / db_id
    db_dir.mkdir(parents=True)
    db_path = db_dir / f"{db_id}.sqlite"

    connection = sqlite3.connect(db_path)
    cursor = connection.cursor()
    cursor.execute(
        "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)"
    )
    cursor.execute("CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT)")
    cursor.executemany(
        "INSERT INTO departments (id, name) VALUES (?, ?)",
        [(1, "engineering"), (2, "sales")],
    )
    cursor.executemany(
        "INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)",
        [(idx, f"emp-{idx}", "engineering") for idx in range(1, 26)],
    )
    connection.commit()
    connection.close()

    questions_path = tmp_path / "questions.json"
    questions = [
        {
            "question": "How many employees are there?",
            "db_id": db_id,
            "query": "SELECT COUNT(*) FROM employees",
        },
        {
            "question": "How many departments are there?",
            "db_id": db_id,
            "query": "SELECT COUNT(*) FROM departments",
        },
    ]
    questions_path.write_text(json.dumps(questions), encoding="utf-8")

    return str(questions_path), str(db_root)


@pytest.fixture
def env(environment_paths):
    questions_path, db_dir = environment_paths
    return SQLEnvironment(
        questions_path=questions_path,
        db_dir=db_dir,
        tokenizer=MockTokenizer(),
    )


class TestModels:
    def test_action_creation(self):
        action = SQLAction(action_type="DESCRIBE", argument="employees")
        assert action.action_type == "DESCRIBE"
        assert action.argument == "employees"

    def test_observation_creation(self):
        observation = SQLObservation(
            question="How many employees are there?",
            schema_info="Available tables:\n- employees",
            result="",
            error="",
            step_count=0,
            budget_remaining=15,
            action_history=[],
            done=False,
            reward=None,
        )
        assert observation.done is False
        assert observation.reward is None
        assert observation.question.startswith("How many")

    def test_state_defaults(self):
        state = SQLState()
        assert state.history_messages == []
        assert state.current_action_type == "QUERY"


class TestEnvironment:
    def test_init_loads_questions(self, env):
        assert len(env.questions) == 2
        assert env.step_budget == 15

    def test_reset_returns_rich_observation(self, env):
        observation = env.reset(seed=42)
        assert isinstance(observation, SQLObservation)
        assert observation.done is False
        assert observation.reward is None
        assert observation.step_count == 0
        assert observation.budget_remaining == 15
        assert observation.error == ""
        assert observation.action_history == []
        assert "Available tables:" in observation.schema_info
        assert "employees" in observation.schema_info
        assert "name TEXT" not in observation.schema_info

    def test_reset_seed_determinism(self, env):
        first = env.reset(seed=123)
        second = env.reset(seed=123)
        assert first.question == second.question

    def test_step_before_reset_is_graceful(self, env):
        observation = env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
        assert "No active episode" in observation.error
        assert observation.done is False

    def test_describe_reveals_columns_and_updates_schema(self, env):
        env.reset(seed=42)
        observation = env.step(SQLAction(action_type="DESCRIBE", argument="employees"))
        assert "Table 'employees' columns:" in observation.result
        assert "- name: TEXT" in observation.result
        assert observation.error == ""
        assert observation.step_count == 1
        assert observation.budget_remaining == 14
        assert observation.reward == pytest.approx(0.015)
        assert "Described tables:" in observation.schema_info
        assert "employees: id INTEGER" in observation.schema_info

    def test_sample_and_query_success(self, env):
        env.reset(seed=42)
        sample_obs = env.step(SQLAction(action_type="SAMPLE", argument="employees"))
        assert "Sample from 'employees':" in sample_obs.result
        assert sample_obs.error == ""
        assert sample_obs.reward == pytest.approx(0.015)

        query_obs = env.step(
            SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employees")
        )
        assert "25" in query_obs.result
        assert query_obs.error == ""
        assert query_obs.reward is not None
        assert query_obs.reward > 0

    def test_query_rejects_non_select(self, env):
        env.reset(seed=42)
        observation = env.step(SQLAction(action_type="QUERY", argument="DROP TABLE x"))
        assert "Only SELECT queries are allowed" in observation.error
        assert observation.step_count == 1
        assert observation.budget_remaining == 14
        assert observation.reward == pytest.approx(-0.005)

    def test_invalid_action_type_consumes_budget(self, env):
        env.reset(seed=42)
        observation = env.step(SQLAction(action_type="HACK", argument="x"))
        assert "Unknown action type" in observation.error
        assert observation.step_count == 1
        assert observation.budget_remaining == 14

    def test_empty_argument_consumes_budget(self, env):
        env.reset(seed=42)
        observation = env.step(SQLAction(action_type="QUERY", argument="   "))
        assert "Argument cannot be empty" in observation.error
        assert observation.step_count == 1
        assert observation.budget_remaining == 14

    def test_answer_ends_episode_without_budget_decrement(self, env):
        env.reset(seed=42)
        before_budget = env._episode.budget
        observation = env.step(SQLAction(action_type="ANSWER", argument="25"))
        assert observation.done is True
        assert observation.reward == 1.0
        assert observation.budget_remaining == before_budget

    def test_step_after_done_is_unchanged(self, env):
        env.reset(seed=42)
        terminal = env.step(SQLAction(action_type="ANSWER", argument="25"))
        again = env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
        assert again.done is True
        assert again.step_count == terminal.step_count
        assert again.budget_remaining == terminal.budget_remaining

    def test_budget_exhaustion_sets_done_and_zero_reward(self, environment_paths):
        questions_path, db_dir = environment_paths
        budget_env = SQLEnvironment(
            questions_path=questions_path,
            db_dir=db_dir,
            tokenizer=MockTokenizer(),
            step_budget=2,
        )
        budget_env.reset(seed=42)

        first = budget_env.step(SQLAction(action_type="DESCRIBE", argument="employees"))
        assert first.done is False
        assert first.budget_remaining == 1
        assert first.reward == pytest.approx(0.015)

        second = budget_env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
        assert second.done is True
        assert second.budget_remaining == 0
        assert second.reward == 0.0

    def test_query_truncates_to_20_rows(self, env):
        env.reset(seed=42)
        observation = env.step(
            SQLAction(action_type="QUERY", argument="SELECT id FROM employees")
        )
        assert "... (truncated to 20 rows)" in observation.result

    def test_query_timeout_returns_error(self, env, monkeypatch):
        env.reset(seed=42)

        def _timeout(*args, **kwargs):
            del args
            del kwargs
            raise sqlite3.OperationalError("Query timed out after 5.0 seconds")

        monkeypatch.setattr(env, "_execute_sql", _timeout)

        observation = env.step(
            SQLAction(
                action_type="QUERY",
                argument=(
                    "SELECT e1.id "
                    "FROM employees e1 "
                    "JOIN employees e2 ON 1=1 "
                    "JOIN employees e3 ON 1=1"
                ),
            )
        )
        assert "timed out" in observation.error.lower()

    def test_open_db_connection_is_read_only(self, env):
        connection = env._open_db("testdb")
        with pytest.raises(sqlite3.OperationalError):
            connection.execute("INSERT INTO departments (id, name) VALUES (3, 'hr')")
        connection.close()


class TestMessageToAction:
    def test_parses_prefixed_message(self, env):
        env.reset(seed=42)
        action = env.message_to_action(
            {"role": "user", "content": "DESCRIBE employees"}
        )
        assert action.action_type == "DESCRIBE"
        assert action.argument == "employees"

    def test_defaults_to_query_for_unprefixed_message(self, env):
        env.reset(seed=42)
        action = env.message_to_action(
            {"role": "user", "content": "SELECT COUNT(*) FROM employees"}
        )
        assert action.action_type == "QUERY"
        assert action.argument == "SELECT COUNT(*) FROM employees"

    def test_validates_message_shape(self, env):
        env.reset(seed=42)
        with pytest.raises(ValueError):
            env.message_to_action({"content": "missing role"})
        with pytest.raises(ValueError):
            env.message_to_action({"role": "user"})
        with pytest.raises(ValueError):
            env.message_to_action({"role": "user", "content": None})


class TestClientSerialization:
    def test_step_payload_serialization(self):
        client = SQLEnvClient.__new__(SQLEnvClient)
        action = SQLAction(action_type="QUERY", argument="SELECT 1")
        payload = client._step_payload(action)
        assert payload["action_type"] == "QUERY"
        assert payload["argument"] == "SELECT 1"
        assert "metadata" in payload

    def test_parse_result_observation_payload(self):
        client = SQLEnvClient.__new__(SQLEnvClient)
        payload = {
            "observation": {
                "question": "How many employees are there?",
                "schema_info": "Available tables:\n- employees",
                "result": "1. 25",
                "error": "",
                "step_count": 1,
                "budget_remaining": 14,
                "action_history": ["QUERY -> 1. 25"],
                "done": False,
                "reward": None,
            },
            "done": False,
            "reward": None,
        }
        result = client._parse_result(payload)
        assert result.observation.question == "How many employees are there?"
        assert result.observation.step_count == 1
        assert result.done is False

    def test_parse_state_deserializes(self):
        client = SQLEnvClient.__new__(SQLEnvClient)
        state = client._parse_state(
            {
                "episode_id": "ep-1",
                "step_count": 2,
                "history_messages": [{"role": "user", "content": "hi"}],
                "current_action_type": "QUERY",
            }
        )
        assert state.episode_id == "ep-1"
        assert state.step_count == 2
        assert state.history_messages == [{"role": "user", "content": "hi"}]

    def test_client_message_to_action_infers_action(self):
        client = SQLEnvClient.__new__(SQLEnvClient)
        action = client.message_to_action(
            {"role": "user", "content": "show me sample rows from employees"},
            tokenizer=MockTokenizer(),
        )
        assert action.action_type == "SAMPLE"
        assert "sample" in action.argument.lower()