FlakyTestSleuthOpenEnvRL / env /environment.py
vedkdev's picture
Upload folder using huggingface_hub
dc990fa verified
from __future__ import annotations
from typing import Any
from env.models import FlakySleuthAction, FlakySleuthObservation
from env.sandbox import Sandbox
from env.task_loader import TaskLoader
from graders import grade_action
FLAKY_SIGNAL_PATTERNS = [
"sleep",
"random",
"time",
"datetime",
"thread",
"asyncio",
"fixture",
"setup",
"teardown",
"global",
"shared",
"singleton",
"os.environ",
"socket",
"timeout",
"retry",
"mock",
"patch",
]
TERMINAL_ACTIONS = ("classify_flakiness", "classify_root_cause", "propose_fix")
class FlakySleuthEnv:
def __init__(self, dataset_path: str = "dataset/py_tasks.csv", max_steps: int = 20):
self.loader = TaskLoader(dataset_path)
self.sandbox: Sandbox | None = None
self.current_task: dict[str, Any] | None = None
self.step_count = 0
self.max_steps = max_steps
self.cumulative_progress = 0.0
self.files_read: set[str] = set()
self.episode_actions: list[FlakySleuthAction] = []
self.search_pattern_counts: dict[str, int] = {}
self.search_context_counts: dict[str, int] = {}
self.consecutive_searches = 0
def reset(self) -> FlakySleuthObservation:
if self.sandbox:
self.sandbox.cleanup()
self.current_task = self.loader.sample()
self.current_task.setdefault("label", "flaky")
self.sandbox = Sandbox(self.current_task)
self.sandbox.setup()
self.current_task["sandbox_root"] = self.sandbox.tmpdir or ""
test_file = self.current_task.get("test_file", "")
if test_file and self.sandbox.tmpdir:
self.current_task["sandbox_test_path"] = f"{self.sandbox.tmpdir}/{test_file}"
self.step_count = 0
self.cumulative_progress = 0.0
self.files_read = set()
self.episode_actions = []
self.search_pattern_counts = {}
self.search_context_counts = {}
self.consecutive_searches = 0
return self._make_obs()
def step(self, action: FlakySleuthAction):
if not self.current_task or not self.sandbox:
raise RuntimeError("Environment is not initialized. Call reset() first.")
self.step_count += 1
self.episode_actions.append(action)
tool_output: str | None = None
reward = 0.0
done = False
info: dict[str, Any] = {}
if action.action_type in TERMINAL_ACTIONS:
terminal_score = grade_action(action, self.current_task)
late_penalty = max(0, self.step_count - 15) * 0.05
wrong_dir_penalty = 0.0
if (
action.action_type == "classify_flakiness"
and action.argument.strip().lower() == "stable"
and str(self.current_task.get("label", "flaky")).lower() == "flaky"
):
wrong_dir_penalty = 0.2
reward = min(
0.999,
max(
0.001,
self.cumulative_progress + terminal_score - late_penalty - wrong_dir_penalty,
),
)
done = True
info = {
"terminal_score": terminal_score,
"progress_score": self.cumulative_progress,
"late_penalty": late_penalty,
"task_type": self.current_task.get("task_type"),
"category": self.current_task.get("category"),
}
else:
tool_output, progress = self._execute_exploration(action)
self.cumulative_progress = min(0.30, max(0.0, self.cumulative_progress + progress))
reward = progress
if not done and self.step_count >= self.max_steps:
done = True
info = {
"terminal_score": 0.001,
"progress_score": self.cumulative_progress,
"late_penalty": max(0, self.step_count - 15) * 0.05,
"timeout": True,
"task_type": self.current_task.get("task_type"),
"category": self.current_task.get("category"),
}
obs = self._make_obs(tool_output)
return obs, reward, done, info
def state(self) -> dict[str, Any]:
if not self.current_task:
return {
"repo_url": None,
"test_name": None,
"task_type": None,
"step_count": self.step_count,
"files_read": [],
"cumulative_progress": self.cumulative_progress,
}
return {
"repo_url": self.current_task.get("repo_url"),
"test_name": self.current_task.get("test_name"),
"task_type": self.current_task.get("task_type"),
"step_count": self.step_count,
"files_read": sorted(self.files_read),
"cumulative_progress": self.cumulative_progress,
}
def close(self) -> None:
if self.sandbox:
self.sandbox.cleanup()
self.sandbox = None
def _execute_exploration(self, action: FlakySleuthAction) -> tuple[str, float]:
assert self.current_task is not None
assert self.sandbox is not None
progress = 0.0
output = ""
if action.action_type != "search_code":
self.consecutive_searches = 0
if action.action_type == "read_file":
content = self.sandbox.read_file(action.argument)
if content is None:
output = f"ERROR: File not found: {action.argument}"
progress = -0.05
elif action.argument in self.files_read:
output = content
progress = 0.0
else:
self.files_read.add(action.argument)
output = content
progress = self._file_relevance_reward(action.argument)
elif action.action_type == "search_code":
self.consecutive_searches += 1
output = self.sandbox.grep(action.argument)
base_progress = self._search_relevance_reward(action.argument)
spam_penalty, warnings = self._search_spam_penalty(action.argument, output)
progress = max(-0.25, base_progress - spam_penalty)
if warnings:
output = f"{output}\n\nWARNING: {' '.join(warnings)}"
elif action.action_type == "run_test":
output = self.sandbox.run_test(self.current_task.get("test_name", ""))
category = str(self.current_task.get("category", "")).strip()
if category not in ("OD", "OD-Brit", "OD-Vic"):
progress = 0.05
else:
output = f"ERROR: Unsupported action_type {action.action_type}"
progress = -0.05
return output, progress
def _file_relevance_reward(self, filepath: str) -> float:
assert self.current_task is not None
test_file = str(self.current_task.get("test_file", ""))
if test_file and test_file in filepath:
return 0.07
if filepath.endswith(".py"):
return 0.03
return 0.01
def _search_relevance_reward(self, pattern: str) -> float:
pattern_lower = pattern.lower()
if any(signal in pattern_lower for signal in FLAKY_SIGNAL_PATTERNS):
return 0.04
return 0.01
def _search_spam_penalty(self, pattern: str, output: str) -> tuple[float, list[str]]:
penalty = 0.0
warnings: list[str] = []
pattern_key = " ".join(pattern.lower().split())
if pattern_key:
pattern_count = self.search_pattern_counts.get(pattern_key, 0) + 1
self.search_pattern_counts[pattern_key] = pattern_count
if pattern_count > 1:
repeat_penalty = min(0.02 * (pattern_count - 1), 0.12)
penalty += repeat_penalty
warnings.append(
f"Repeated search pattern ({pattern_count}x) penalty={repeat_penalty:.2f}."
)
context_hits = self._extract_search_hits(output)
context_key = f"{pattern_key}::{','.join(context_hits)}"
context_count = self.search_context_counts.get(context_key, 0) + 1
self.search_context_counts[context_key] = context_count
if context_count > 1:
context_penalty = min(0.03 * (context_count - 1), 0.15)
penalty += context_penalty
warnings.append(
f"Same search context repeated ({context_count}x) penalty={context_penalty:.2f}."
)
if self.consecutive_searches > 3:
streak_penalty = min(0.02 * (self.consecutive_searches - 3), 0.20)
penalty += streak_penalty
warnings.append(
f"Search-only streak={self.consecutive_searches} penalty={streak_penalty:.2f}."
)
return min(penalty, 0.35), warnings
def _extract_search_hits(self, output: str) -> tuple[str, ...]:
files: list[str] = []
seen: set[str] = set()
for raw_line in output.splitlines():
line = raw_line.strip()
if not line or line.startswith("No matches found") or line.startswith("Search "):
continue
filepath = line.split(":", 1)[0].strip()
if filepath.startswith("./"):
filepath = filepath[2:]
if not filepath.endswith(".py"):
continue
if filepath not in seen:
seen.add(filepath)
files.append(filepath)
if len(files) >= 5:
break
return tuple(files)
def _make_obs(self, tool_output: str | None = None) -> FlakySleuthObservation:
if not self.current_task:
raise RuntimeError("No current task available")
return FlakySleuthObservation(
repo_url=str(self.current_task.get("repo_url", "")),
test_name=str(self.current_task.get("test_name", "")),
test_code=str(self.current_task.get("test_code", ""))[:2000],
file_tree=self.sandbox.file_tree if self.sandbox else [],
tool_output=tool_output,
task_type=str(self.current_task.get("task_type", "classify")),
task_description=str(self.current_task.get("task_description", "Investigate the flaky test.")),
step_count=self.step_count,
done=False,
reward=None,
)