"""Feature extractor for PatchJudge. Extracts structured features from code patches using: - Unified diff parsing - AST analysis (Python) - Keyword/entity extraction for issue-patch alignment - Code quality signal detection """ import ast import re import textwrap from collections import Counter from typing import Optional from patchjudge.models import PatchExample, PatchFeatures class FeatureExtractor: """Extracts structured features from a patch for LLM evaluation.""" # Core/infrastructure files that are risky to modify CORE_FILE_PATTERNS = [ r'__init__\.py$', r'settings\.py$', r'conf\.py$', r'config\.py$', r'setup\.py$', r'setup\.cfg$', r'manage\.py$', r'urls\.py$', r'wsgi\.py$', r'asgi\.py$', r'migrations/', r'base\.py$', ] # Patterns that indicate hardcoded values HARDCODE_PATTERNS = [ r'return\s+["\'].*["\']', # return "hardcoded string" r'return\s+\d+\b', # return 42 r'==\s*["\'][^"\']{20,}', # comparing with long hardcoded string r'if\s+.*==\s*\d{3,}', # comparing with specific large numbers ] # Debug statement patterns DEBUG_PATTERNS = [ r'\bprint\s*\(', r'\bpdb\b', r'\bbreakpoint\s*\(', r'\bIPython\b', r'\bipdb\b', r'\bconsole\.log\b', r'\bdebugger\b', ] # TODO/FIXME patterns TODO_PATTERNS = [ r'\bTODO\b', r'\bFIXME\b', r'\bHACK\b', r'\bXXX\b', r'\bTEMP\b', r'\bWORKAROUND\b', ] def extract(self, example: PatchExample) -> PatchFeatures: """Extract all features from a PatchExample.""" patch = example.agent_patch features = PatchFeatures() # --- Diff statistics --- added_lines, removed_lines = self._parse_diff_lines(patch) features.num_files_changed = self._count_files_changed(patch) features.num_lines_added = len(added_lines) features.num_lines_removed = len(removed_lines) features.num_hunks = self._count_hunks(patch) # --- Code structure (AST-based for Python) --- if self._is_python_patch(patch): features.added_functions = self._extract_added_functions(added_lines) features.modified_functions = self._extract_modified_functions( patch, example.repo_context ) features.has_error_handling = self._check_error_handling(added_lines) features.has_edge_case_handling = self._check_edge_cases(added_lines) features.cyclomatic_complexity_delta = self._estimate_complexity_delta( added_lines, removed_lines ) features.nesting_depth_max = self._estimate_max_nesting(added_lines) # --- Issue-patch alignment --- issue_keywords = self._extract_issue_keywords(example.problem_statement) patch_text = '\n'.join(added_lines + removed_lines) addressed = self._match_keywords(issue_keywords, patch_text) features.issue_keywords_addressed = addressed features.issue_components_mentioned = self._extract_components( example.problem_statement ) if issue_keywords: features.keyword_coverage_ratio = len(addressed) / len(issue_keywords) # --- Code quality signals --- added_text = '\n'.join(added_lines) features.has_todos = self._check_patterns(added_text, self.TODO_PATTERNS) features.has_hardcoded_values = self._check_patterns(added_text, self.HARDCODE_PATTERNS) features.has_debug_statements = self._check_patterns(added_text, self.DEBUG_PATTERNS) features.style_violations = self._check_style_basic(added_lines) features.follows_project_style = len(features.style_violations) == 0 # --- Risk signals --- changed_files = self._get_changed_files(patch) features.modifies_core_files = self._check_core_files(changed_files) features.change_scope = self._assess_scope( features.num_files_changed, features.num_lines_added + features.num_lines_removed ) new_imports = self._extract_new_imports(added_lines) features.has_imports_added = len(new_imports) > 0 features.new_imports = new_imports features.touches_tests = any( 'test' in f.lower() for f in changed_files ) return features # ========================================================================= # Diff parsing # ========================================================================= def _parse_diff_lines(self, diff: str) -> tuple[list[str], list[str]]: """Parse unified diff into added and removed lines (without +/- prefix).""" added = [] removed = [] for line in diff.split('\n'): if line.startswith('+') and not line.startswith('+++'): added.append(line[1:]) elif line.startswith('-') and not line.startswith('---'): removed.append(line[1:]) return added, removed def _count_files_changed(self, diff: str) -> int: """Count number of files changed in the diff.""" return len(set( m.group(1) for m in re.finditer(r'^diff --git a/.+ b/(.+)$', diff, re.MULTILINE) )) def _count_hunks(self, diff: str) -> int: """Count number of hunks (@@ markers) in the diff.""" return len(re.findall(r'^@@\s', diff, re.MULTILINE)) def _get_changed_files(self, diff: str) -> list[str]: """Get list of changed file paths.""" return list(set( m.group(1) for m in re.finditer(r'^diff --git a/.+ b/(.+)$', diff, re.MULTILINE) )) def _is_python_patch(self, diff: str) -> bool: """Check if the patch modifies Python files.""" files = self._get_changed_files(diff) return any(f.endswith('.py') for f in files) # ========================================================================= # Code structure analysis # ========================================================================= def _extract_added_functions(self, added_lines: list[str]) -> list[str]: """Find function/method definitions in added lines.""" functions = [] for line in added_lines: match = re.match(r'\s*def\s+(\w+)\s*\(', line) if match: functions.append(match.group(1)) # Also check for class definitions match = re.match(r'\s*class\s+(\w+)', line) if match: functions.append(f"class:{match.group(1)}") return functions def _extract_modified_functions( self, diff: str, repo_context: dict ) -> list[str]: """Find functions that were modified (existed before, changed now).""" modified = [] # Parse hunk headers to find function context for match in re.finditer( r'^@@\s+.*\s+@@\s*(.*)$', diff, re.MULTILINE ): context = match.group(1).strip() # Hunk headers often contain the function name func_match = re.match(r'def\s+(\w+)', context) if func_match: modified.append(func_match.group(1)) class_match = re.match(r'class\s+(\w+)', context) if class_match: modified.append(f"class:{class_match.group(1)}") return list(set(modified)) def _check_error_handling(self, added_lines: list[str]) -> bool: """Check if added code includes error/exception handling.""" text = '\n'.join(added_lines) patterns = [ r'\btry\s*:', r'\bexcept\b', r'\braise\b', r'\bValueError\b', r'\bTypeError\b', r'\bKeyError\b', r'\bAssertionError\b', r'\bRuntimeError\b', r'\bif\s+.*\bis\s+None\b', r'\bif\s+not\b', ] return any(re.search(p, text) for p in patterns) def _check_edge_cases(self, added_lines: list[str]) -> bool: """Check if the patch handles edge cases.""" text = '\n'.join(added_lines) patterns = [ r'\bif\s+len\(', # Length checks r'\bif\s+not\s+\w+\s*:', # Empty checks r'\bif\s+\w+\s+is\s+None', # None checks r'\bif\s+.*<=?\s*0', # Zero/negative checks r'\bif\s+isinstance\(', # Type checks r'\bif\s+hasattr\(', # Attribute checks r'\bor\s+\[\]', # Default empty list r'\bor\s+\{\}', # Default empty dict r'\bor\s+""', # Default empty string r'\.get\(', # Dict .get() with default ] return sum(1 for p in patterns if re.search(p, text)) >= 2 def _estimate_complexity_delta( self, added_lines: list[str], removed_lines: list[str] ) -> int: """Estimate change in cyclomatic complexity.""" complexity_keywords = [ 'if', 'elif', 'else', 'for', 'while', 'try', 'except', 'and', 'or', 'with', 'assert' ] def count_complexity(lines): count = 0 for line in lines: stripped = line.strip() for kw in complexity_keywords: if re.search(rf'\b{kw}\b', stripped): count += 1 break # Count each line only once return count return count_complexity(added_lines) - count_complexity(removed_lines) def _estimate_max_nesting(self, added_lines: list[str]) -> int: """Estimate maximum nesting depth in added code.""" max_depth = 0 for line in added_lines: if line.strip(): # Count leading spaces (assume 4-space indent) stripped = line.lstrip() indent = len(line) - len(stripped) depth = indent // 4 max_depth = max(max_depth, depth) return max_depth # ========================================================================= # Issue-patch alignment # ========================================================================= def _extract_issue_keywords(self, problem_statement: str) -> list[str]: """Extract meaningful keywords from the issue description.""" # Remove code blocks text = re.sub(r'```[\s\S]*?```', '', problem_statement) text = re.sub(r'`[^`]+`', '', text) # Remove URLs text = re.sub(r'https?://\S+', '', text) # Extract potential identifiers (CamelCase, snake_case, etc.) identifiers = re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', text) # CamelCase identifiers += re.findall(r'\b\w+_\w+\b', text) # snake_case identifiers += re.findall(r'\b[A-Z]{2,}\b', text) # CONSTANTS # Extract error message keywords errors = re.findall(r'\b\w*(?:Error|Exception|Warning|Failure)\b', text) # Extract method/function names mentioned methods = re.findall(r'\.(\w+)\(', text) methods += re.findall(r'def\s+(\w+)', text) methods += re.findall(r'class\s+(\w+)', text) # Combine and deduplicate keywords = list(set(identifiers + errors + methods)) # Filter out common words stopwords = { 'the', 'is', 'in', 'it', 'to', 'and', 'or', 'not', 'this', 'that', 'with', 'for', 'are', 'was', 'has', 'have', 'had', 'when', 'would', 'should', 'could', 'will', 'can', 'may', 'one', 'two', 'use', 'used', 'using', 'see', 'also', } keywords = [k for k in keywords if k.lower() not in stopwords and len(k) > 2] return keywords[:30] # Cap at 30 keywords def _match_keywords(self, keywords: list[str], patch_text: str) -> list[str]: """Check which issue keywords appear in the patch.""" patch_lower = patch_text.lower() return [k for k in keywords if k.lower() in patch_lower] def _extract_components(self, problem_statement: str) -> list[str]: """Extract software component names from the issue.""" # Look for file paths files = re.findall(r'[\w/]+\.py\b', problem_statement) # Look for module/class references modules = re.findall(r'(?:from|import)\s+([\w.]+)', problem_statement) # Look for class.method patterns class_methods = re.findall(r'(\w+\.\w+)\(', problem_statement) return list(set(files + modules + class_methods))[:20] # ========================================================================= # Code quality signals # ========================================================================= def _check_patterns(self, text: str, patterns: list[str]) -> bool: """Check if any of the patterns match in the text.""" return any(re.search(p, text, re.IGNORECASE) for p in patterns) def _check_style_basic(self, added_lines: list[str]) -> list[str]: """Basic style checks without external tools.""" violations = [] for i, line in enumerate(added_lines): # Line too long (PEP 8: 79 chars, flexible to 100) if len(line) > 120: violations.append(f"line_too_long:{i}") # Trailing whitespace if line != line.rstrip(): violations.append(f"trailing_whitespace:{i}") # Mixed tabs and spaces if '\t' in line and ' ' in line: violations.append(f"mixed_indentation:{i}") # Multiple statements on one line (except for comprehensions) if ';' in line and 'for' not in line and 'import' not in line: violations.append(f"multiple_statements:{i}") # Deduplicate by type types_seen = set() unique = [] for v in violations: vtype = v.split(':')[0] if vtype not in types_seen: types_seen.add(vtype) unique.append(v) return unique def _check_core_files(self, changed_files: list[str]) -> bool: """Check if any changed files match core/infrastructure patterns.""" for f in changed_files: for pattern in self.CORE_FILE_PATTERNS: if re.search(pattern, f): return True return False def _assess_scope(self, num_files: int, total_lines: int) -> str: """Assess the scope of changes.""" if num_files <= 1 and total_lines <= 20: return "minimal" elif num_files <= 3 and total_lines <= 100: return "moderate" else: return "extensive" def _extract_new_imports(self, added_lines: list[str]) -> list[str]: """Extract newly added import statements.""" imports = [] for line in added_lines: stripped = line.strip() if stripped.startswith('import ') or stripped.startswith('from '): imports.append(stripped) return imports def extract_features_batch( examples: list[PatchExample], show_progress: bool = True, ) -> list[tuple[PatchExample, PatchFeatures]]: """Extract features for a batch of examples.""" extractor = FeatureExtractor() results = [] for i, ex in enumerate(examples): if show_progress and (i + 1) % 50 == 0: print(f" Extracted features for {i+1}/{len(examples)} examples") features = extractor.extract(ex) results.append((ex, features)) if show_progress: print(f" Done: {len(results)} examples processed") return results if __name__ == "__main__": import json from patchjudge.data_loader import SWEBenchLoader loader = SWEBenchLoader() examples = loader.build_dataset(sources=["coderforge"]) extractor = FeatureExtractor() # Extract features for first 5 examples for ex in examples[:5]: features = extractor.extract(ex) print(f"\n{'='*60}") print(f"Instance: {ex.instance_id}") print(f"Agent: {ex.agent_name}") print(f"Test passed: {ex.test_passed}") print(f"Files changed: {features.num_files_changed}") print(f"Lines +{features.num_lines_added}/-{features.num_lines_removed}") print(f"Hunks: {features.num_hunks}") print(f"Scope: {features.change_scope}") print(f"Error handling: {features.has_error_handling}") print(f"Edge cases: {features.has_edge_case_handling}") print(f"TODOs: {features.has_todos}") print(f"Debug stmts: {features.has_debug_statements}") print(f"Hardcoded: {features.has_hardcoded_values}") print(f"Core files: {features.modifies_core_files}") print(f"New imports: {features.new_imports}") print(f"Issue keywords addressed: {features.issue_keywords_addressed[:5]}") print(f"Keyword coverage: {features.keyword_coverage_ratio:.2f}") print(f"Style violations: {features.style_violations}") print(f"Added functions: {features.added_functions}") print(f"Modified functions: {features.modified_functions}")