File size: 2,537 Bytes
9e64e71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | """Tests for SQLEnvTRL question-database alignment and post-episode penalty."""
import pytest
from sql_env.training.trl_adapter import SQLEnvTRL, _POST_EPISODE_PENALTY
def test_reset_with_question_text_loads_correct_database():
"""Verify that passing question_text to reset() loads the matching DB."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
# "How many templates do we have?" belongs to cre_Doc_Template_Mgt
obs = env.reset(question_text="How many templates do we have?")
assert "Templates" in obs, (
f"Expected 'Templates' table for template question, got: {obs}"
)
# "How many employees are there?" belongs to employee_hire_evaluation
obs = env.reset(
question_text="Find the total amount of bonus given in all the evaluations."
)
assert "employee" in obs or "evaluation" in obs, (
f"Expected employee/evaluation tables, got: {obs}"
)
def test_reset_without_question_text_still_works():
"""Verify reset() works without question_text (random question)."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
obs = env.reset()
assert "Tables:" in obs
assert "Use describe" in obs
def test_reset_with_unknown_question_falls_back():
"""Unknown question_text falls back to random selection."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
obs = env.reset(question_text="This question does not exist")
assert "Tables:" in obs
def test_post_episode_penalty_on_tool_call_after_done():
"""Calling a tool after episode ends should apply penalty and raise."""
SQLEnvTRL._configure(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
step_budget=5,
)
env = SQLEnvTRL()
env.reset()
env._done = True
env.reward = 1.0
with pytest.raises(ValueError, match="Episode is over"):
env.describe("some_table")
assert env.reward == pytest.approx(1.0 + _POST_EPISODE_PENALTY)
# Second call stacks the penalty
with pytest.raises(ValueError, match="Episode is over"):
env.query("SELECT 1")
assert env.reward == pytest.approx(1.0 + 2 * _POST_EPISODE_PENALTY)
|