File size: 7,822 Bytes
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
"""Verify SFT and GRPO see identical prompt formats.

Renders the same question through both pipelines and compares the
tokenized output. Run on Colab or any env with transformers installed:

    python scripts/test_sft_grpo_alignment.py
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from transformers import AutoTokenizer  # noqa: E402

from sql_env.training.trl_adapter import get_tool_definitions  # noqa: E402

from scripts.generate_sft_data import get_system_prompt  # noqa: E402


def render_sft_prompt(
    tokenizer,
    messages: list[dict],
    tools: list[dict],
) -> str:
    """Render a prompt the way SFT sees it."""
    return tokenizer.apply_chat_template(
        messages,
        tools=tools,
        tokenize=False,
        add_generation_prompt=True,
    )


def render_grpo_prompt(
    tokenizer,
    messages: list[dict],
    tools: list[dict],
) -> str:
    """Render a prompt the way GRPO sees it (same template, same tools)."""
    return tokenizer.apply_chat_template(
        messages,
        tools=tools,
        tokenize=False,
        add_generation_prompt=True,
    )


def test_tool_definitions_match_class():
    """Verify get_tool_definitions() extracts all SQLEnvTRL methods."""
    tools = get_tool_definitions()
    tool_names = {t["function"]["name"] for t in tools}

    expected = {"describe", "sample", "query", "answer"}
    assert tool_names == expected, (
        f"Tool mismatch: got {tool_names}, expected {expected}"
    )

    # Each tool should have parameters with required fields
    for tool in tools:
        func = tool["function"]
        assert "name" in func
        assert "description" in func
        assert "parameters" in func
        props = func["parameters"]["properties"]
        assert len(props) > 0, f"{func['name']} has no parameters"
        required = func["parameters"]["required"]
        assert len(required) > 0, f"{func['name']} has no required params"

    print("[PASS] Tool definitions match SQLEnvTRL methods")
    return tools


def test_prompt_parity(tokenizer, tools):
    """Verify SFT and GRPO render identical prompts."""
    system_prompt = get_system_prompt(enable_thinking=False)
    question = (
        "How many cars have a larger accelerate than the car with "
        "the largest horsepower?"
        "Tables: car_makers, car_names, cars_data, continents, "
        "countries, model_list. "
        "Use describe, sample, query, and answer tools."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]

    sft_rendered = render_sft_prompt(tokenizer, messages, tools)
    grpo_rendered = render_grpo_prompt(tokenizer, messages, tools)

    assert sft_rendered == grpo_rendered, (
        "SFT and GRPO prompts differ!\n"
        f"SFT length: {len(sft_rendered)}\n"
        f"GRPO length: {len(grpo_rendered)}"
    )
    print("[PASS] SFT and GRPO prompts are identical")
    return sft_rendered


def test_tools_in_rendered_prompt(rendered: str, tools: list[dict]):
    """Verify the rendered prompt contains tool definitions."""
    assert "<tools>" in rendered, "No <tools> block in rendered prompt"
    assert "</tools>" in rendered, "No </tools> block in rendered prompt"

    for tool in tools:
        name = tool["function"]["name"]
        assert f'"name": "{name}"' in rendered, (
            f"Tool '{name}' not found in rendered prompt"
        )

    print("[PASS] All tool definitions present in rendered prompt")


def test_sft_data_has_tools(tools: list[dict]):
    """Verify SFT data includes tool definitions."""
    sft_path = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json"
    if not sft_path.exists():
        print("[SKIP] SFT data not generated yet")
        return

    with open(sft_path) as f:
        data = json.load(f)

    has_tools = sum(1 for row in data if "tools" in row)
    total = len(data)

    if has_tools == 0:
        print(
            f"[WARN] SFT data has NO tool definitions ({total} "
            "trajectories). Regenerate with: "
            "python scripts/generate_sft_data.py"
        )
    elif has_tools < total:
        print(f"[WARN] Only {has_tools}/{total} trajectories have tools")
    else:
        # Verify tools match
        first_tools = data[0]["tools"]
        first_names = {t["function"]["name"] for t in first_tools}
        expected_names = {t["function"]["name"] for t in tools}
        assert first_names == expected_names, (
            f"SFT data tools {first_names} != expected {expected_names}"
        )
        print(f"[PASS] All {total} SFT trajectories have matching tools")


def test_sft_tool_call_format(tokenizer, tools: list[dict]):
    """Verify SFT tool_calls render correctly through chat template."""
    messages = [
        {"role": "system", "content": "You are a SQL assistant."},
        {"role": "user", "content": "How many rows in employees?"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "type": "function",
                    "function": {
                        "name": "describe",
                        "arguments": {"table_name": "employees"},
                    },
                }
            ],
        },
        {"role": "tool", "content": "Table 'employees' columns:\n- id"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "type": "function",
                    "function": {
                        "name": "query",
                        "arguments": {
                            "sql": "SELECT COUNT(*) FROM employees",
                        },
                    },
                }
            ],
        },
        {"role": "tool", "content": "1. 42"},
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "type": "function",
                    "function": {
                        "name": "answer",
                        "arguments": {"value": "42"},
                    },
                }
            ],
        },
        {"role": "tool", "content": "Answer submitted: correct."},
    ]

    rendered = tokenizer.apply_chat_template(
        messages,
        tools=tools,
        tokenize=False,
    )

    # Should contain tool_call tags for each assistant turn
    tool_call_count = rendered.count("<tool_call>")
    assert tool_call_count == 3, f"Expected 3 tool_calls, got {tool_call_count}"

    # Each tool call should have the function name
    assert '"name": "describe"' in rendered
    assert '"name": "query"' in rendered
    assert '"name": "answer"' in rendered

    # SQL should be present (not null)
    assert "SELECT COUNT" in rendered

    print("[PASS] Multi-turn tool_calls render correctly with tools")


def main():
    model_name = "Qwen/Qwen3-0.6B"
    print(f"Loading tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    print("\n--- Tool Definition Tests ---")
    tools = test_tool_definitions_match_class()

    print("\n--- Prompt Parity Tests ---")
    rendered = test_prompt_parity(tokenizer, tools)

    print("\n--- Tool Presence Tests ---")
    test_tools_in_rendered_prompt(rendered, tools)

    print("\n--- SFT Data Tests ---")
    test_sft_data_has_tools(tools)

    print("\n--- Multi-Turn Rendering Tests ---")
    test_sft_tool_call_format(tokenizer, tools)

    print("\n--- Rendered Prompt Preview ---")
    # Show first 600 chars of the rendered prompt
    print(rendered[:600])
    print("...")

    print("\n=== ALL TESTS PASSED ===")


if __name__ == "__main__":
    main()