File size: 2,907 Bytes
210535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
OpenEnv typed models — Observation, Action, Reward.
All models are Pydantic v2 compliant.
"""
from __future__ import annotations

from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field


# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------

class Observation(BaseModel):
    """What the agent sees at each step."""

    task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)")
    task_name: str = Field(..., description="Human-readable task name")
    task_description: str = Field(..., description="What the agent must accomplish")
    query: str = Field(..., description="The SQL query the agent must fix / optimise")
    schema_context: str = Field(
        ..., description="DDL / schema description relevant to the query"
    )
    hint: Optional[str] = Field(
        None, description="Optional natural-language hint for the current step"
    )
    step_number: int = Field(0, description="Current step within the episode (0-indexed)")
    max_steps: int = Field(5, description="Maximum steps allowed per episode")
    done: bool = Field(False, description="Whether the episode has ended")


# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------

class Action(BaseModel):
    """What the agent submits at each step."""

    rewritten_query: str = Field(
        ..., description="The agent's rewritten / improved SQL query"
    )
    explanation: str = Field(
        ..., description="Natural-language explanation of changes made"
    )
    is_done: bool = Field(
        False,
        description="Set True when the agent believes the query is fully optimised",
    )


# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------

class RewardBreakdown(BaseModel):
    correctness: float = Field(0.0, ge=0.0, le=1.0)
    performance: float = Field(0.0, ge=0.0, le=1.0)
    style: float = Field(0.0, ge=0.0, le=1.0)
    step_penalty: float = Field(0.0, le=0.0)  # always ≤ 0


class Reward(BaseModel):
    """Reward returned after each step."""

    score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward")
    grader_score: float = Field(
        ..., ge=0.0, le=1.0, description="Raw grader score for the submitted query"
    )
    breakdown: RewardBreakdown = Field(
        default_factory=RewardBreakdown,
        description="Per-dimension partial scores",
    )
    feedback: str = Field("", description="Human-readable feedback from the grader")
    cumulative_score: float = Field(
        0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far"
    )