File size: 3,938 Bytes
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Quick smoke test for all 3 tasks."""
import sys, json
sys.path.insert(0, ".")

from env.environment import SQLOptimizerEnv
from env.models import Action

env = SQLOptimizerEnv()

# ── Task 1 ──────────────────────────────────────────────────────────────────
print("=== Task 1 (Easy): fix-broken-join ===")
obs = env.reset(1)
print(f"  task: {obs.task_name}")
action = Action(
    rewritten_query=(
        "SELECT o.order_id, c.name, o.total "
        "FROM orders o INNER JOIN customers c ON o.customer_id = c.customer_id "
        "WHERE o.total > 100"
    ),
    explanation="Replaced comma cross-join with INNER JOIN ON customer_id",
    is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f"  grader_score={info['grader_score']:.3f}  step_reward={reward.score:.4f}  done={done}")
print(f"  feedback: {reward.feedback}")
assert obs2.done == True, "done should be True"
assert info["grader_score"] >= 0.8, f"Expected >=0.8, got {info['grader_score']}"

# ── Task 2 ──────────────────────────────────────────────────────────────────
print()
print("=== Task 2 (Medium): eliminate-n-plus-one ===")
obs = env.reset(2)
print(f"  task: {obs.task_name}")
action = Action(
    rewritten_query=(
        "SELECT e.name, d.dept_name "
        "FROM employees e "
        "LEFT JOIN departments d ON e.dept_id = d.dept_id "
        "WHERE e.salary > 50000"
    ),
    explanation="Replaced correlated subquery with a single LEFT JOIN",
    is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f"  grader_score={info['grader_score']:.3f}  step_reward={reward.score:.4f}  done={done}")
print(f"  feedback: {reward.feedback}")
assert info["grader_score"] >= 0.7, f"Expected >=0.7, got {info['grader_score']}"

# ── Task 3 ──────────────────────────────────────────────────────────────────
print()
print("=== Task 3 (Hard): full-optimization ===")
obs = env.reset(3)
print(f"  task: {obs.task_name}")
action = Action(
    rewritten_query=(
        "-- Index hint: consider CREATE INDEX ON products(category, price)\n"
        "SELECT p.name, p.category, p.price, oi.quantity, oi.unit_price\n"
        "FROM   products p\n"
        "JOIN   order_items oi ON p.product_id = oi.product_id\n"
        "WHERE  p.price >= 100 AND p.price < 200\n"
        "  AND  p.category = 'Electronics'\n"
        "ORDER  BY p.name"
    ),
    explanation="Removed DISTINCT and SELECT *, replaced CAST LIKE with range, added index hint",
    is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f"  grader_score={info['grader_score']:.3f}  step_reward={reward.score:.4f}  done={done}")
print(f"  feedback: {reward.feedback}")
assert info["grader_score"] >= 0.9, f"Expected >=0.9, got {info['grader_score']}"

# ── state() ─────────────────────────────────────────────────────────────────
print()
print("=== state() ===")
print(json.dumps(env.state(), indent=2))

# ── invalid action penalty ───────────────────────────────────────────────────
print()
print("=== Invalid action test ===")
env.reset(1)
obs2, reward, done, info = env.step(Action(rewritten_query="", explanation="", is_done=False))
print(f"  step_reward={reward.score}  is_invalid={info['is_invalid']}")
assert info["is_invalid"] == True, "Empty query should be flagged invalid"

print()
print("ALL TESTS PASSED")