Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Code Review Environment β agent finds bugs in Python snippets. | |
| 3 tasks: syntax errors (easy) β logic bugs (medium) β security vulns (hard). | |
| """ | |
| from __future__ import annotations | |
| import uuid | |
| import random | |
| from openenv.core.env_server.interfaces import Environment, Action, Observation | |
| from models import ReviewAction, ReviewObservation, ReviewState | |
| # ββ Task bank ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TASKS = { | |
| "easy": { | |
| "description": ( | |
| "Review the following Python code and identify any syntax or " | |
| "runtime errors. Specify the bug type, the line number where " | |
| "the error occurs, and explain what is wrong." | |
| ), | |
| "snippets": [ | |
| { | |
| "code": """\ | |
| def calculate_average(numbers) | |
| total = 0 | |
| for num in numbers: | |
| total += num | |
| return total / len(numbers) | |
| result = calculate_average([10, 20, 30]) | |
| print(result) | |
| """, | |
| "bug_type": "syntax", | |
| "line_number": 1, | |
| "keywords": ["colon", "missing", "def", "syntax"], | |
| }, | |
| { | |
| "code": """\ | |
| def add_numbers(a, b) | |
| return a + b | |
| result = add_numbers(5, 10) | |
| print(result) | |
| """, | |
| "bug_type": "syntax", | |
| "line_number": 1, | |
| "keywords": ["colon", "missing", "syntax"], | |
| }, | |
| { | |
| "code": """\ | |
| def greet(name): | |
| message = "Hello " + name | |
| print(message) | |
| # Missing return statement here | |
| greet("Alice") | |
| """, | |
| "bug_type": "syntax", | |
| "line_number": 4, | |
| "keywords": ["return", "missing", "runtime"], | |
| }, | |
| { | |
| "code": """\ | |
| def process_data(data): | |
| result = process(data) | |
| return result | |
| result_value = process_data([1, 2, 3]) | |
| print(result_value) | |
| """, | |
| "bug_type": "syntax", | |
| "line_number": 2, | |
| "keywords": ["undefined", "name", "not defined"], | |
| }, | |
| ], | |
| "correct_bug_type": "syntax", | |
| "correct_keywords": ["colon", "missing", "def", "syntax"], | |
| }, | |
| "medium": { | |
| "description": ( | |
| "Review the following Python code. It runs without crashing " | |
| "but produces incorrect output. Identify the logic bug, " | |
| "the line number, and explain why it is wrong." | |
| ), | |
| "snippets": [ | |
| { | |
| "code": """\ | |
| def find_max(numbers): | |
| max_val = numbers[0] | |
| for i in range(len(numbers)): | |
| if numbers[i] > max_val: | |
| max_val = numbers[i - 1] | |
| return max_val | |
| print(find_max([3, 7, 2, 9, 4])) | |
| """, | |
| "bug_type": "logic", | |
| "line_number": 5, | |
| "keywords": ["index", "i - 1", "off by one", "wrong", "logic"], | |
| }, | |
| { | |
| "code": """\ | |
| def count_matches(lst, target): | |
| count = 0 | |
| for item in lst: | |
| if item = target: | |
| count += 1 | |
| return count | |
| print(count_matches([1, 2, 2, 3, 2], 2)) | |
| """, | |
| "bug_type": "logic", | |
| "line_number": 4, | |
| "keywords": ["comparison", "==", "assignment", "operator"], | |
| }, | |
| { | |
| "code": """\ | |
| def reverse_string(s): | |
| result = "" | |
| for i in range(len(s) - 1, -1, -1): | |
| for j in range(0, len(s)): | |
| result += s[j] | |
| return result | |
| print(reverse_string("hello")) | |
| """, | |
| "bug_type": "logic", | |
| "line_number": 3, | |
| "keywords": ["loop", "bounds", "reversed", "range"], | |
| }, | |
| { | |
| "code": """\ | |
| def extract_substring(text): | |
| start = 5 | |
| end = 2 | |
| return text[start:end] | |
| result = extract_substring("Hello World") | |
| print(result) | |
| """, | |
| "bug_type": "logic", | |
| "line_number": 4, | |
| "keywords": ["slice", "index", "string", "wrong"], | |
| }, | |
| ], | |
| "correct_bug_type": "logic", | |
| "correct_keywords": ["index", "i - 1", "off by one", "wrong", "logic"], | |
| }, | |
| "hard": { | |
| "description": ( | |
| "Review the following Python code for security vulnerabilities. " | |
| "Identify the vulnerability type, the line number, and explain " | |
| "the security risk it introduces." | |
| ), | |
| "snippets": [ | |
| { | |
| "code": """\ | |
| import sqlite3 | |
| def get_user(username): | |
| conn = sqlite3.connect('users.db') | |
| cursor = conn.cursor() | |
| query = "SELECT * FROM users WHERE username = '" + username + "'" | |
| cursor.execute(query) | |
| return cursor.fetchone() | |
| user_input = input("Enter username: ") | |
| print(get_user(user_input)) | |
| """, | |
| "bug_type": "security", | |
| "line_number": 6, | |
| "keywords": ["sql injection", "injection", "concatenat", "unsanitized", "parameterized"], | |
| }, | |
| { | |
| "code": """\ | |
| def authenticate(username, password): | |
| admin_password = "SuperSecret123!" | |
| if username == "admin" and password == admin_password: | |
| return True | |
| return False | |
| user_input = input("Username: ") | |
| pass_input = input("Password: ") | |
| print(authenticate(user_input, pass_input)) | |
| """, | |
| "bug_type": "security", | |
| "line_number": 2, | |
| "keywords": ["hardcoded", "password", "credentials", "secret", "plaintext"], | |
| }, | |
| { | |
| "code": """\ | |
| import pickle | |
| def deserialize_data(data): | |
| return pickle.loads(data) | |
| received_data = input("Enter pickled data: ") | |
| result = deserialize_data(received_data.encode()) | |
| print(result) | |
| """, | |
| "bug_type": "security", | |
| "line_number": 4, | |
| "keywords": ["pickle", "untrusted", "deserialize", "arbitrary code"], | |
| }, | |
| { | |
| "code": """\ | |
| def evaluate_expression(expr): | |
| result = eval(expr) | |
| return result | |
| user_expr = input("Enter expression: ") | |
| print(evaluate_expression(user_expr)) | |
| """, | |
| "bug_type": "security", | |
| "line_number": 2, | |
| "keywords": ["eval", "arbitrary code", "user input", "dangerous"], | |
| }, | |
| ], | |
| "correct_bug_type": "security", | |
| "correct_keywords": ["sql injection", "injection", "concatenat", "unsanitized", "parameterized"], | |
| }, | |
| } | |
| MAX_STEPS = 3 | |
| # ββ Reward function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_reward(action: ReviewAction, task: dict, attempt: int) -> tuple[float, str]: | |
| """ | |
| Partial progress reward β not binary. | |
| Returns (reward_float, feedback_string). | |
| """ | |
| reward = 0.0 | |
| feedback_parts = [] | |
| # Bug type match (+1.0) | |
| if action.bug_type.lower() == task["correct_bug_type"]: | |
| reward += 1.0 | |
| feedback_parts.append("β Correct bug type identified.") | |
| else: | |
| reward -= 0.3 | |
| feedback_parts.append( | |
| f"β Wrong bug type. Got '{action.bug_type}', " | |
| f"expected '{task['correct_bug_type']}'." | |
| ) | |
| # Line number match (+0.5) | |
| if action.line_number == task["correct_line_number"]: | |
| reward += 0.5 | |
| feedback_parts.append("β Correct line number.") | |
| else: | |
| feedback_parts.append( | |
| f"β Wrong line number. Got {action.line_number}, " | |
| f"expected {task['correct_line_number']}." | |
| ) | |
| # Keyword quality check (+0.5) | |
| review_lower = action.review.lower() | |
| matched_keywords = [ | |
| kw for kw in task["correct_keywords"] if kw in review_lower | |
| ] | |
| if matched_keywords: | |
| reward += 0.5 | |
| feedback_parts.append(f"β Good explanation (matched: {matched_keywords}).") | |
| else: | |
| feedback_parts.append("β Explanation missing key concepts.") | |
| # Semantic similarity bonus (+0.25): if review length > 50 chars AND contains correct keyword | |
| if len(action.review) > 50 and matched_keywords: | |
| reward += 0.25 | |
| feedback_parts.append("β Semantic similarity bonus: detailed and accurate explanation.") | |
| # Retry penalty | |
| if attempt > 1: | |
| penalty = 0.1 * (attempt - 1) | |
| reward -= penalty | |
| feedback_parts.append(f"β Retry penalty: -{penalty:.1f}") | |
| # Clamp to 0.0-1.0 (max raw = 2.25, normalize) | |
| normalized = max(0.0, min(1.0, reward / 2.25)) | |
| return round(normalized, 4), " ".join(feedback_parts) | |
| # ββ Environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CodeReviewEnvironment(Environment): | |
| """ | |
| Code Review Environment. | |
| Agent reviews Python snippets across 3 difficulty tasks. | |
| """ | |
| def __init__(self): | |
| self._state = ReviewState() | |
| def reset(self, task_id: str = "easy") -> Observation: | |
| if task_id not in TASKS: | |
| task_id = "easy" | |
| task = TASKS[task_id] | |
| # Randomly select a snippet from the available snippets | |
| selected_snippet = random.choice(task["snippets"]) | |
| self._state = ReviewState( | |
| current_task_id=task_id, | |
| current_snippet=selected_snippet["code"], | |
| correct_bug_type=selected_snippet["bug_type"], | |
| correct_line_number=selected_snippet["line_number"], | |
| correct_keywords=selected_snippet["keywords"], | |
| step_count=0, | |
| task_episode_id=str(uuid.uuid4()), | |
| cumulative_reward=0.0, | |
| total_snippets=len(task["snippets"]), | |
| ) | |
| return ReviewObservation( | |
| code_snippet=selected_snippet["code"], | |
| task_description=task["description"], | |
| task_id=task_id, | |
| attempt_number=0, | |
| previous_feedback="", | |
| done=False, | |
| hint=None, | |
| ) | |
| def step(self, action: Action) -> Observation: | |
| if not isinstance(action, ReviewAction): | |
| raise ValueError(f"Expected ReviewAction, got {type(action)}") | |
| self._state.step_count += 1 | |
| task_base = TASKS[self._state.current_task_id] | |
| # Create task dict with current snippet's correct answers for compute_reward | |
| task = { | |
| "correct_bug_type": self._state.correct_bug_type, | |
| "correct_line_number": self._state.correct_line_number, | |
| "correct_keywords": self._state.correct_keywords, | |
| } | |
| reward, feedback = compute_reward( | |
| action, task, self._state.step_count | |
| ) | |
| self._state.cumulative_reward += reward | |
| done = ( | |
| reward >= 0.75 # good enough answer | |
| or self._state.step_count >= MAX_STEPS | |
| ) | |
| return ReviewObservation( | |
| code_snippet=self._state.current_snippet, | |
| task_description=task_base["description"], | |
| task_id=self._state.current_task_id, | |
| attempt_number=self._state.step_count, | |
| previous_feedback=feedback, | |
| done=done, | |
| hint=None, | |
| ) | |
| def state(self) -> ReviewState: | |
| return self._state | |