ncncomplete commited on
Commit
cf4ce1e
·
verified ·
1 Parent(s): 09e3e7a

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. inference.py +79 -0
  2. models.py +2 -0
  3. openenv.yaml +2 -0
  4. server/app.py +32 -0
  5. server/code_review_env_environment.py +179 -15
inference.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline inference script for code_review_env.
3
+ Uses OpenAI API to run an agent against all 3 tasks.
4
+ """
5
+ import os
6
+ import json
7
+ import requests
8
+ from openai import OpenAI
9
+
10
+ BASE_URL = os.getenv("ENV_URL", "http://localhost:8000")
11
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "dummy-key"))
12
+
13
+ TASKS = ["easy", "medium", "hard"]
14
+
15
+ def run_task(task_id: str) -> float:
16
+ # Reset environment
17
+ reset_resp = requests.post(f"{BASE_URL}/reset", json={"task_id": task_id})
18
+ obs = reset_resp.json()["observation"]
19
+
20
+ code_snippet = obs["code_snippet"]
21
+ task_description = obs["task_description"]
22
+
23
+ # Call LLM
24
+ prompt = f"""You are a code reviewer. {task_description}
25
+
26
+ Code to review:
27
+ ```python
28
+ {code_snippet}
29
+ ```
30
+
31
+ Respond ONLY with valid JSON, no markdown:
32
+ {{
33
+ "review": "your detailed analysis",
34
+ "bug_type": "syntax or logic or security or none",
35
+ "line_number": <integer>,
36
+ "confidence": <float 0.0-1.0>
37
+ }}"""
38
+
39
+ try:
40
+ response = client.chat.completions.create(
41
+ model="gpt-4o-mini",
42
+ messages=[{"role": "user", "content": prompt}],
43
+ temperature=0.0,
44
+ )
45
+ raw = response.choices[0].message.content.strip()
46
+ raw = raw.replace("```json", "").replace("```", "").strip()
47
+ action = json.loads(raw)
48
+ except Exception as e:
49
+ print(f"LLM error for {task_id}: {e}")
50
+ action = {"review": "unknown", "bug_type": "none", "line_number": -1, "confidence": 0.0}
51
+
52
+ # Step
53
+ step_resp = requests.post(f"{BASE_URL}/step", json={"action": action})
54
+ step_data = step_resp.json()
55
+
56
+ # Get grader score
57
+ grader_resp = requests.get(f"{BASE_URL}/grader?task_id={task_id}&episode_id=baseline")
58
+ score = grader_resp.json().get("score", 0.0)
59
+
60
+ print(f"Task: {task_id} | Score: {score} | Feedback: {step_data['observation'].get('previous_feedback', '')}")
61
+ return score
62
+
63
+ def main():
64
+ scores = {}
65
+ for task_id in TASKS:
66
+ scores[task_id] = run_task(task_id)
67
+
68
+ average = sum(scores.values()) / len(scores)
69
+ scores["average"] = round(average, 4)
70
+
71
+ print(f"\nBaseline Results: {json.dumps(scores, indent=2)}")
72
+
73
+ with open("baseline_scores.json", "w") as f:
74
+ json.dump(scores, f, indent=2)
75
+
76
+ return scores
77
+
78
+ if __name__ == "__main__":
79
+ main()
models.py CHANGED
@@ -30,6 +30,7 @@ class ReviewObservation(Observation):
30
  attempt_number: int # how many steps taken so far
31
  previous_feedback: str # feedback from last step, empty on reset
32
  done: bool # whether episode is complete
 
33
 
34
 
35
  class ReviewState(State):
@@ -42,3 +43,4 @@ class ReviewState(State):
42
  step_count: int = 0
43
  task_episode_id: str = ""
44
  cumulative_reward: float = 0.0
 
 
30
  attempt_number: int # how many steps taken so far
31
  previous_feedback: str # feedback from last step, empty on reset
32
  done: bool # whether episode is complete
33
+ hint: Optional[str] = None # optional hint for the agent
34
 
35
 
36
  class ReviewState(State):
 
43
  step_count: int = 0
44
  task_episode_id: str = ""
45
  cumulative_reward: float = 0.0
46
+ total_snippets: int = 4
openenv.yaml CHANGED
@@ -4,4 +4,6 @@ type: space
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 8000
 
 
7
 
 
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 8000
7
+ version: "1.0.0"
8
+ description: "AI agent environment for Python code review across syntax, logic, and security bug detection"
9
 
server/app.py CHANGED
@@ -34,6 +34,12 @@ def list_tasks():
34
  "bug_type": "string - syntax | logic | security | none",
35
  "line_number": "int - line with the bug, -1 if unknown",
36
  "confidence": "float - your confidence 0.0 to 1.0"
 
 
 
 
 
 
37
  }
38
  },
39
  {
@@ -45,6 +51,12 @@ def list_tasks():
45
  "bug_type": "string - syntax | logic | security | none",
46
  "line_number": "int - line with the bug, -1 if unknown",
47
  "confidence": "float - your confidence 0.0 to 1.0"
 
 
 
 
 
 
48
  }
49
  },
50
  {
@@ -56,11 +68,31 @@ def list_tasks():
56
  "bug_type": "string - syntax | logic | security | none",
57
  "line_number": "int - line with the bug, -1 if unknown",
58
  "confidence": "float - your confidence 0.0 to 1.0"
 
 
 
 
 
 
59
  }
60
  }
61
  ]
62
  }
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @app.get("/grader")
65
  def grader(task_id: str = Query("easy"), episode_id: str = Query(None)):
66
  """
 
