Spaces:
Sleeping
Sleeping
| """OpenEnv adapter around the TraceFix-RL core environment.""" | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from core.environment import TraceFixRLGym | |
| from core.models import CodeAction, CodeObservation | |
| except ImportError: | |
| from core.environment import TraceFixRLGym | |
| from core.models import CodeAction, CodeObservation | |
| class TraceFixRLEnvironment(Environment): | |
| """Environment implementation compatible with OpenEnv's server interface.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._gym = TraceFixRLGym() | |
| self._state = State(episode_id="", step_count=0) | |
| def reset(self, difficulty: str | None = None, task_name: str | None = None) -> CodeObservation: | |
| if difficulty == "easy": | |
| self._gym.training_step = 1 | |
| elif difficulty == "medium": | |
| self._gym.training_step = 2000 | |
| elif difficulty == "hard": | |
| self._gym.training_step = 6000 | |
| task_dict = None | |
| if task_name and task_name != "tracefix_rl": | |
| try: | |
| from tasks.tasks import ALL_TASKS | |
| for t in ALL_TASKS: | |
| if t.get("name") == task_name: | |
| task_dict = t | |
| break | |
| except ImportError: | |
| pass | |
| obs, system_prompt = self._gym.reset(task_index=task_dict) | |
| self._state = State( | |
| episode_id=obs.info.get("episode_id", ""), | |
| step_count=obs.step_count, | |
| ) | |
| metadata = dict(obs.metadata or {}) | |
| metadata["system_prompt"] = system_prompt | |
| obs.metadata = metadata | |
| return obs | |
| def step(self, action: CodeAction) -> CodeObservation: # type: ignore[override] | |
| obs, reward, done, info = self._gym.step(action) | |
| obs.reward = reward | |
| obs.done = done | |
| metadata = dict(obs.metadata or {}) | |
| metadata.update(info) | |
| obs.metadata = metadata | |
| self._state = State( | |
| episode_id=obs.info.get("episode_id", ""), | |
| step_count=obs.step_count, | |
| ) | |
| return obs | |
| def state(self) -> State: | |
| return self._state | |