Spaces:
Sleeping
Sleeping
File size: 2,907 Bytes
210535c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | """
OpenEnv typed models — Observation, Action, Reward.
All models are Pydantic v2 compliant.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class Observation(BaseModel):
"""What the agent sees at each step."""
task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)")
task_name: str = Field(..., description="Human-readable task name")
task_description: str = Field(..., description="What the agent must accomplish")
query: str = Field(..., description="The SQL query the agent must fix / optimise")
schema_context: str = Field(
..., description="DDL / schema description relevant to the query"
)
hint: Optional[str] = Field(
None, description="Optional natural-language hint for the current step"
)
step_number: int = Field(0, description="Current step within the episode (0-indexed)")
max_steps: int = Field(5, description="Maximum steps allowed per episode")
done: bool = Field(False, description="Whether the episode has ended")
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class Action(BaseModel):
"""What the agent submits at each step."""
rewritten_query: str = Field(
..., description="The agent's rewritten / improved SQL query"
)
explanation: str = Field(
..., description="Natural-language explanation of changes made"
)
is_done: bool = Field(
False,
description="Set True when the agent believes the query is fully optimised",
)
# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------
class RewardBreakdown(BaseModel):
correctness: float = Field(0.0, ge=0.0, le=1.0)
performance: float = Field(0.0, ge=0.0, le=1.0)
style: float = Field(0.0, ge=0.0, le=1.0)
step_penalty: float = Field(0.0, le=0.0) # always ≤ 0
class Reward(BaseModel):
"""Reward returned after each step."""
score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward")
grader_score: float = Field(
..., ge=0.0, le=1.0, description="Raw grader score for the submitted query"
)
breakdown: RewardBreakdown = Field(
default_factory=RewardBreakdown,
description="Per-dimension partial scores",
)
feedback: str = Field("", description="Human-readable feedback from the grader")
cumulative_score: float = Field(
0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far"
)
|