34
  "bug_type": "string - syntax | logic | security | none",
35
  "line_number": "int - line with the bug, -1 if unknown",
36
  "confidence": "float - your confidence 0.0 to 1.0"
37
+ },
38
+ "example_action": {
39
+ "review": "Line 1 is missing a colon after the function definition. This is a syntax error.",
40
+ "bug_type": "syntax",
41
+ "line_number": 1,
42
+ "confidence": 0.95
43
  }
44
  },
45
  {
 
51
  "bug_type": "string - syntax | logic | security | none",
52
  "line_number": "int - line with the bug, -1 if unknown",
53
  "confidence": "float - your confidence 0.0 to 1.0"
54
+ },
55
+ "example_action": {
56
+ "review": "Line 5 has an index error: it should be max_val = numbers[i], not numbers[i - 1]. This is a logic bug.",
57
+ "bug_type": "logic",
58
+ "line_number": 5,
59
+ "confidence": 0.95
60
  }
61
  },
62
  {
 
68
  "bug_type": "string - syntax | logic | security | none",
69
  "line_number": "int - line with the bug, -1 if unknown",
70
  "confidence": "float - your confidence 0.0 to 1.0"
71
+ },
72
+ "example_action": {
73
+ "review": "Line 6 has a SQL injection vulnerability because the username is concatenated directly into the query without parameterized statements.",
74
+ "bug_type": "security",
75
+ "line_number": 6,
76
+ "confidence": 0.95
77
  }
78
  }
79
  ]
80
  }
81
 
82
+ @app.get("/info")
83
+ def info():
84
+ """
85
+ Returns information about the Code Review Environment.
86
+ Returns: environment name, version, description, number of tasks, and supported difficulty levels
87
+ """
88
+ return {
89
+ "name": "code_review_env",
90
+ "version": "1.0.0",
91
+ "description": "AI agent environment for Python code review across syntax, logic, and security bug detection",
92
+ "num_tasks": 3,
93
+ "difficulty_levels": ["easy", "medium", "hard"]
94
+ }
95
+
96
  @app.get("/grader")
97
  def grader(task_id: str = Query("easy"), episode_id: str = Query(None)):
98
  """
server/code_review_env_environment.py CHANGED
@@ -11,6 +11,7 @@ Code Review Environment — agent finds bugs in Python snippets.
11
 
12
  from __future__ import annotations
13
  import uuid
 
14
  from openenv.core.env_server.interfaces import Environment, Action, Observation
15
  from models import ReviewAction, ReviewObservation, ReviewState
16
  # ── Task bank ────────────────────────────────────────────────────────────────
@@ -22,7 +23,9 @@ TASKS = {
22
  "runtime errors. Specify the bug type, the line number where "
23
  "the error occurs, and explain what is wrong."
24
  ),
25
- "snippet": """\
 
 
26
  def calculate_average(numbers)
27
  total = 0
28
  for num in numbers:
@@ -32,8 +35,50 @@ def calculate_average(numbers)
32
  result = calculate_average([10, 20, 30])
33
  print(result)
