File size: 2,537 Bytes
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Tests for SQLEnvTRL question-database alignment and post-episode penalty."""

import pytest

from sql_env.training.trl_adapter import SQLEnvTRL, _POST_EPISODE_PENALTY


def test_reset_with_question_text_loads_correct_database():
    """Verify that passing question_text to reset() loads the matching DB."""
    SQLEnvTRL._configure(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        step_budget=5,
    )
    env = SQLEnvTRL()

    # "How many templates do we have?" belongs to cre_Doc_Template_Mgt
    obs = env.reset(question_text="How many templates do we have?")
    assert "Templates" in obs, (
        f"Expected 'Templates' table for template question, got: {obs}"
    )

    # "How many employees are there?" belongs to employee_hire_evaluation
    obs = env.reset(
        question_text="Find the total amount of bonus given in all the evaluations."
    )
    assert "employee" in obs or "evaluation" in obs, (
        f"Expected employee/evaluation tables, got: {obs}"
    )


def test_reset_without_question_text_still_works():
    """Verify reset() works without question_text (random question)."""
    SQLEnvTRL._configure(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        step_budget=5,
    )
    env = SQLEnvTRL()
    obs = env.reset()
    assert "Tables:" in obs
    assert "Use describe" in obs


def test_reset_with_unknown_question_falls_back():
    """Unknown question_text falls back to random selection."""
    SQLEnvTRL._configure(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        step_budget=5,
    )
    env = SQLEnvTRL()
    obs = env.reset(question_text="This question does not exist")
    assert "Tables:" in obs


def test_post_episode_penalty_on_tool_call_after_done():
    """Calling a tool after episode ends should apply penalty and raise."""
    SQLEnvTRL._configure(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        step_budget=5,
    )
    env = SQLEnvTRL()
    env.reset()
    env._done = True
    env.reward = 1.0

    with pytest.raises(ValueError, match="Episode is over"):
        env.describe("some_table")

    assert env.reward == pytest.approx(1.0 + _POST_EPISODE_PENALTY)

    # Second call stacks the penalty
    with pytest.raises(ValueError, match="Episode is over"):
        env.query("SELECT 1")

    assert env.reward == pytest.approx(1.0 + 2 * _POST_EPISODE_PENALTY)