File size: 2,251 Bytes
dcdd52f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""OpenEnv adapter around the TraceFix-RL core environment."""

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from core.environment import TraceFixRLGym
    from core.models import CodeAction, CodeObservation
except ImportError:
    from core.environment import TraceFixRLGym
    from core.models import CodeAction, CodeObservation


class TraceFixRLEnvironment(Environment):
    """Environment implementation compatible with OpenEnv's server interface."""

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._gym = TraceFixRLGym()
        self._state = State(episode_id="", step_count=0)

    def reset(self, difficulty: str | None = None, task_name: str | None = None) -> CodeObservation:
        if difficulty == "easy":
            self._gym.training_step = 1
        elif difficulty == "medium":
            self._gym.training_step = 2000
        elif difficulty == "hard":
            self._gym.training_step = 6000

        task_dict = None
        if task_name and task_name != "tracefix_rl":
            try:
                from tasks.tasks import ALL_TASKS
                for t in ALL_TASKS:
                    if t.get("name") == task_name:
                        task_dict = t
                        break
            except ImportError:
                pass

        obs, system_prompt = self._gym.reset(task_index=task_dict)
        self._state = State(
            episode_id=obs.info.get("episode_id", ""),
            step_count=obs.step_count,
        )
        metadata = dict(obs.metadata or {})
        metadata["system_prompt"] = system_prompt
        obs.metadata = metadata
        return obs

    def step(self, action: CodeAction) -> CodeObservation:  # type: ignore[override]
        obs, reward, done, info = self._gym.step(action)
        obs.reward = reward
        obs.done = done
        metadata = dict(obs.metadata or {})
        metadata.update(info)
        obs.metadata = metadata
        self._state = State(
            episode_id=obs.info.get("episode_id", ""),
            step_count=obs.step_count,
        )
        return obs

    @property
    def state(self) -> State:
        return self._state