34
  """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  "correct_bug_type": "syntax",
36
- "correct_line_number": 1,
37
  "correct_keywords": ["colon", "missing", "def", "syntax"],
38
  },
39
  "medium": {
@@ -42,7 +87,9 @@ print(result)
42
  "but produces incorrect output. Identify the logic bug, "
43
  "the line number, and explain why it is wrong."
44
  ),
45
- "snippet": """\
 
 
46
  def find_max(numbers):
47
  max_val = numbers[0]
48
  for i in range(len(numbers)):
@@ -52,8 +99,56 @@ def find_max(numbers):
52
 
53
  print(find_max([3, 7, 2, 9, 4]))
54
  """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  "correct_bug_type": "logic",
56
- "correct_line_number": 5,
57
  "correct_keywords": ["index", "i - 1", "off by one", "wrong", "logic"],
58
  },
59
  "hard": {
@@ -62,7 +157,9 @@ print(find_max([3, 7, 2, 9, 4]))
62
  "Identify the vulnerability type, the line number, and explain "
63
  "the security risk it introduces."
64
  ),
65
- "snippet": """\
 
 
66
  import sqlite3
67
 
68
  def get_user(username):
@@ -75,8 +172,57 @@ def get_user(username):
75
  user_input = input("Enter username: ")
76
  print(get_user(user_input))
77
  """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "correct_bug_type": "security",
79
- "correct_line_number": 6,
80
  "correct_keywords": ["sql injection", "injection", "concatenat", "unsanitized", "parameterized"],
81
  },
82
  }
@@ -125,14 +271,19 @@ def compute_reward(action: ReviewAction, task: dict, attempt: int) -> tuple[floa
125
  else:
126
  feedback_parts.append("✗ Explanation missing key concepts.")
127
 
 
 
 
 
 
128
  # Retry penalty
129
  if attempt > 1:
130
  penalty = 0.1 * (attempt - 1)
131
  reward -= penalty
132
  feedback_parts.append(f"⚠ Retry penalty: -{penalty:.1f}")
133
 
134
- # Clamp to 0.0-1.0 (max raw = 2.0, normalize)
135
- normalized = max(0.0, min(1.0, reward / 2.0))
136
  return round(normalized, 4), " ".join(feedback_parts)
137
 
138
 
@@ -151,23 +302,28 @@ class CodeReviewEnvironment(Environment):
151
  if task_id not in TASKS:
152
  task_id = "easy"
153
  task = TASKS[task_id]
 
 
 
154
  self._state = ReviewState(
155
  current_task_id=task_id,
156
- current_snippet=task["snippet"],
157
- correct_bug_type=task["correct_bug_type"],
158
- correct_line_number=task["correct_line_number"],
159
- correct_keywords=task["correct_keywords"],
160
  step_count=0,
161
  task_episode_id=str(uuid.uuid4()),
162
  cumulative_reward=0.0,
 
163
  )
164
  return ReviewObservation(
165
- code_snippet=task["snippet"],
166
  task_description=task["description"],
167
  task_id=task_id,
168
  attempt_number=0,
169
  previous_feedback="",
170
  done=False,
 
171
  )
172
 
173
  def step(self, action: Action) -> Observation:
@@ -175,7 +331,14 @@ class CodeReviewEnvironment(Environment):
175
  raise ValueError(f"Expected ReviewAction, got {type(action)}")
176
 
177
  self._state.step_count += 1
178
- task = TASKS[self._state.current_task_id]
 
 
 
 
 
 
 
179
 
180
  reward, feedback = compute_reward(
181
  action, task, self._state.step_count
@@ -189,11 +352,12 @@ class CodeReviewEnvironment(Environment):
189
 
190
  return ReviewObservation(
191
  code_snippet=self._state.current_snippet,
192
- task_description=task["description"],
193
  task_id=self._state.current_task_id,
194
  attempt_number=self._state.step_count,
195
  previous_feedback=feedback,
196
  done=done,
 
197
  )
198
 
199
  @property
 
11
 
12
  from __future__ import annotations
13
  import uuid
14
+ import random
15
  from openenv.core.env_server.interfaces import Environment, Action, Observation
16
  from models import ReviewAction, ReviewObservation, ReviewState
17
  # ── Task bank ────────────────────────────────────────────────────────────────
 
23
  "runtime errors. Specify the bug type, the line number where "
24
  "the error occurs, and explain what is wrong."
25
  ),
26
+ "snippets": [
27
+ {
28
+ "code": """\
29
  def calculate_average(numbers)
