sql-arena / src /sql_arena /environment.py
rahul2124's picture
Upload folder using huggingface_hub
ac49ad8 verified
Raw
History Blame Contribute Delete
6.66 kB
"""
Core SQL Arena Environment.
Implements the OpenEnv step()/reset()/state() interface.
"""
from typing import Optional, Dict, Any, List
from .models import SQLArenaAction, SQLArenaObservation, SQLArenaState
from .database import DatabaseManager
from .tasks import SQLTask, get_task, list_tasks, TASK_BY_ID
from .graders import grade_result, generate_hint
class StepResult:
"""Result of a single environment step."""
def __init__(
self,
observation: SQLArenaObservation,
reward: float,
done: bool,
info: Optional[Dict[str, Any]] = None,
):
self.observation = observation
self.reward = reward
self.done = done
self.info = info or {}
class SQLArenaEnvironment:
"""
SQL Arena: An interactive SQL query challenge environment.
The agent receives a database schema and a natural language question,
then iteratively writes SQL queries. The environment provides
execution results, feedback, and partial credit scoring.
"""
def __init__(self):
self.db = DatabaseManager()
self.current_task: Optional[SQLTask] = None
self._state: Optional[SQLArenaState] = None
self._last_observation: Optional[SQLArenaObservation] = None
def reset(
self,
difficulty: str = "basic_select",
task_id: Optional[str] = None,
) -> StepResult:
"""
Reset the environment with a new task.
Args:
difficulty: 'basic_select', 'join_aggregate', or 'complex_analysis'
task_id: Optional specific task ID
Returns:
StepResult with initial observation
"""
# Get the task
self.current_task = get_task(difficulty, task_id)
task = self.current_task
# Setup database
self.db.create_database(task.setup_sql)
# Initialize state
self._state = SQLArenaState(
task_id=task.task_id,
difficulty=task.difficulty,
current_step=0,
max_steps=task.max_steps,
best_score=0.0,
total_reward=0.0,
rewards_history=[],
done=False,
last_action_error=None,
)
# Create initial observation
self._last_observation = SQLArenaObservation(
schema_description=task.schema_description,
question=task.question,
query_result=None,
error_message=None,
feedback="Welcome to SQL Arena! Write a SQL query to answer the question above.",
expected_columns=task.expected_columns,
attempts_remaining=task.max_steps,
difficulty=task.difficulty,
task_id=task.task_id,
)
return StepResult(
observation=self._last_observation,
reward=0.0,
done=False,
info={"task_title": task.title},
)
def step(self, action: SQLArenaAction) -> StepResult:
"""
Execute the agent's SQL query and return feedback.
Args:
action: SQLArenaAction containing the SQL query
Returns:
StepResult with observation, reward, and done flag
"""
if self._state is None or self.current_task is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
if self._state.done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
task = self.current_task
state = self._state
# Increment step counter
state.current_step += 1
# Execute the query
success, result, error = self.db.execute_query(action.sql_query)
# Grade the result
score, feedback = grade_result(task, success, result, error)
# Track best score
state.best_score = max(state.best_score, score)
# Calculate step reward
if len(state.rewards_history) == 0:
reward = score
else:
prev_best = max(state.rewards_history) if state.rewards_history else 0.0
improvement = max(0, score - prev_best)
reward = score * 0.5 + improvement * 0.5
reward = round(min(max(reward, 0.0), 1.0), 4)
# Clamp to strictly between 0 and 1
if reward <= 0.0:
reward = 0.01
if reward >= 1.0:
reward = 0.99
state.rewards_history.append(reward)
state.total_reward += reward
# Add progressive hints
hint = generate_hint(task, state.current_step, score)
if hint and score < 1.0:
feedback += f"\n\n{hint}"
# Check if done
attempts_remaining = task.max_steps - state.current_step
is_perfect = score >= 1.0
is_out_of_steps = attempts_remaining <= 0
state.done = is_perfect or is_out_of_steps
state.last_action_error = error
# Format query result for observation
query_result_str = None
if success and result:
query_result_str = self.db.format_result(result)
# Build observation
self._last_observation = SQLArenaObservation(
schema_description=task.schema_description,
question=task.question,
query_result=query_result_str,
error_message=error,
feedback=feedback,
expected_columns=task.expected_columns,
attempts_remaining=attempts_remaining,
difficulty=task.difficulty,
task_id=task.task_id,
)
return StepResult(
observation=self._last_observation,
reward=reward,
done=state.done,
info={
"score": score,
"best_score": state.best_score,
"step": state.current_step,
"is_perfect": is_perfect,
},
)
def state(self) -> SQLArenaState:
"""Return the current environment state."""
if self._state is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
return self._state
def close(self) -> None:
"""Clean up resources."""
self.db.close()
self.current_task = None
self._state = None
self._last_observation = None
def get_available_tasks(self) -> Dict:
"""Return all available tasks."""
return list_tasks()