Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any | |
| from env.models import FlakySleuthAction, FlakySleuthObservation | |
| from env.sandbox import Sandbox | |
| from env.task_loader import TaskLoader | |
| from graders import grade_action | |
| FLAKY_SIGNAL_PATTERNS = [ | |
| "sleep", | |
| "random", | |
| "time", | |
| "datetime", | |
| "thread", | |
| "asyncio", | |
| "fixture", | |
| "setup", | |
| "teardown", | |
| "global", | |
| "shared", | |
| "singleton", | |
| "os.environ", | |
| "socket", | |
| "timeout", | |
| "retry", | |
| "mock", | |
| "patch", | |
| ] | |
| TERMINAL_ACTIONS = ("classify_flakiness", "classify_root_cause", "propose_fix") | |
| class FlakySleuthEnv: | |
| def __init__(self, dataset_path: str = "dataset/py_tasks.csv", max_steps: int = 20): | |
| self.loader = TaskLoader(dataset_path) | |
| self.sandbox: Sandbox | None = None | |
| self.current_task: dict[str, Any] | None = None | |
| self.step_count = 0 | |
| self.max_steps = max_steps | |
| self.cumulative_progress = 0.0 | |
| self.files_read: set[str] = set() | |
| self.episode_actions: list[FlakySleuthAction] = [] | |
| self.search_pattern_counts: dict[str, int] = {} | |
| self.search_context_counts: dict[str, int] = {} | |
| self.consecutive_searches = 0 | |
| def reset(self) -> FlakySleuthObservation: | |
| if self.sandbox: | |
| self.sandbox.cleanup() | |
| self.current_task = self.loader.sample() | |
| self.current_task.setdefault("label", "flaky") | |
| self.sandbox = Sandbox(self.current_task) | |
| self.sandbox.setup() | |
| self.current_task["sandbox_root"] = self.sandbox.tmpdir or "" | |
| test_file = self.current_task.get("test_file", "") | |
| if test_file and self.sandbox.tmpdir: | |
| self.current_task["sandbox_test_path"] = f"{self.sandbox.tmpdir}/{test_file}" | |
| self.step_count = 0 | |
| self.cumulative_progress = 0.0 | |
| self.files_read = set() | |
| self.episode_actions = [] | |
| self.search_pattern_counts = {} | |
| self.search_context_counts = {} | |
| self.consecutive_searches = 0 | |
| return self._make_obs() | |
| def step(self, action: FlakySleuthAction): | |
| if not self.current_task or not self.sandbox: | |
| raise RuntimeError("Environment is not initialized. Call reset() first.") | |
| self.step_count += 1 | |
| self.episode_actions.append(action) | |
| tool_output: str | None = None | |
| reward = 0.0 | |
| done = False | |
| info: dict[str, Any] = {} | |
| if action.action_type in TERMINAL_ACTIONS: | |
| terminal_score = grade_action(action, self.current_task) | |
| late_penalty = max(0, self.step_count - 15) * 0.05 | |
| wrong_dir_penalty = 0.0 | |
| if ( | |
| action.action_type == "classify_flakiness" | |
| and action.argument.strip().lower() == "stable" | |
| and str(self.current_task.get("label", "flaky")).lower() == "flaky" | |
| ): | |
| wrong_dir_penalty = 0.2 | |
| reward = min( | |
| 0.999, | |
| max( | |
| 0.001, | |
| self.cumulative_progress + terminal_score - late_penalty - wrong_dir_penalty, | |
| ), | |
| ) | |
| done = True | |
| info = { | |
| "terminal_score": terminal_score, | |
| "progress_score": self.cumulative_progress, | |
| "late_penalty": late_penalty, | |
| "task_type": self.current_task.get("task_type"), | |
| "category": self.current_task.get("category"), | |
| } | |
| else: | |
| tool_output, progress = self._execute_exploration(action) | |
| self.cumulative_progress = min(0.30, max(0.0, self.cumulative_progress + progress)) | |
| reward = progress | |
| if not done and self.step_count >= self.max_steps: | |
| done = True | |
| info = { | |
| "terminal_score": 0.001, | |
| "progress_score": self.cumulative_progress, | |
| "late_penalty": max(0, self.step_count - 15) * 0.05, | |
| "timeout": True, | |
| "task_type": self.current_task.get("task_type"), | |
| "category": self.current_task.get("category"), | |
| } | |
| obs = self._make_obs(tool_output) | |
| return obs, reward, done, info | |
| def state(self) -> dict[str, Any]: | |
| if not self.current_task: | |
| return { | |
| "repo_url": None, | |
| "test_name": None, | |
| "task_type": None, | |
| "step_count": self.step_count, | |
| "files_read": [], | |
| "cumulative_progress": self.cumulative_progress, | |
| } | |
| return { | |
| "repo_url": self.current_task.get("repo_url"), | |
| "test_name": self.current_task.get("test_name"), | |
| "task_type": self.current_task.get("task_type"), | |
| "step_count": self.step_count, | |
| "files_read": sorted(self.files_read), | |
| "cumulative_progress": self.cumulative_progress, | |
| } | |
| def close(self) -> None: | |
| if self.sandbox: | |
| self.sandbox.cleanup() | |
| self.sandbox = None | |
| def _execute_exploration(self, action: FlakySleuthAction) -> tuple[str, float]: | |
| assert self.current_task is not None | |
| assert self.sandbox is not None | |
| progress = 0.0 | |
| output = "" | |
| if action.action_type != "search_code": | |
| self.consecutive_searches = 0 | |
| if action.action_type == "read_file": | |
| content = self.sandbox.read_file(action.argument) | |
| if content is None: | |
| output = f"ERROR: File not found: {action.argument}" | |
| progress = -0.05 | |
| elif action.argument in self.files_read: | |
| output = content | |
| progress = 0.0 | |
| else: | |
| self.files_read.add(action.argument) | |
| output = content | |
| progress = self._file_relevance_reward(action.argument) | |
| elif action.action_type == "search_code": | |
| self.consecutive_searches += 1 | |
| output = self.sandbox.grep(action.argument) | |
| base_progress = self._search_relevance_reward(action.argument) | |
| spam_penalty, warnings = self._search_spam_penalty(action.argument, output) | |
| progress = max(-0.25, base_progress - spam_penalty) | |
| if warnings: | |
| output = f"{output}\n\nWARNING: {' '.join(warnings)}" | |
| elif action.action_type == "run_test": | |
| output = self.sandbox.run_test(self.current_task.get("test_name", "")) | |
| category = str(self.current_task.get("category", "")).strip() | |
| if category not in ("OD", "OD-Brit", "OD-Vic"): | |
| progress = 0.05 | |
| else: | |
| output = f"ERROR: Unsupported action_type {action.action_type}" | |
| progress = -0.05 | |
| return output, progress | |
| def _file_relevance_reward(self, filepath: str) -> float: | |
| assert self.current_task is not None | |
| test_file = str(self.current_task.get("test_file", "")) | |
| if test_file and test_file in filepath: | |
| return 0.07 | |
| if filepath.endswith(".py"): | |
| return 0.03 | |
| return 0.01 | |
| def _search_relevance_reward(self, pattern: str) -> float: | |
| pattern_lower = pattern.lower() | |
| if any(signal in pattern_lower for signal in FLAKY_SIGNAL_PATTERNS): | |
| return 0.04 | |
| return 0.01 | |
| def _search_spam_penalty(self, pattern: str, output: str) -> tuple[float, list[str]]: | |
| penalty = 0.0 | |
| warnings: list[str] = [] | |
| pattern_key = " ".join(pattern.lower().split()) | |
| if pattern_key: | |
| pattern_count = self.search_pattern_counts.get(pattern_key, 0) + 1 | |
| self.search_pattern_counts[pattern_key] = pattern_count | |
| if pattern_count > 1: | |
| repeat_penalty = min(0.02 * (pattern_count - 1), 0.12) | |
| penalty += repeat_penalty | |
| warnings.append( | |
| f"Repeated search pattern ({pattern_count}x) penalty={repeat_penalty:.2f}." | |
| ) | |
| context_hits = self._extract_search_hits(output) | |
| context_key = f"{pattern_key}::{','.join(context_hits)}" | |
| context_count = self.search_context_counts.get(context_key, 0) + 1 | |
| self.search_context_counts[context_key] = context_count | |
| if context_count > 1: | |
| context_penalty = min(0.03 * (context_count - 1), 0.15) | |
| penalty += context_penalty | |
| warnings.append( | |
| f"Same search context repeated ({context_count}x) penalty={context_penalty:.2f}." | |
| ) | |
| if self.consecutive_searches > 3: | |
| streak_penalty = min(0.02 * (self.consecutive_searches - 3), 0.20) | |
| penalty += streak_penalty | |
| warnings.append( | |
| f"Search-only streak={self.consecutive_searches} penalty={streak_penalty:.2f}." | |
| ) | |
| return min(penalty, 0.35), warnings | |
| def _extract_search_hits(self, output: str) -> tuple[str, ...]: | |
| files: list[str] = [] | |
| seen: set[str] = set() | |
| for raw_line in output.splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("No matches found") or line.startswith("Search "): | |
| continue | |
| filepath = line.split(":", 1)[0].strip() | |
| if filepath.startswith("./"): | |
| filepath = filepath[2:] | |
| if not filepath.endswith(".py"): | |
| continue | |
| if filepath not in seen: | |
| seen.add(filepath) | |
| files.append(filepath) | |
| if len(files) >= 5: | |
| break | |
| return tuple(files) | |
| def _make_obs(self, tool_output: str | None = None) -> FlakySleuthObservation: | |
| if not self.current_task: | |
| raise RuntimeError("No current task available") | |
| return FlakySleuthObservation( | |
| repo_url=str(self.current_task.get("repo_url", "")), | |
| test_name=str(self.current_task.get("test_name", "")), | |
| test_code=str(self.current_task.get("test_code", ""))[:2000], | |
| file_tree=self.sandbox.file_tree if self.sandbox else [], | |
| tool_output=tool_output, | |
| task_type=str(self.current_task.get("task_type", "classify")), | |
| task_description=str(self.current_task.get("task_description", "Investigate the flaky test.")), | |
| step_count=self.step_count, | |
| done=False, | |
| reward=None, | |
| ) | |