File size: 5,047 Bytes
3bb15c1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Data models for PatchJudge."""
from dataclasses import dataclass, field, asdict
from typing import Optional
import json
@dataclass
class PatchExample:
"""Unified format for a single patch evaluation example."""
instance_id: str
repo: str
problem_statement: str
gold_patch: str # Human-written reference patch
agent_patch: str # AI-generated patch
agent_name: str # Which agent produced this
test_passed: bool # Did the agent's patch pass tests?
base_commit: str
repo_context: dict = field(default_factory=dict) # {filename: file_content}
difficulty: str = ""
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, d: dict) -> "PatchExample":
return cls(**d)
def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=2)
@classmethod
def from_json(cls, s: str) -> "PatchExample":
return cls.from_dict(json.loads(s))
@dataclass
class PatchFeatures:
"""Structured features extracted from a patch."""
# Diff statistics
num_files_changed: int = 0
num_lines_added: int = 0
num_lines_removed: int = 0
num_hunks: int = 0
# Code structure
added_functions: list = field(default_factory=list)
modified_functions: list = field(default_factory=list)
has_error_handling: bool = False
has_edge_case_handling: bool = False
# Issue-patch alignment
issue_keywords_addressed: list = field(default_factory=list)
issue_components_mentioned: list = field(default_factory=list)
keyword_coverage_ratio: float = 0.0
# Code quality signals
has_todos: bool = False
has_hardcoded_values: bool = False
has_debug_statements: bool = False
follows_project_style: bool = True
style_violations: list = field(default_factory=list)
# Risk signals
modifies_core_files: bool = False
change_scope: str = "minimal" # minimal, moderate, extensive
has_imports_added: bool = False
new_imports: list = field(default_factory=list)
touches_tests: bool = False
# Complexity
cyclomatic_complexity_delta: int = 0
nesting_depth_max: int = 0
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class DimensionScore:
"""Score for a single evaluation dimension."""
score: int # 0-10
reasoning: str
flags: list = field(default_factory=list)
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class JudgeResult:
"""Complete judge evaluation output."""
merge_score: float # 0-100 weighted score
dimension_scores: dict = field(default_factory=dict) # dim_name -> DimensionScore
raw_output: str = ""
features: Optional[PatchFeatures] = None
model_used: str = ""
@property
def correctness(self) -> int:
return self.dimension_scores.get("correctness", {}).get("score", 0)
@property
def completeness(self) -> int:
return self.dimension_scores.get("completeness", {}).get("score", 0)
@property
def code_quality(self) -> int:
return self.dimension_scores.get("code_quality", {}).get("score", 0)
@property
def non_regression_risk(self) -> int:
return self.dimension_scores.get("non_regression_risk", {}).get("score", 0)
@property
def merge_readiness(self) -> int:
return self.dimension_scores.get("merge_readiness", {}).get("score", 0)
def to_dict(self) -> dict:
d = {
"merge_score": self.merge_score,
"dimension_scores": self.dimension_scores,
"raw_output": self.raw_output,
"model_used": self.model_used,
}
if self.features:
d["features"] = self.features.to_dict()
return d
def summary(self) -> str:
lines = [f"MergeScore: {self.merge_score:.1f}/100"]
for dim, data in self.dimension_scores.items():
score = data.get("score", "?")
lines.append(f" {dim}: {score}/10")
if data.get("flags"):
for flag in data["flags"]:
lines.append(f" ⚠ {flag}")
return "\n".join(lines)
@dataclass
class ValidationResult:
"""Result of validating PatchJudge against ground truth."""
total_examples: int = 0
# METR alignment: fraction of test-passing patches scoring below 50
test_passing_below_50_pct: float = 0.0
# Correlation metrics
score_resolved_correlation: float = 0.0
mean_score_resolved: float = 0.0
mean_score_unresolved: float = 0.0
# Known-bad detection
known_bad_detected: int = 0
known_bad_total: int = 0
known_bad_detection_rate: float = 0.0
# Score distribution
score_mean: float = 0.0
score_std: float = 0.0
score_median: float = 0.0
# Per-dimension stats
dimension_stats: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return asdict(self)
|