Spaces:
Sleeping
Sleeping
| """ | |
| Core OpenEnv environment: SQLOptimizerEnv | |
| Implements the three required methods: | |
| reset(task_id) β Observation | |
| step(action) β (Observation, Reward, done, info) | |
| state() β dict (current internal snapshot) | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, Optional, Tuple | |
| from .models import Action, Observation, Reward, RewardBreakdown | |
| from .tasks import TASKS, TaskDef, get_task | |
| from .reward import compute_step_reward | |
| _MIN_SCORE_EPS = 0.001 | |
| _MAX_SCORE_EPS = 0.999 | |
| def _strict_score(value: float) -> float: | |
| return round(min(max(float(value), _MIN_SCORE_EPS), _MAX_SCORE_EPS), 4) | |
| class SQLOptimizerEnv: | |
| """SQL Query Optimizer OpenEnv environment.""" | |
| def __init__(self) -> None: | |
| self._task: Optional[TaskDef] = None | |
| self._step_number: int = 0 | |
| self._done: bool = False | |
| self._cumulative_score: float = 0.0 | |
| self._prev_grader_score: float = 0.0 | |
| self._history: list[Dict[str, Any]] = [] | |
| self._last_grader_score: float = 0.0 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # reset | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: int = 1) -> Observation: | |
| """Start a fresh episode for the given task.""" | |
| self._task = get_task(task_id) | |
| self._step_number = 0 | |
| self._done = False | |
| self._cumulative_score = 0.0 | |
| self._prev_grader_score = 0.0 | |
| self._last_grader_score = 0.0 | |
| self._history = [] | |
| return self._make_observation() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # step | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| """ | |
| Advance the environment by one step. | |
| Returns: | |
| observation: next Observation | |
| reward: Reward for this step | |
| done: whether the episode has ended | |
| info: auxiliary dict | |
| """ | |
| if self._task is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| # Validate action | |
| is_invalid = not action.rewritten_query or not action.rewritten_query.strip() | |
| # Run grader | |
| if is_invalid: | |
| grader_result_score = self._prev_grader_score | |
| breakdown = RewardBreakdown() | |
| feedback = "Empty or invalid query submitted." | |
| else: | |
| gr = self._task.grader(action.rewritten_query) | |
| grader_result_score = gr.score | |
| breakdown = RewardBreakdown( | |
| correctness=gr.correctness, | |
| performance=gr.performance, | |
| style=gr.style, | |
| step_penalty=0.0, | |
| ) | |
| feedback = gr.feedback | |
| grader_result_score = _strict_score(grader_result_score) | |
| # Compute shaped reward | |
| step_reward = compute_step_reward( | |
| grader_score=grader_result_score, | |
| prev_grader_score=self._prev_grader_score, | |
| step_number=self._step_number, | |
| max_steps=self._task.max_steps, | |
| is_done=action.is_done, | |
| is_invalid=is_invalid, | |
| ) | |
| # Apply step penalty to breakdown | |
| import math | |
| halfway = math.ceil(self._task.max_steps / 2) | |
| if self._step_number > halfway and not action.is_done: | |
| breakdown.step_penalty = -0.02 | |
| self._cumulative_score = _strict_score(self._cumulative_score + step_reward) | |
| self._prev_grader_score = grader_result_score | |
| self._last_grader_score = grader_result_score | |
| self._step_number += 1 | |
| # Episode ends if agent signals done OR max steps reached | |
| self._done = action.is_done or self._step_number >= self._task.max_steps | |
| # Record history | |
| self._history.append( | |
| { | |
| "step": self._step_number, | |
| "rewritten_query": action.rewritten_query, | |
| "grader_score": grader_result_score, | |
| "step_reward": step_reward, | |
| "is_done": action.is_done, | |
| } | |
| ) | |
| reward = Reward( | |
| score=_strict_score(step_reward), | |
| grader_score=grader_result_score, | |
| breakdown=breakdown, | |
| feedback=feedback, | |
| cumulative_score=self._cumulative_score, | |
| ) | |
| info = { | |
| "step_number": self._step_number, | |
| "grader_score": grader_result_score, | |
| "cumulative_score": self._cumulative_score, | |
| "is_invalid": is_invalid, | |
| } | |
| return self._make_observation(), reward, self._done, info | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # state | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def state(self) -> Dict[str, Any]: | |
| """Return the current internal state snapshot.""" | |
| if self._task is None: | |
| return {"status": "not_started"} | |
| return { | |
| "task_id": self._task.id, | |
| "task_name": self._task.name, | |
| "difficulty": self._task.difficulty, | |
| "step_number": self._step_number, | |
| "max_steps": self._task.max_steps, | |
| "done": self._done, | |
| "cumulative_score": self._cumulative_score, | |
| "last_grader_score": self._last_grader_score, | |
| "history": self._history, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Internal helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_observation(self) -> Observation: | |
| assert self._task is not None | |
| return Observation( | |
| task_id=self._task.id, | |
| task_name=self._task.name, | |
| task_description=self._task.description, | |
| query=self._task.query, | |
| schema_context=self._task.schema_context, | |
| hint=self._task.hint, | |
| step_number=self._step_number, | |
| max_steps=self._task.max_steps, | |
| done=self._done, | |
| ) | |