| """Unit tests for OraclePolicy action selection.""" |
|
|
| from __future__ import annotations |
|
|
| from sql_env.evaluation.policies import Policy |
| from sql_env.evaluation.oracle_policy import OraclePolicy |
| from sql_env.models import QuestionRecord, SQLAction, SQLObservation |
|
|
|
|
| def _question( |
| *, |
| text: str, |
| tables: list[str], |
| gold_sql: str, |
| gold_answer: str, |
| question_id: str = "q1", |
| ) -> QuestionRecord: |
| return QuestionRecord( |
| question_id=question_id, |
| question_text=text, |
| database_name="db", |
| gold_sql=gold_sql, |
| gold_answer=gold_answer, |
| answer_type="string", |
| difficulty="easy", |
| tables_involved=tables, |
| ) |
|
|
|
|
| def _obs( |
| *, |
| question: str, |
| step_count: int = 0, |
| budget_remaining: int = 10, |
| ) -> SQLObservation: |
| return SQLObservation( |
| question=question, |
| schema_info="Available tables:\n- employees\n- departments", |
| result="", |
| error="", |
| step_count=step_count, |
| budget_remaining=budget_remaining, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
|
|
|
|
| def test_init_builds_lookup_from_questions() -> None: |
| first = _question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1") |
| second = _question( |
| text="Q2", |
| tables=["t2"], |
| gold_sql="SELECT 2", |
| gold_answer="2", |
| question_id="q2", |
| ) |
|
|
| policy = OraclePolicy([first, second]) |
|
|
| assert set(policy._question_lookup) == {"Q1", "Q2"} |
|
|
|
|
| def test_init_empty_questions() -> None: |
| policy = OraclePolicy([]) |
|
|
| assert policy._question_lookup == {} |
|
|
|
|
| def test_init_single_question() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| assert set(policy._question_lookup) == {"Q1"} |
|
|
|
|
| def test_init_duplicate_question_text() -> None: |
| first = _question( |
| text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="first" |
| ) |
| second = _question( |
| text="Q1", |
| tables=["t2"], |
| gold_sql="SELECT 2", |
| gold_answer="second", |
| question_id="q2", |
| ) |
|
|
| policy = OraclePolicy([first, second]) |
|
|
| assert policy._question_lookup["Q1"].gold_answer == "second" |
|
|
|
|
| def test_init_state_defaults() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| assert policy._current_question is None |
| assert policy._tables_to_describe == [] |
| assert policy._gold_sql_sent is False |
|
|
|
|
| def test_select_action_describe_phase() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", |
| tables=["t1", "t2"], |
| gold_sql="SELECT * FROM t1", |
| gold_answer="A", |
| ) |
| ] |
| ) |
|
|
| action = policy.select_action(_obs(question="Q1")) |
|
|
| assert action == SQLAction(action_type="DESCRIBE", argument="t1") |
|
|
|
|
| def test_select_action_describe_second_table() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", |
| tables=["t1", "t2"], |
| gold_sql="SELECT * FROM t1", |
| gold_answer="A", |
| ) |
| ] |
| ) |
| policy.select_action(_obs(question="Q1")) |
|
|
| action = policy.select_action(_obs(question="Q1", step_count=1)) |
|
|
| assert action == SQLAction(action_type="DESCRIBE", argument="t2") |
|
|
|
|
| def test_select_action_query_phase() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", |
| tables=["t1", "t2"], |
| gold_sql="SELECT * FROM t1", |
| gold_answer="A", |
| ) |
| ] |
| ) |
| policy.select_action(_obs(question="Q1")) |
| policy.select_action(_obs(question="Q1", step_count=1)) |
|
|
| action = policy.select_action(_obs(question="Q1", step_count=2)) |
|
|
| assert action == SQLAction(action_type="QUERY", argument="SELECT * FROM t1") |
|
|
|
|
| def test_select_action_answer_phase() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" |
| ) |
| ] |
| ) |
| policy.select_action(_obs(question="Q1")) |
| policy.select_action(_obs(question="Q1", step_count=1)) |
|
|
| action = policy.select_action(_obs(question="Q1", step_count=2)) |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="A") |
|
|
|
|
| def test_full_episode_sequence() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", |
| tables=["employees"], |
| gold_sql="SELECT COUNT(*) FROM employees", |
| gold_answer="3", |
| ) |
| ] |
| ) |
|
|
| action_1 = policy.select_action(_obs(question="Q1")) |
| action_2 = policy.select_action(_obs(question="Q1", step_count=1)) |
| action_3 = policy.select_action(_obs(question="Q1", step_count=2)) |
|
|
| assert [action_1.action_type, action_2.action_type, action_3.action_type] == [ |
| "DESCRIBE", |
| "QUERY", |
| "ANSWER", |
| ] |
|
|
|
|
| def test_new_episode_resets_state() -> None: |
| q1 = _question( |
| text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" |
| ) |
| q2 = _question( |
| text="Q2", tables=["t2"], gold_sql="SELECT * FROM t2", gold_answer="B" |
| ) |
| policy = OraclePolicy([q1, q2]) |
|
|
| policy.select_action(_obs(question="Q1")) |
| policy.select_action(_obs(question="Q1", step_count=1)) |
| action = policy.select_action(_obs(question="Q2", step_count=0)) |
|
|
| assert action == SQLAction(action_type="DESCRIBE", argument="t2") |
| assert policy._gold_sql_sent is False |
|
|
|
|
| def test_new_episode_lookup() -> None: |
| q1 = _question( |
| text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" |
| ) |
| q2 = _question( |
| text="Q2", tables=["t2"], gold_sql="SELECT * FROM t2", gold_answer="B" |
| ) |
| policy = OraclePolicy([q1, q2]) |
|
|
| policy.select_action(_obs(question="Q1")) |
| policy.select_action(_obs(question="Q2", step_count=0)) |
|
|
| assert policy._current_question is q2 |
|
|
|
|
| def test_zero_tables_skips_describe() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| action = policy.select_action(_obs(question="Q1")) |
|
|
| assert action == SQLAction(action_type="QUERY", argument="SELECT 1") |
|
|
|
|
| def test_zero_tables_then_answer() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
| policy.select_action(_obs(question="Q1")) |
|
|
| action = policy.select_action(_obs(question="Q1", step_count=1)) |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="1") |
|
|
|
|
| def test_unknown_question_returns_empty_answer() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| action = policy.select_action(_obs(question="UNKNOWN")) |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="") |
|
|
|
|
| def test_unknown_question_no_crash() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| result = policy.select_action(_obs(question="UNKNOWN")) |
|
|
| assert isinstance(result, SQLAction) |
| assert result.action_type == "ANSWER" |
|
|
|
|
| def test_budget_one_forces_answer() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", |
| tables=["t1", "t2"], |
| gold_sql="SELECT * FROM t1", |
| gold_answer="A", |
| ) |
| ] |
| ) |
| policy.select_action(_obs(question="Q1")) |
|
|
| action = policy.select_action(_obs(question="Q1", step_count=1, budget_remaining=1)) |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="A") |
|
|
|
|
| def test_budget_one_forces_answer_before_query() -> None: |
| policy = OraclePolicy( |
| [ |
| _question( |
| text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" |
| ) |
| ] |
| ) |
| policy.select_action(_obs(question="Q1")) |
|
|
| action = policy.select_action(_obs(question="Q1", step_count=1, budget_remaining=1)) |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="A") |
|
|
|
|
| def test_budget_one_unknown_question() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| action = policy.select_action(_obs(question="UNKNOWN", budget_remaining=1)) |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="") |
|
|
|
|
| def test_select_action_returns_sql_action() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| result = policy.select_action(_obs(question="Q1")) |
|
|
| assert isinstance(result, SQLAction) |
|
|
|
|
| def test_select_action_valid_action_types() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| actions = [ |
| policy.select_action(_obs(question="Q1", step_count=0)), |
| policy.select_action(_obs(question="Q1", step_count=1)), |
| policy.select_action(_obs(question="Q1", step_count=2)), |
| ] |
|
|
| assert {action.action_type for action in actions}.issubset( |
| {"DESCRIBE", "QUERY", "ANSWER"} |
| ) |
|
|
|
|
| def test_oracle_satisfies_policy_protocol() -> None: |
| policy = OraclePolicy( |
| [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] |
| ) |
|
|
| assert isinstance(policy, Policy) |
|
|
|
|
| def test_oracle_has_select_action_method() -> None: |
| assert callable(getattr(OraclePolicy, "select_action", None)) |
|
|