| """
|
| Typed Pydantic models for SQL Arena OpenEnv environment.
|
|
|
| These models define the contract between the agent and environment:
|
| - SQLArenaAction: What the agent sends (a SQL query)
|
| - SQLArenaObservation: What the agent receives (schema, results, feedback)
|
| - SQLArenaState: Internal environment state tracking
|
| """
|
|
|
| from pydantic import BaseModel, Field
|
| from typing import Optional, List
|
|
|
|
|
| class SQLArenaAction(BaseModel):
|
| """
|
| Action model — the agent submits a SQL query.
|
|
|
| This is what the agent sends to the environment each step.
|
| The environment will execute this query against the SQLite database
|
| and return results + feedback.
|
| """
|
| sql_query: str = Field(
|
| ...,
|
| description="SQL query to execute against the database",
|
| examples=[
|
| "SELECT name, salary FROM employees WHERE salary > 50000",
|
| "SELECT department, COUNT(*) FROM employees GROUP BY department",
|
| ]
|
| )
|
|
|
|
|
| class SQLArenaObservation(BaseModel):
|
| """
|
| Observation model — what the agent sees after each step.
|
|
|
| Contains the database schema, the question to answer,
|
| results from the last query, error messages, and feedback
|
| with partial credit information.
|
| """
|
|
|
| schema_description: str = Field(
|
| ...,
|
| description="Human-readable database schema (CREATE TABLE statements)"
|
| )
|
| question: str = Field(
|
| ...,
|
| description="Natural language question the agent must answer with SQL"
|
| )
|
| difficulty: str = Field(
|
| ...,
|
| description="Task difficulty level: basic_select, join_aggregate, or complex_analysis"
|
| )
|
| task_id: str = Field(
|
| ...,
|
| description="Unique identifier for this specific problem"
|
| )
|
| attempts_remaining: int = Field(
|
| ...,
|
| description="Number of query attempts the agent has left"
|
| )
|
|
|
|
|
| query_result: Optional[str] = Field(
|
| None,
|
| description="Formatted result table from the last executed query"
|
| )
|
| error_message: Optional[str] = Field(
|
| None,
|
| description="SQL error message if the query failed to execute"
|
| )
|
| feedback: Optional[str] = Field(
|
| None,
|
| description="Detailed feedback on query correctness with partial credit breakdown"
|
| )
|
|
|
|
|
| expected_columns: Optional[List[str]] = Field(
|
| None,
|
| description="Expected column names in the correct result (hint)"
|
| )
|
|
|
|
|
| class SQLArenaState(BaseModel):
|
| """
|
| Internal state model — tracks the episode progress.
|
|
|
| This is returned by the state() endpoint and contains
|
| all information about the current episode.
|
| """
|
| task_id: str = Field(..., description="Current task identifier")
|
| difficulty: str = Field(..., description="Current difficulty level")
|
| current_step: int = Field(0, description="Number of steps taken so far")
|
| max_steps: int = Field(5, description="Maximum steps allowed for this task")
|
| best_score: float = Field(0.0, description="Best score achieved so far in this episode")
|
| total_reward: float = Field(0.0, description="Sum of all rewards received")
|
| rewards_history: List[float] = Field(
|
| default_factory=list,
|
| description="List of rewards received at each step"
|
| )
|
| done: bool = Field(False, description="Whether the episode has ended")
|
| last_action_error: Optional[str] = Field(
|
| None,
|
| description="Error from the last action, if any"
|
| ) |