30
  total = 0
31
  for num in numbers:
 
35
  result = calculate_average([10, 20, 30])
36
  print(result)
37
  """,
38
+ "bug_type": "syntax",
39
+ "line_number": 1,
40
+ "keywords": ["colon", "missing", "def", "syntax"],
41
+ },
42
+ {
43
+ "code": """\
44
+ def add_numbers(a, b)
45
+ return a + b
46
+
47
+ result = add_numbers(5, 10)
48
+ print(result)
49
+ """,
50
+ "bug_type": "syntax",
51
+ "line_number": 1,
52
+ "keywords": ["colon", "missing", "syntax"],
53
+ },
54
+ {
55
+ "code": """\
56
+ def greet(name):
57
+ message = "Hello " + name
58
+ print(message)
59
+ # Missing return statement here
60
+
61
+ greet("Alice")
62
+ """,
63
+ "bug_type": "syntax",
64
+ "line_number": 4,
65
+ "keywords": ["return", "missing", "runtime"],
66
+ },
67
+ {
68
+ "code": """\
69
+ def process_data(data):
70
+ result = process(data)
71
+ return result
72
+
73
+ result_value = process_data([1, 2, 3])
74
+ print(result_value)
75
+ """,
76
+ "bug_type": "syntax",
77
+ "line_number": 2,
78
+ "keywords": ["undefined", "name", "not defined"],
79
+ },
80
+ ],
81
  "correct_bug_type": "syntax",
 
82
  "correct_keywords": ["colon", "missing", "def", "syntax"],
83
  },
84
  "medium": {
 
87
  "but produces incorrect output. Identify the logic bug, "
88
  "the line number, and explain why it is wrong."
89
  ),
90
+ "snippets": [
91
+ {
92
+ "code": """\
93
  def find_max(numbers):
94
  max_val = numbers[0]
95
  for i in range(len(numbers)):
 
99
 
100
  print(find_max([3, 7, 2, 9, 4]))
101
  """,
102
+ "bug_type": "logic",
103
+ "line_number": 5,
104
+ "keywords": ["index", "i - 1", "off by one", "wrong", "logic"],
105
+ },
106
+ {
107
+ "code": """\
108
+ def count_matches(lst, target):
109
+ count = 0
110
+ for item in lst:
111
+ if item = target:
112
+ count += 1
113
+ return count
114
+
115
+ print(count_matches([1, 2, 2, 3, 2], 2))
116
+ """,
117
+ "bug_type": "logic",
118
+ "line_number": 4,
119
+ "keywords": ["comparison", "==", "assignment", "operator"],
120
+ },
121
+ {
122
+ "code": """\
123
+ def reverse_string(s):
124
+ result = ""
125
+ for i in range(len(s) - 1, -1, -1):
126
+ for j in range(0, len(s)):
127
+ result += s[j]
128
+ return result
129
+
130
+ print(reverse_string("hello"))
131
+ """,
132
+ "bug_type": "logic",
133
+ "line_number": 3,
134
+ "keywords": ["loop", "bounds", "reversed", "range"],
135
+ },
136
+ {
137
+ "code": """\
138
+ def extract_substring(text):
139
+ start = 5
140
+ end = 2
141
+ return text[start:end]
142
+
143
+ result = extract_substring("Hello World")
144
+ print(result)
145
+ """,
146
+ "bug_type": "logic",
147
+ "line_number": 4,
148
+ "keywords": ["slice", "index", "string", "wrong"],
149
+ },
150
+ ],
151
  "correct_bug_type": "logic",
 
152
  "correct_keywords": ["index", "i - 1", "off by one", "wrong", "logic"],
153
  },
154
  "hard": {
 
157
  "Identify the vulnerability type, the line number, and explain "
158
  "the security risk it introduces."
159
  ),
160
+ "snippets": [
161
+ {
162
+ "code": """\
163
  import sqlite3
164
 
165
  def get_user(username):
 
172
  user_input = input("Enter username: ")
173
  print(get_user(user_input))
