Spaces:
Sleeping
Sleeping
File size: 3,938 Bytes
210535c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """Quick smoke test for all 3 tasks."""
import sys, json
sys.path.insert(0, ".")
from env.environment import SQLOptimizerEnv
from env.models import Action
env = SQLOptimizerEnv()
# ββ Task 1 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("=== Task 1 (Easy): fix-broken-join ===")
obs = env.reset(1)
print(f" task: {obs.task_name}")
action = Action(
rewritten_query=(
"SELECT o.order_id, c.name, o.total "
"FROM orders o INNER JOIN customers c ON o.customer_id = c.customer_id "
"WHERE o.total > 100"
),
explanation="Replaced comma cross-join with INNER JOIN ON customer_id",
is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
print(f" feedback: {reward.feedback}")
assert obs2.done == True, "done should be True"
assert info["grader_score"] >= 0.8, f"Expected >=0.8, got {info['grader_score']}"
# ββ Task 2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print()
print("=== Task 2 (Medium): eliminate-n-plus-one ===")
obs = env.reset(2)
print(f" task: {obs.task_name}")
action = Action(
rewritten_query=(
"SELECT e.name, d.dept_name "
"FROM employees e "
"LEFT JOIN departments d ON e.dept_id = d.dept_id "
"WHERE e.salary > 50000"
),
explanation="Replaced correlated subquery with a single LEFT JOIN",
is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
print(f" feedback: {reward.feedback}")
assert info["grader_score"] >= 0.7, f"Expected >=0.7, got {info['grader_score']}"
# ββ Task 3 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print()
print("=== Task 3 (Hard): full-optimization ===")
obs = env.reset(3)
print(f" task: {obs.task_name}")
action = Action(
rewritten_query=(
"-- Index hint: consider CREATE INDEX ON products(category, price)\n"
"SELECT p.name, p.category, p.price, oi.quantity, oi.unit_price\n"
"FROM products p\n"
"JOIN order_items oi ON p.product_id = oi.product_id\n"
"WHERE p.price >= 100 AND p.price < 200\n"
" AND p.category = 'Electronics'\n"
"ORDER BY p.name"
),
explanation="Removed DISTINCT and SELECT *, replaced CAST LIKE with range, added index hint",
is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
print(f" feedback: {reward.feedback}")
assert info["grader_score"] >= 0.9, f"Expected >=0.9, got {info['grader_score']}"
# ββ state() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print()
print("=== state() ===")
print(json.dumps(env.state(), indent=2))
# ββ invalid action penalty βββββββββββββββββββββββββββββββββββββββββββββββββββ
print()
print("=== Invalid action test ===")
env.reset(1)
obs2, reward, done, info = env.step(Action(rewritten_query="", explanation="", is_done=False))
print(f" step_reward={reward.score} is_invalid={info['is_invalid']}")
assert info["is_invalid"] == True, "Empty query should be flagged invalid"
print()
print("ALL TESTS PASSED")
|