Spaces:
Running
Running
File size: 2,251 Bytes
dcdd52f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """OpenEnv adapter around the TraceFix-RL core environment."""
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from core.environment import TraceFixRLGym
from core.models import CodeAction, CodeObservation
except ImportError:
from core.environment import TraceFixRLGym
from core.models import CodeAction, CodeObservation
class TraceFixRLEnvironment(Environment):
"""Environment implementation compatible with OpenEnv's server interface."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._gym = TraceFixRLGym()
self._state = State(episode_id="", step_count=0)
def reset(self, difficulty: str | None = None, task_name: str | None = None) -> CodeObservation:
if difficulty == "easy":
self._gym.training_step = 1
elif difficulty == "medium":
self._gym.training_step = 2000
elif difficulty == "hard":
self._gym.training_step = 6000
task_dict = None
if task_name and task_name != "tracefix_rl":
try:
from tasks.tasks import ALL_TASKS
for t in ALL_TASKS:
if t.get("name") == task_name:
task_dict = t
break
except ImportError:
pass
obs, system_prompt = self._gym.reset(task_index=task_dict)
self._state = State(
episode_id=obs.info.get("episode_id", ""),
step_count=obs.step_count,
)
metadata = dict(obs.metadata or {})
metadata["system_prompt"] = system_prompt
obs.metadata = metadata
return obs
def step(self, action: CodeAction) -> CodeObservation: # type: ignore[override]
obs, reward, done, info = self._gym.step(action)
obs.reward = reward
obs.done = done
metadata = dict(obs.metadata or {})
metadata.update(info)
obs.metadata = metadata
self._state = State(
episode_id=obs.info.get("episode_id", ""),
step_count=obs.step_count,
)
return obs
@property
def state(self) -> State:
return self._state
|