File size: 12,094 Bytes
5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 | """Smoke tests for the structured SQL environment loop."""
import json
import sqlite3
import pytest
from sql_env.client import SQLEnvClient
from sql_env.models import SQLAction, SQLObservation, SQLState
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.server.mock_tokenizer import MockTokenizer
@pytest.fixture
def environment_paths(tmp_path):
db_id = "testdb"
db_root = tmp_path / "databases"
db_dir = db_root / db_id
db_dir.mkdir(parents=True)
db_path = db_dir / f"{db_id}.sqlite"
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.execute(
"CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)"
)
cursor.execute("CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT)")
cursor.executemany(
"INSERT INTO departments (id, name) VALUES (?, ?)",
[(1, "engineering"), (2, "sales")],
)
cursor.executemany(
"INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)",
[(idx, f"emp-{idx}", "engineering") for idx in range(1, 26)],
)
connection.commit()
connection.close()
questions_path = tmp_path / "questions.json"
questions = [
{
"question": "How many employees are there?",
"db_id": db_id,
"query": "SELECT COUNT(*) FROM employees",
},
{
"question": "How many departments are there?",
"db_id": db_id,
"query": "SELECT COUNT(*) FROM departments",
},
]
questions_path.write_text(json.dumps(questions), encoding="utf-8")
return str(questions_path), str(db_root)
@pytest.fixture
def env(environment_paths):
questions_path, db_dir = environment_paths
return SQLEnvironment(
questions_path=questions_path,
db_dir=db_dir,
tokenizer=MockTokenizer(),
)
class TestModels:
def test_action_creation(self):
action = SQLAction(action_type="DESCRIBE", argument="employees")
assert action.action_type == "DESCRIBE"
assert action.argument == "employees"
def test_observation_creation(self):
observation = SQLObservation(
question="How many employees are there?",
schema_info="Available tables:\n- employees",
result="",
error="",
step_count=0,
budget_remaining=15,
action_history=[],
done=False,
reward=None,
)
assert observation.done is False
assert observation.reward is None
assert observation.question.startswith("How many")
def test_state_defaults(self):
state = SQLState()
assert state.history_messages == []
assert state.current_action_type == "QUERY"
class TestEnvironment:
def test_init_loads_questions(self, env):
assert len(env.questions) == 2
assert env.step_budget == 15
def test_reset_returns_rich_observation(self, env):
observation = env.reset(seed=42)
assert isinstance(observation, SQLObservation)
assert observation.done is False
assert observation.reward is None
assert observation.step_count == 0
assert observation.budget_remaining == 15
assert observation.error == ""
assert observation.action_history == []
assert "Available tables:" in observation.schema_info
assert "employees" in observation.schema_info
assert "name TEXT" not in observation.schema_info
def test_reset_seed_determinism(self, env):
first = env.reset(seed=123)
second = env.reset(seed=123)
assert first.question == second.question
def test_step_before_reset_is_graceful(self, env):
observation = env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
assert "No active episode" in observation.error
assert observation.done is False
def test_describe_reveals_columns_and_updates_schema(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="DESCRIBE", argument="employees"))
assert "Table 'employees' columns:" in observation.result
assert "- name: TEXT" in observation.result
assert observation.error == ""
assert observation.step_count == 1
assert observation.budget_remaining == 14
assert observation.reward == pytest.approx(0.015)
assert "Described tables:" in observation.schema_info
assert "employees: id INTEGER" in observation.schema_info
def test_sample_and_query_success(self, env):
env.reset(seed=42)
sample_obs = env.step(SQLAction(action_type="SAMPLE", argument="employees"))
assert "Sample from 'employees':" in sample_obs.result
assert sample_obs.error == ""
assert sample_obs.reward == pytest.approx(0.015)
query_obs = env.step(
SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employees")
)
assert "25" in query_obs.result
assert query_obs.error == ""
assert query_obs.reward is not None
assert query_obs.reward > 0
def test_query_rejects_non_select(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="QUERY", argument="DROP TABLE x"))
assert "Only SELECT queries are allowed" in observation.error
assert observation.step_count == 1
assert observation.budget_remaining == 14
assert observation.reward == pytest.approx(-0.005)
def test_invalid_action_type_consumes_budget(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="HACK", argument="x"))
assert "Unknown action type" in observation.error
assert observation.step_count == 1
assert observation.budget_remaining == 14
def test_empty_argument_consumes_budget(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="QUERY", argument=" "))
assert "Argument cannot be empty" in observation.error
assert observation.step_count == 1
assert observation.budget_remaining == 14
def test_answer_ends_episode_without_budget_decrement(self, env):
env.reset(seed=42)
before_budget = env._episode.budget
observation = env.step(SQLAction(action_type="ANSWER", argument="25"))
assert observation.done is True
assert observation.reward == 1.0
assert observation.budget_remaining == before_budget
def test_step_after_done_is_unchanged(self, env):
env.reset(seed=42)
terminal = env.step(SQLAction(action_type="ANSWER", argument="25"))
again = env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
assert again.done is True
assert again.step_count == terminal.step_count
assert again.budget_remaining == terminal.budget_remaining
def test_budget_exhaustion_sets_done_and_zero_reward(self, environment_paths):
questions_path, db_dir = environment_paths
budget_env = SQLEnvironment(
questions_path=questions_path,
db_dir=db_dir,
tokenizer=MockTokenizer(),
step_budget=2,
)
budget_env.reset(seed=42)
first = budget_env.step(SQLAction(action_type="DESCRIBE", argument="employees"))
assert first.done is False
assert first.budget_remaining == 1
assert first.reward == pytest.approx(0.015)
second = budget_env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
assert second.done is True
assert second.budget_remaining == 0
assert second.reward == 0.0
def test_query_truncates_to_20_rows(self, env):
env.reset(seed=42)
observation = env.step(
SQLAction(action_type="QUERY", argument="SELECT id FROM employees")
)
assert "... (truncated to 20 rows)" in observation.result
def test_query_timeout_returns_error(self, env, monkeypatch):
env.reset(seed=42)
def _timeout(*args, **kwargs):
del args
del kwargs
raise sqlite3.OperationalError("Query timed out after 5.0 seconds")
monkeypatch.setattr(env, "_execute_sql", _timeout)
observation = env.step(
SQLAction(
action_type="QUERY",
argument=(
"SELECT e1.id "
"FROM employees e1 "
"JOIN employees e2 ON 1=1 "
"JOIN employees e3 ON 1=1"
),
)
)
assert "timed out" in observation.error.lower()
def test_open_db_connection_is_read_only(self, env):
connection = env._open_db("testdb")
with pytest.raises(sqlite3.OperationalError):
connection.execute("INSERT INTO departments (id, name) VALUES (3, 'hr')")
connection.close()
class TestMessageToAction:
def test_parses_prefixed_message(self, env):
env.reset(seed=42)
action = env.message_to_action(
{"role": "user", "content": "DESCRIBE employees"}
)
assert action.action_type == "DESCRIBE"
assert action.argument == "employees"
def test_defaults_to_query_for_unprefixed_message(self, env):
env.reset(seed=42)
action = env.message_to_action(
{"role": "user", "content": "SELECT COUNT(*) FROM employees"}
)
assert action.action_type == "QUERY"
assert action.argument == "SELECT COUNT(*) FROM employees"
def test_validates_message_shape(self, env):
env.reset(seed=42)
with pytest.raises(ValueError):
env.message_to_action({"content": "missing role"})
with pytest.raises(ValueError):
env.message_to_action({"role": "user"})
with pytest.raises(ValueError):
env.message_to_action({"role": "user", "content": None})
class TestClientSerialization:
def test_step_payload_serialization(self):
client = SQLEnvClient.__new__(SQLEnvClient)
action = SQLAction(action_type="QUERY", argument="SELECT 1")
payload = client._step_payload(action)
assert payload["action_type"] == "QUERY"
assert payload["argument"] == "SELECT 1"
assert "metadata" in payload
def test_parse_result_observation_payload(self):
client = SQLEnvClient.__new__(SQLEnvClient)
payload = {
"observation": {
"question": "How many employees are there?",
"schema_info": "Available tables:\n- employees",
"result": "1. 25",
"error": "",
"step_count": 1,
"budget_remaining": 14,
"action_history": ["QUERY -> 1. 25"],
"done": False,
"reward": None,
},
"done": False,
"reward": None,
}
result = client._parse_result(payload)
assert result.observation.question == "How many employees are there?"
assert result.observation.step_count == 1
assert result.done is False
def test_parse_state_deserializes(self):
client = SQLEnvClient.__new__(SQLEnvClient)
state = client._parse_state(
{
"episode_id": "ep-1",
"step_count": 2,
"history_messages": [{"role": "user", "content": "hi"}],
"current_action_type": "QUERY",
}
)
assert state.episode_id == "ep-1"
assert state.step_count == 2
assert state.history_messages == [{"role": "user", "content": "hi"}]
def test_client_message_to_action_infers_action(self):
client = SQLEnvClient.__new__(SQLEnvClient)
action = client.message_to_action(
{"role": "user", "content": "show me sample rows from employees"},
tokenizer=MockTokenizer(),
)
assert action.action_type == "SAMPLE"
assert "sample" in action.argument.lower()
|