File size: 3,656 Bytes
72805b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | """
Typed Pydantic models for SQL Arena OpenEnv environment.
These models define the contract between the agent and environment:
- SQLArenaAction: What the agent sends (a SQL query)
- SQLArenaObservation: What the agent receives (schema, results, feedback)
- SQLArenaState: Internal environment state tracking
"""
from pydantic import BaseModel, Field
from typing import Optional, List
class SQLArenaAction(BaseModel):
"""
Action model — the agent submits a SQL query.
This is what the agent sends to the environment each step.
The environment will execute this query against the SQLite database
and return results + feedback.
"""
sql_query: str = Field(
...,
description="SQL query to execute against the database",
examples=[
"SELECT name, salary FROM employees WHERE salary > 50000",
"SELECT department, COUNT(*) FROM employees GROUP BY department",
]
)
class SQLArenaObservation(BaseModel):
"""
Observation model — what the agent sees after each step.
Contains the database schema, the question to answer,
results from the last query, error messages, and feedback
with partial credit information.
"""
# Always present
schema_description: str = Field(
...,
description="Human-readable database schema (CREATE TABLE statements)"
)
question: str = Field(
...,
description="Natural language question the agent must answer with SQL"
)
difficulty: str = Field(
...,
description="Task difficulty level: basic_select, join_aggregate, or complex_analysis"
)
task_id: str = Field(
...,
description="Unique identifier for this specific problem"
)
attempts_remaining: int = Field(
...,
description="Number of query attempts the agent has left"
)
# Present after step() calls
query_result: Optional[str] = Field(
None,
description="Formatted result table from the last executed query"
)
error_message: Optional[str] = Field(
None,
description="SQL error message if the query failed to execute"
)
feedback: Optional[str] = Field(
None,
description="Detailed feedback on query correctness with partial credit breakdown"
)
# Hints to help the agent
expected_columns: Optional[List[str]] = Field(
None,
description="Expected column names in the correct result (hint)"
)
class SQLArenaState(BaseModel):
"""
Internal state model — tracks the episode progress.
This is returned by the state() endpoint and contains
all information about the current episode.
"""
task_id: str = Field(..., description="Current task identifier")
difficulty: str = Field(..., description="Current difficulty level")
current_step: int = Field(0, description="Number of steps taken so far")
max_steps: int = Field(5, description="Maximum steps allowed for this task")
best_score: float = Field(0.0, description="Best score achieved so far in this episode")
total_reward: float = Field(0.0, description="Sum of all rewards received")
rewards_history: List[float] = Field(
default_factory=list,
description="List of rewards received at each step"
)
done: bool = Field(False, description="Whether the episode has ended")
last_action_error: Optional[str] = Field(
None,
description="Error from the last action, if any"
) |