sql_env / tests /unit /test_oracle_policy.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Unit tests for OraclePolicy action selection."""
from __future__ import annotations
from sql_env.evaluation.policies import Policy
from sql_env.evaluation.oracle_policy import OraclePolicy
from sql_env.models import QuestionRecord, SQLAction, SQLObservation
def _question(
*,
text: str,
tables: list[str],
gold_sql: str,
gold_answer: str,
question_id: str = "q1",
) -> QuestionRecord:
return QuestionRecord(
question_id=question_id,
question_text=text,
database_name="db",
gold_sql=gold_sql,
gold_answer=gold_answer,
answer_type="string",
difficulty="easy",
tables_involved=tables,
)
def _obs(
*,
question: str,
step_count: int = 0,
budget_remaining: int = 10,
) -> SQLObservation:
return SQLObservation(
question=question,
schema_info="Available tables:\n- employees\n- departments",
result="",
error="",
step_count=step_count,
budget_remaining=budget_remaining,
action_history=[],
done=False,
reward=None,
)
def test_init_builds_lookup_from_questions() -> None:
first = _question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")
second = _question(
text="Q2",
tables=["t2"],
gold_sql="SELECT 2",
gold_answer="2",
question_id="q2",
)
policy = OraclePolicy([first, second])
assert set(policy._question_lookup) == {"Q1", "Q2"}
def test_init_empty_questions() -> None:
policy = OraclePolicy([])
assert policy._question_lookup == {}
def test_init_single_question() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")]
)
assert set(policy._question_lookup) == {"Q1"}
def test_init_duplicate_question_text() -> None:
first = _question(
text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="first"
)
second = _question(
text="Q1",
tables=["t2"],
gold_sql="SELECT 2",
gold_answer="second",
question_id="q2",
)
policy = OraclePolicy([first, second])
assert policy._question_lookup["Q1"].gold_answer == "second"
def test_init_state_defaults() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")]
)
assert policy._current_question is None
assert policy._tables_to_describe == []
assert policy._gold_sql_sent is False
def test_select_action_describe_phase() -> None:
policy = OraclePolicy(
[
_question(
text="Q1",
tables=["t1", "t2"],
gold_sql="SELECT * FROM t1",
gold_answer="A",
)
]
)
action = policy.select_action(_obs(question="Q1"))
assert action == SQLAction(action_type="DESCRIBE", argument="t1")
def test_select_action_describe_second_table() -> None:
policy = OraclePolicy(
[
_question(
text="Q1",
tables=["t1", "t2"],
gold_sql="SELECT * FROM t1",
gold_answer="A",
)
]
)
policy.select_action(_obs(question="Q1"))
action = policy.select_action(_obs(question="Q1", step_count=1))
assert action == SQLAction(action_type="DESCRIBE", argument="t2")
def test_select_action_query_phase() -> None:
policy = OraclePolicy(
[
_question(
text="Q1",
tables=["t1", "t2"],
gold_sql="SELECT * FROM t1",
gold_answer="A",
)
]
)
policy.select_action(_obs(question="Q1"))
policy.select_action(_obs(question="Q1", step_count=1))
action = policy.select_action(_obs(question="Q1", step_count=2))
assert action == SQLAction(action_type="QUERY", argument="SELECT * FROM t1")
def test_select_action_answer_phase() -> None:
policy = OraclePolicy(
[
_question(
text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A"
)
]
)
policy.select_action(_obs(question="Q1"))
policy.select_action(_obs(question="Q1", step_count=1))
action = policy.select_action(_obs(question="Q1", step_count=2))
assert action == SQLAction(action_type="ANSWER", argument="A")
def test_full_episode_sequence() -> None:
policy = OraclePolicy(
[
_question(
text="Q1",
tables=["employees"],
gold_sql="SELECT COUNT(*) FROM employees",
gold_answer="3",
)
]
)
action_1 = policy.select_action(_obs(question="Q1"))
action_2 = policy.select_action(_obs(question="Q1", step_count=1))
action_3 = policy.select_action(_obs(question="Q1", step_count=2))
assert [action_1.action_type, action_2.action_type, action_3.action_type] == [
"DESCRIBE",
"QUERY",
"ANSWER",
]
def test_new_episode_resets_state() -> None:
q1 = _question(
text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A"
)
q2 = _question(
text="Q2", tables=["t2"], gold_sql="SELECT * FROM t2", gold_answer="B"
)
policy = OraclePolicy([q1, q2])
policy.select_action(_obs(question="Q1"))
policy.select_action(_obs(question="Q1", step_count=1))
action = policy.select_action(_obs(question="Q2", step_count=0))
assert action == SQLAction(action_type="DESCRIBE", argument="t2")
assert policy._gold_sql_sent is False
def test_new_episode_lookup() -> None:
q1 = _question(
text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A"
)
q2 = _question(
text="Q2", tables=["t2"], gold_sql="SELECT * FROM t2", gold_answer="B"
)
policy = OraclePolicy([q1, q2])
policy.select_action(_obs(question="Q1"))
policy.select_action(_obs(question="Q2", step_count=0))
assert policy._current_question is q2
def test_zero_tables_skips_describe() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")]
)
action = policy.select_action(_obs(question="Q1"))
assert action == SQLAction(action_type="QUERY", argument="SELECT 1")
def test_zero_tables_then_answer() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")]
)
policy.select_action(_obs(question="Q1"))
action = policy.select_action(_obs(question="Q1", step_count=1))
assert action == SQLAction(action_type="ANSWER", argument="1")
def test_unknown_question_returns_empty_answer() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")]
)
action = policy.select_action(_obs(question="UNKNOWN"))
assert action == SQLAction(action_type="ANSWER", argument="")
def test_unknown_question_no_crash() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")]
)
result = policy.select_action(_obs(question="UNKNOWN"))
assert isinstance(result, SQLAction)
assert result.action_type == "ANSWER"
def test_budget_one_forces_answer() -> None:
policy = OraclePolicy(
[
_question(
text="Q1",
tables=["t1", "t2"],
gold_sql="SELECT * FROM t1",
gold_answer="A",
)
]
)
policy.select_action(_obs(question="Q1"))
action = policy.select_action(_obs(question="Q1", step_count=1, budget_remaining=1))
assert action == SQLAction(action_type="ANSWER", argument="A")
def test_budget_one_forces_answer_before_query() -> None:
policy = OraclePolicy(
[
_question(
text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A"
)
]
)
policy.select_action(_obs(question="Q1"))
action = policy.select_action(_obs(question="Q1", step_count=1, budget_remaining=1))
assert action == SQLAction(action_type="ANSWER", argument="A")
def test_budget_one_unknown_question() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")]
)
action = policy.select_action(_obs(question="UNKNOWN", budget_remaining=1))
assert action == SQLAction(action_type="ANSWER", argument="")
def test_select_action_returns_sql_action() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")]
)
result = policy.select_action(_obs(question="Q1"))
assert isinstance(result, SQLAction)
def test_select_action_valid_action_types() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")]
)
actions = [
policy.select_action(_obs(question="Q1", step_count=0)),
policy.select_action(_obs(question="Q1", step_count=1)),
policy.select_action(_obs(question="Q1", step_count=2)),
]
assert {action.action_type for action in actions}.issubset(
{"DESCRIBE", "QUERY", "ANSWER"}
)
def test_oracle_satisfies_policy_protocol() -> None:
policy = OraclePolicy(
[_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")]
)
assert isinstance(policy, Policy)
def test_oracle_has_select_action_method() -> None:
assert callable(getattr(OraclePolicy, "select_action", None))