sql-query-optimizer / env /environment.py
Param20h's picture
Upload folder using huggingface_hub
429a3ac verified
"""
Core OpenEnv environment: SQLOptimizerEnv
Implements the three required methods:
reset(task_id) β†’ Observation
step(action) β†’ (Observation, Reward, done, info)
state() β†’ dict (current internal snapshot)
"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
from .models import Action, Observation, Reward, RewardBreakdown
from .tasks import TASKS, TaskDef, get_task
from .reward import compute_step_reward
_MIN_SCORE_EPS = 0.001
_MAX_SCORE_EPS = 0.999
def _strict_score(value: float) -> float:
return round(min(max(float(value), _MIN_SCORE_EPS), _MAX_SCORE_EPS), 4)
class SQLOptimizerEnv:
"""SQL Query Optimizer OpenEnv environment."""
def __init__(self) -> None:
self._task: Optional[TaskDef] = None
self._step_number: int = 0
self._done: bool = False
self._cumulative_score: float = 0.0
self._prev_grader_score: float = 0.0
self._history: list[Dict[str, Any]] = []
self._last_grader_score: float = 0.0
# ──────────────────────────────────────────────────────────────────────────
# reset
# ──────────────────────────────────────────────────────────────────────────
def reset(self, task_id: int = 1) -> Observation:
"""Start a fresh episode for the given task."""
self._task = get_task(task_id)
self._step_number = 0
self._done = False
self._cumulative_score = 0.0
self._prev_grader_score = 0.0
self._last_grader_score = 0.0
self._history = []
return self._make_observation()
# ──────────────────────────────────────────────────────────────────────────
# step
# ──────────────────────────────────────────────────────────────────────────
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
"""
Advance the environment by one step.
Returns:
observation: next Observation
reward: Reward for this step
done: whether the episode has ended
info: auxiliary dict
"""
if self._task is None:
raise RuntimeError("Call reset() before step().")
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
# Validate action
is_invalid = not action.rewritten_query or not action.rewritten_query.strip()
# Run grader
if is_invalid:
grader_result_score = self._prev_grader_score
breakdown = RewardBreakdown()
feedback = "Empty or invalid query submitted."
else:
gr = self._task.grader(action.rewritten_query)
grader_result_score = gr.score
breakdown = RewardBreakdown(
correctness=gr.correctness,
performance=gr.performance,
style=gr.style,
step_penalty=0.0,
)
feedback = gr.feedback
grader_result_score = _strict_score(grader_result_score)
# Compute shaped reward
step_reward = compute_step_reward(
grader_score=grader_result_score,
prev_grader_score=self._prev_grader_score,
step_number=self._step_number,
max_steps=self._task.max_steps,
is_done=action.is_done,
is_invalid=is_invalid,
)
# Apply step penalty to breakdown
import math
halfway = math.ceil(self._task.max_steps / 2)
if self._step_number > halfway and not action.is_done:
breakdown.step_penalty = -0.02
self._cumulative_score = _strict_score(self._cumulative_score + step_reward)
self._prev_grader_score = grader_result_score
self._last_grader_score = grader_result_score
self._step_number += 1
# Episode ends if agent signals done OR max steps reached
self._done = action.is_done or self._step_number >= self._task.max_steps
# Record history
self._history.append(
{
"step": self._step_number,
"rewritten_query": action.rewritten_query,
"grader_score": grader_result_score,
"step_reward": step_reward,
"is_done": action.is_done,
}
)
reward = Reward(
score=_strict_score(step_reward),
grader_score=grader_result_score,
breakdown=breakdown,
feedback=feedback,
cumulative_score=self._cumulative_score,
)
info = {
"step_number": self._step_number,
"grader_score": grader_result_score,
"cumulative_score": self._cumulative_score,
"is_invalid": is_invalid,
}
return self._make_observation(), reward, self._done, info
# ──────────────────────────────────────────────────────────────────────────
# state
# ──────────────────────────────────────────────────────────────────────────
def state(self) -> Dict[str, Any]:
"""Return the current internal state snapshot."""
if self._task is None:
return {"status": "not_started"}
return {
"task_id": self._task.id,
"task_name": self._task.name,
"difficulty": self._task.difficulty,
"step_number": self._step_number,
"max_steps": self._task.max_steps,
"done": self._done,
"cumulative_score": self._cumulative_score,
"last_grader_score": self._last_grader_score,
"history": self._history,
}
# ──────────────────────────────────────────────────────────────────────────
# Internal helpers
# ──────────────────────────────────────────────────────────────────────────
def _make_observation(self) -> Observation:
assert self._task is not None
return Observation(
task_id=self._task.id,
task_name=self._task.name,
task_description=self._task.description,
query=self._task.query,
schema_context=self._task.schema_context,
hint=self._task.hint,
step_number=self._step_number,
max_steps=self._task.max_steps,
done=self._done,
)