Spaces:
Sleeping
Sleeping
File size: 3,222 Bytes
aa3a171 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | def compute_reward(task: dict, agent_query: str, run_result: dict) -> dict:
"""
task = one of TASK dicts from tasks/
agent_query = the SQL string the agent submitted
run_result = output from runner.run_query()
Returns a dict: { value, syntax_ok, result_match_pct, plan_score, message }
"""
# ββ Step 1: Did the query even run? βββββββββββββββββββββββββββββββββββββββ
syntax_ok = (run_result["error"] is None)
if not syntax_ok:
# Give tiny credit for trying (not zero, so agent gets gradient signal)
return {
"value": 0.05,
"syntax_ok": False,
"result_match_pct": 0.0,
"plan_score": 0.0,
"message": f"Syntax error: {run_result['error'][:100]}",
}
# ββ Step 2: Did we get the right rows? ββββββββββββββββββββββββββββββββββββ
result_match_pct = 0.0
if task["expected_rows"] is not None:
expected = task["expected_rows"]
got = run_result["rows"]
# Count how many expected rows are present in the result
matches = sum(1 for row in expected if row in got)
result_match_pct = matches / max(len(expected), 1)
# Penalize extra rows (returned too many rows = wrong query)
if len(got) > len(expected) * 2:
result_match_pct *= 0.7 # 30% penalty for bloated results
else:
# Hard task: no fixed rows β give full match credit if query runs
result_match_pct = 1.0
# ββ Step 3: Is the query plan good? (hard task only) βββββββββββββββββββββ
plan_score = 0.0
if task.get("check_plan"):
query_upper = agent_query.upper()
good_patterns = task.get("good_patterns", [])
# Each good pattern found = partial credit
found = sum(1 for p in good_patterns if p.upper() in query_upper)
plan_score = found / max(len(good_patterns), 1)
# Also penalize if they still use correlated subquery pattern
if "WHERE" in query_upper and "SELECT AVG" in query_upper:
plan_score *= 0.3 # Heavy penalty β they didn't really optimize
# ββ Step 4: Combine into final score ββββββββββββββββββββββββββββββββββββββ
# Weights: syntax 20% + correctness 60% + plan 20%
base_score = 0.2 + (0.6 * result_match_pct) + (0.2 * plan_score)
# Penalize absurdly long queries (e.g. agent spams SELECT *)
length_penalty = max(0.0, (len(agent_query) - 800) / 2000)
final = max(0.0, min(1.0, base_score - length_penalty))
status = "perfect" if final >= 0.99 else "partial" if final > 0.2 else "wrong"
msg = f"{status} | rows matched: {result_match_pct:.0%} | plan: {plan_score:.0%}"
return {
"value": round(final, 3),
"syntax_ok": True,
"result_match_pct": result_match_pct,
"plan_score": plan_score,
"message": msg,
} |