174
  """,
175
+ "bug_type": "security",
176
+ "line_number": 6,
177
+ "keywords": ["sql injection", "injection", "concatenat", "unsanitized", "parameterized"],
178
+ },
179
+ {
180
+ "code": """\
181
+ def authenticate(username, password):
182
+ admin_password = "SuperSecret123!"
183
+
184
+ if username == "admin" and password == admin_password:
185
+ return True
186
+ return False
187
+
188
+ user_input = input("Username: ")
189
+ pass_input = input("Password: ")
190
+ print(authenticate(user_input, pass_input))
191
+ """,
192
+ "bug_type": "security",
193
+ "line_number": 2,
194
+ "keywords": ["hardcoded", "password", "credentials", "secret", "plaintext"],
195
+ },
196
+ {
197
+ "code": """\
198
+ import pickle
199
+
200
+ def deserialize_data(data):
201
+ return pickle.loads(data)
202
+
203
+ received_data = input("Enter pickled data: ")
204
+ result = deserialize_data(received_data.encode())
205
+ print(result)
206
+ """,
207
+ "bug_type": "security",
208
+ "line_number": 4,
209
+ "keywords": ["pickle", "untrusted", "deserialize", "arbitrary code"],
210
+ },
211
+ {
212
+ "code": """\
213
+ def evaluate_expression(expr):
214
+ result = eval(expr)
215
+ return result
216
+
217
+ user_expr = input("Enter expression: ")
218
+ print(evaluate_expression(user_expr))
219
+ """,
220
+ "bug_type": "security",
221
+ "line_number": 2,
222
+ "keywords": ["eval", "arbitrary code", "user input", "dangerous"],
223
+ },
224
+ ],
225
  "correct_bug_type": "security",
 
226
  "correct_keywords": ["sql injection", "injection", "concatenat", "unsanitized", "parameterized"],
227
  },
228
  }
 
271
  else:
272
  feedback_parts.append("✗ Explanation missing key concepts.")
273
 
274
+ # Semantic similarity bonus (+0.25): if review length > 50 chars AND contains correct keyword
275
+ if len(action.review) > 50 and matched_keywords:
276
+ reward += 0.25
277
+ feedback_parts.append("✓ Semantic similarity bonus: detailed and accurate explanation.")
278
+
279
  # Retry penalty
280
  if attempt > 1:
281
  penalty = 0.1 * (attempt - 1)
282
  reward -= penalty
283
  feedback_parts.append(f"⚠ Retry penalty: -{penalty:.1f}")
284
 
285
+ # Clamp to 0.0-1.0 (max raw = 2.25, normalize)
286
+ normalized = max(0.0, min(1.0, reward / 2.25))
287
  return round(normalized, 4), " ".join(feedback_parts)
288
 
289
 
 
302
  if task_id not in TASKS:
303
  task_id = "easy"
304
  task = TASKS[task_id]
305
+ # Randomly select a snippet from the available snippets
306
+ selected_snippet = random.choice(task["snippets"])
307
+
308
  self._state = ReviewState(
309
  current_task_id=task_id,
310
+ current_snippet=selected_snippet["code"],
311
+ correct_bug_type=selected_snippet["bug_type"],
312
+ correct_line_number=selected_snippet["line_number"],
313
+ correct_keywords=selected_snippet["keywords"],
314
  step_count=0,
315
  task_episode_id=str(uuid.uuid4()),
316
  cumulative_reward=0.0,
317
+ total_snippets=len(task["snippets"]),
318
  )
319
  return ReviewObservation(
320
+ code_snippet=selected_snippet["code"],
321
  task_description=task["description"],
322
  task_id=task_id,
323
  attempt_number=0,
324
  previous_feedback="",
325
  done=False,
326
+ hint=None,
327
  )
328
 
329
  def step(self, action: Action) -> Observation:
 
331
  raise ValueError(f"Expected ReviewAction, got {type(action)}")
332
 
333
  self._state.step_count += 1
334
+ task_base = TASKS[self._state.current_task_id]
335
+
336
+ # Create task dict with current snippet's correct answers for compute_reward
337
+ task = {
338
+ "correct_bug_type": self._state.correct_bug_type,
339
+ "correct_line_number": self._state.correct_line_number,
340
+ "correct_keywords": self._state.correct_keywords,
341
+ }
342
 
343
  reward, feedback = compute_reward(
344
  action, task, self._state.step_count
 
352
 
353
  return ReviewObservation(
354
  code_snippet=self._state.current_snippet,
355
+ task_description=task_base["description"],
356
  task_id=self._state.current_task_id,
357
  attempt_number=self._state.step_count,
358
  previous_feedback=feedback,
359
  done=done,
360
+ hint=None,
361
  )
362
 
363
  @property