Spaces:
Sleeping
Sleeping
File size: 8,763 Bytes
90fc756 b83c8ad 90fc756 b83c8ad bf2775e 90fc756 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | from __future__ import annotations
import json
from pathlib import Path
from sql_query_reviewer.models import (
IdentifiedIssue,
SQLReviewAction,
SQLReviewObservation,
SQLReviewState,
StepResult,
TaskRecord,
)
from server.grader import grade_episode, match_issue, validate_fix
from server.reward import compute_reward
class SQLReviewEnvironment:
def __init__(self, task_directory: Path | None = None) -> None:
self.task_directory = task_directory or Path(__file__).resolve().parent.parent / "tasks"
self.tasks = self._load_tasks()
self.task_order = sorted(self.tasks)
self.current_task: TaskRecord | None = None
self.current_state: SQLReviewState | None = None
self._reset_index = 0
def available_task_ids(self) -> list[str]:
return list(self.task_order)
def reset(self, task_id: str | None = None) -> StepResult:
selected_task_id = task_id or self._next_task_id()
if selected_task_id not in self.tasks:
raise ValueError(f"Unknown task_id: {selected_task_id}")
self.current_task = self.tasks[selected_task_id]
self.current_state = SQLReviewState(task_id=self.current_task.task_id)
observation = self._build_observation(
feedback="Review this SQL query and identify correctness, performance, or security issues."
)
return StepResult(observation=observation, reward=0.0, done=False, info={})
def step(self, action: SQLReviewAction) -> StepResult:
task = self._require_task()
state = self._require_state()
if state.done:
raise RuntimeError("Episode already finished. Call reset() before taking more steps.")
found_ids = {issue.issue_id for issue in state.issues_identified}
reward = 0.0
info: dict[str, object] = {}
feedback = "No-op."
state.step_count += 1
if action.action_type == "identify_issue":
duplicate_issue, duplicate_score = match_issue(action, task.ground_truth_issues, set())
if duplicate_issue is not None and duplicate_issue.id in found_ids:
reward = compute_reward(action, duplicate_issue, duplicate_issue=True)
feedback = f"Issue '{duplicate_issue.id}' was already identified earlier in the episode."
info = {"match_score": round(duplicate_score, 3), "match_type": "duplicate", "issue_id": duplicate_issue.id}
else:
matched_issue, score = match_issue(action, task.ground_truth_issues, found_ids)
if matched_issue is None:
state.false_positive_count += 1
reward = compute_reward(action, None)
feedback = "No matching issue found for that description."
info = {"match_score": round(score, 3), "match_type": "none"}
else:
fix_valid = validate_fix(action.suggested_fix, matched_issue)
state.issues_identified.append(
IdentifiedIssue(
issue_id=matched_issue.id,
category=matched_issue.category,
description=matched_issue.description,
)
)
reward = compute_reward(action, matched_issue, fix_valid=fix_valid, issues_found_count=len(state.issues_identified), schema_available=bool(task.schema_info))
remaining = len(task.ground_truth_issues) - len(state.issues_identified)
feedback = f"Matched {matched_issue.category} issue '{matched_issue.id}'. {remaining} issue(s) remaining."
info = {
"match_score": round(score, 3),
"match_type": "fuzzy",
"severity": matched_issue.severity,
"issue_id": matched_issue.id,
"all_issues_found": remaining == 0,
}
if fix_valid and action.suggested_fix:
state.fixes_suggested.append(action.suggested_fix)
elif action.action_type == "suggest_fix":
if not state.issues_identified:
reward = compute_reward(action, None, has_previous_issue=False)
feedback = "Identify an issue before suggesting a fix."
else:
last_issue_id = state.issues_identified[-1].issue_id
last_issue = next(issue for issue in task.ground_truth_issues if issue.id == last_issue_id)
fix_valid = validate_fix(action.suggested_fix, last_issue)
reward = compute_reward(action, last_issue, fix_valid=fix_valid, has_previous_issue=True)
feedback = "Fix accepted for the last identified issue." if fix_valid else "Suggested fix did not match the expected remediation."
info = {"issue_id": last_issue.id, "fix_valid": fix_valid}
if fix_valid and action.suggested_fix:
state.fixes_suggested.append(action.suggested_fix)
elif action.action_type == "approve":
remaining_unfound = len(task.ground_truth_issues) - len(found_ids)
reward = compute_reward(action, None, remaining_unfound=remaining_unfound)
state.approved = True
state.done = True
feedback = (
"Query approved with full issue coverage."
if remaining_unfound == 0
else f"Query approved too early. {remaining_unfound} issue(s) were missed."
)
info = {"remaining_unfound": remaining_unfound}
else:
feedback = self._schema_feedback(task)
reward = compute_reward(action, None, schema_available=bool(task.schema_info))
info = {"context_shared": bool(task.schema_info)}
state.total_reward += reward
if state.step_count >= task.max_steps and not state.done:
state.done = True
feedback = f"{feedback} Maximum step count reached."
if state.done:
state.final_score = grade_episode(
found_issue_ids={issue.issue_id for issue in state.issues_identified},
ground_truth_issues=task.ground_truth_issues,
total_steps=state.step_count,
max_steps=task.max_steps,
false_positive_count=state.false_positive_count,
)
info["final_score"] = state.final_score
observation = self._build_observation(feedback=feedback)
return StepResult(observation=observation, reward=reward, done=state.done, info=info)
def state(self) -> SQLReviewState:
return self._require_state().model_copy(deep=True)
def _load_tasks(self) -> dict[str, TaskRecord]:
tasks: dict[str, TaskRecord] = {}
for file_path in sorted(self.task_directory.glob("*_tasks.json")):
with file_path.open("r", encoding="utf-8") as handle:
for raw_task in json.load(handle):
task = TaskRecord.model_validate(raw_task)
tasks[task.task_id] = task
if not tasks:
raise RuntimeError(f"No task files found in {self.task_directory}")
return tasks
def _next_task_id(self) -> str:
task_id = self.task_order[self._reset_index % len(self.task_order)]
self._reset_index += 1
return task_id
def _build_observation(self, feedback: str) -> SQLReviewObservation:
task = self._require_task()
state = self._require_state()
remaining_actions = max(task.max_steps - state.step_count, 0)
return SQLReviewObservation(
query=task.query,
schema_info=task.schema_info,
context=task.context,
issues_found_so_far=state.issues_identified,
remaining_actions=remaining_actions,
difficulty=task.difficulty,
feedback=feedback,
)
def _schema_feedback(self, task: TaskRecord) -> str:
if not task.schema_info:
return "No additional schema context is available for this task."
tables = ", ".join(sorted(task.schema_info))
return f"Schema context available for: {tables}."
def _require_task(self) -> TaskRecord:
if self.current_task is None:
raise RuntimeError("Environment has no active task. Call reset() first.")
return self.current_task
def _require_state(self) -> SQLReviewState:
if self.current_state is None:
raise RuntimeError("Environment has no active state. Call reset() first.")
return self.current_state
|