VD10 commited on
Commit
0fc6b71
·
verified ·
1 Parent(s): 3bb15c1

Upload patchjudge/data_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. patchjudge/data_loader.py +380 -0
patchjudge/data_loader.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading pipeline for PatchJudge.
2
+
3
+ Loads SWE-bench Verified gold patches and agent-generated patches from:
4
+ 1. HuggingFace datasets (AlexCuadron O1, CoderForge)
5
+ 2. SWE-bench S3 bucket (139 verified agent submissions)
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import re
11
+ import logging
12
+ from pathlib import Path
13
+ from typing import Optional
14
+ from collections import defaultdict
15
+
16
+ from datasets import load_dataset
17
+ from patchjudge.models import PatchExample
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SWEBenchLoader:
23
+ """Loads SWE-bench Verified data and agent patches."""
24
+
25
+ def __init__(self, cache_dir: str = "data"):
26
+ self.cache_dir = Path(cache_dir)
27
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
28
+ self._gold_data = None # Lazy loaded
29
+
30
+ def load_gold_data(self) -> dict:
31
+ """Load SWE-bench Verified dataset. Returns {instance_id: row_dict}."""
32
+ if self._gold_data is not None:
33
+ return self._gold_data
34
+
35
+ logger.info("Loading SWE-bench Verified dataset...")
36
+ ds = load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
37
+ self._gold_data = {}
38
+ for row in ds:
39
+ self._gold_data[row["instance_id"]] = {
40
+ "instance_id": row["instance_id"],
41
+ "repo": row["repo"],
42
+ "problem_statement": row["problem_statement"],
43
+ "gold_patch": row["patch"],
44
+ "base_commit": row["base_commit"],
45
+ "test_patch": row["test_patch"],
46
+ "difficulty": row.get("difficulty", ""),
47
+ "hints_text": row.get("hints_text", ""),
48
+ }
49
+ logger.info(f"Loaded {len(self._gold_data)} SWE-bench Verified instances")
50
+ return self._gold_data
51
+
52
+ def load_coderforge_patches(self) -> list[PatchExample]:
53
+ """Load agent patches from CoderForge (Qwen3-Coder-32B, 500 instances)."""
54
+ logger.info("Loading CoderForge agent patches...")
55
+ ds = load_dataset(
56
+ "togethercomputer/CoderForge-Preview-32B-SWE-Bench-Verified-Evaluation-trajectories",
57
+ "trajectory", split="train"
58
+ )
59
+ gold = self.load_gold_data()
60
+ examples = []
61
+
62
+ for row in ds:
63
+ # Extract instance_id from ds JSON field
64
+ try:
65
+ ds_info = json.loads(row["ds"])
66
+ instance_id = ds_info["instance_id"]
67
+ except (json.JSONDecodeError, KeyError):
68
+ # Try extracting from trajectory_id
69
+ tid = row.get("trajectory_id", "")
70
+ instance_id = tid.rsplit("_run", 1)[0] if "_run" in tid else tid
71
+
72
+ if instance_id not in gold:
73
+ continue
74
+
75
+ agent_patch = row.get("output_patch", "")
76
+ if not agent_patch or agent_patch.strip() == "":
77
+ continue
78
+
79
+ g = gold[instance_id]
80
+ ex = PatchExample(
81
+ instance_id=instance_id,
82
+ repo=g["repo"],
83
+ problem_statement=g["problem_statement"],
84
+ gold_patch=g["gold_patch"],
85
+ agent_patch=agent_patch,
86
+ agent_name="CoderForge-Qwen3-32B",
87
+ test_passed=row.get("reward", 0.0) == 1.0,
88
+ base_commit=g["base_commit"],
89
+ difficulty=g["difficulty"],
90
+ )
91
+ examples.append(ex)
92
+
93
+ logger.info(f"Loaded {len(examples)} CoderForge patches "
94
+ f"({sum(1 for e in examples if e.test_passed)} passed)")
95
+ return examples
96
+
97
+ def load_o1_patches(self) -> list[PatchExample]:
98
+ """Load agent patches from OpenHands+O1 (500 instances)."""
99
+ logger.info("Loading OpenHands+O1 agent patches...")
100
+ ds = load_dataset(
101
+ "AlexCuadron/SWE-Bench-Verified-O1-native-tool-calling-reasoning-high-results",
102
+ split="test"
103
+ )
104
+ gold = self.load_gold_data()
105
+ examples = []
106
+
107
+ for row in ds:
108
+ issue_name = row.get("issue_name", "")
109
+ # issue_name format: "django__django-16454" — same as instance_id
110
+ instance_id = issue_name
111
+
112
+ if instance_id not in gold:
113
+ continue
114
+
115
+ agent_patch = row.get("patch", "")
116
+ if not agent_patch or agent_patch.strip() == "":
117
+ continue
118
+
119
+ g = gold[instance_id]
120
+ ex = PatchExample(
121
+ instance_id=instance_id,
122
+ repo=g["repo"],
123
+ problem_statement=g["problem_statement"],
124
+ gold_patch=g["gold_patch"],
125
+ agent_patch=agent_patch,
126
+ agent_name="OpenHands-O1-reasoning-high",
127
+ test_passed=row.get("resolved", False),
128
+ base_commit=g["base_commit"],
129
+ difficulty=g["difficulty"],
130
+ )
131
+ examples.append(ex)
132
+
133
+ logger.info(f"Loaded {len(examples)} O1 patches "
134
+ f"({sum(1 for e in examples if e.test_passed)} passed)")
135
+ return examples
136
+
137
+ def load_s3_agent_patches(
138
+ self,
139
+ agents: list[str] = None,
140
+ max_per_agent: int = 500,
141
+ ) -> list[PatchExample]:
142
+ """Load agent patches from SWE-bench S3 bucket.
143
+
144
+ Args:
145
+ agents: List of agent directory names in S3. Defaults to a curated set.
146
+ max_per_agent: Max patches per agent.
147
+ """
148
+ try:
149
+ import boto3
150
+ from botocore import UNSIGNED
151
+ from botocore.config import Config
152
+ import requests
153
+ except ImportError:
154
+ logger.warning("boto3 not available, skipping S3 patches")
155
+ return []
156
+
157
+ if agents is None:
158
+ agents = [
159
+ "20250225_sweagent_claude-3-7-sonnet",
160
+ "20241029_OpenHands-CodeAct-2.1-sonnet-20241022",
161
+ "20241028_agentless-1.5_gpt4o",
162
+ "20241108_autocoderover-v2.0-claude-3-5-sonnet-20241022",
163
+ "20240620_sweagent_claude3.5sonnet",
164
+ ]
165
+
166
+ gold = self.load_gold_data()
167
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
168
+ BUCKET = 'swe-bench-submissions'
169
+ examples = []
170
+
171
+ for agent_dir in agents:
172
+ logger.info(f"Loading patches from S3 agent: {agent_dir}")
173
+
174
+ # Get resolve labels from GitHub
175
+ resolved_ids = set()
176
+ try:
177
+ url = (
178
+ f"https://raw.githubusercontent.com/SWE-bench/experiments/"
179
+ f"main/evaluation/verified/{agent_dir}/results/results.json"
180
+ )
181
+ import requests
182
+ r = requests.get(url, timeout=10)
183
+ if r.status_code == 200:
184
+ resolved_ids = set(r.json().get("resolved", []))
185
+ logger.info(f" {agent_dir}: {len(resolved_ids)} resolved")
186
+ except Exception as e:
187
+ logger.warning(f" Could not load resolve labels for {agent_dir}: {e}")
188
+
189
+ # List instance directories
190
+ paginator = s3.get_paginator('list_objects_v2')
191
+ count = 0
192
+ try:
193
+ for page in paginator.paginate(
194
+ Bucket=BUCKET,
195
+ Prefix=f'verified/{agent_dir}/logs/',
196
+ Delimiter='/'
197
+ ):
198
+ for prefix_info in page.get('CommonPrefixes', []):
199
+ if count >= max_per_agent:
200
+ break
201
+
202
+ prefix = prefix_info['Prefix']
203
+ instance_id = prefix.rstrip('/').split('/')[-1]
204
+
205
+ if instance_id not in gold:
206
+ continue
207
+
208
+ # Download patch.diff
209
+ try:
210
+ obj = s3.get_object(
211
+ Bucket=BUCKET,
212
+ Key=f'verified/{agent_dir}/logs/{instance_id}/patch.diff'
213
+ )
214
+ agent_patch = obj['Body'].read().decode('utf-8')
215
+ except Exception:
216
+ continue
217
+
218
+ if not agent_patch.strip():
219
+ continue
220
+
221
+ g = gold[instance_id]
222
+ ex = PatchExample(
223
+ instance_id=instance_id,
224
+ repo=g["repo"],
225
+ problem_statement=g["problem_statement"],
226
+ gold_patch=g["gold_patch"],
227
+ agent_patch=agent_patch,
228
+ agent_name=agent_dir,
229
+ test_passed=instance_id in resolved_ids,
230
+ base_commit=g["base_commit"],
231
+ difficulty=g["difficulty"],
232
+ )
233
+ examples.append(ex)
234
+ count += 1
235
+ except Exception as e:
236
+ logger.warning(f" Error loading from S3 for {agent_dir}: {e}")
237
+
238
+ logger.info(f" Loaded {count} patches from {agent_dir}")
239
+
240
+ logger.info(f"Total S3 patches: {len(examples)} "
241
+ f"({sum(1 for e in examples if e.test_passed)} passed)")
242
+ return examples
243
+
244
+ def build_dataset(
245
+ self,
246
+ sources: list[str] = None,
247
+ min_examples: int = 100,
248
+ include_repo_context: bool = False,
249
+ s3_agents: list[str] = None,
250
+ ) -> list[PatchExample]:
251
+ """Build the unified PatchExample dataset from multiple sources.
252
+
253
+ Args:
254
+ sources: List of sources to use. Options: 'coderforge', 'o1', 's3'.
255
+ Defaults to ['coderforge', 'o1'].
256
+ min_examples: Minimum examples to collect.
257
+ include_repo_context: If True, attempt to clone repos and gather context.
258
+ s3_agents: Agent list for S3 source.
259
+ """
260
+ if sources is None:
261
+ sources = ["coderforge", "o1"]
262
+
263
+ all_examples = []
264
+
265
+ if "coderforge" in sources:
266
+ all_examples.extend(self.load_coderforge_patches())
267
+
268
+ if "o1" in sources:
269
+ all_examples.extend(self.load_o1_patches())
270
+
271
+ if "s3" in sources:
272
+ all_examples.extend(self.load_s3_agent_patches(agents=s3_agents))
273
+
274
+ # Deduplicate by (instance_id, agent_name)
275
+ seen = set()
276
+ unique = []
277
+ for ex in all_examples:
278
+ key = (ex.instance_id, ex.agent_name)
279
+ if key not in seen:
280
+ seen.add(key)
281
+ unique.append(ex)
282
+
283
+ logger.info(f"Total unique examples: {len(unique)} "
284
+ f"(passed: {sum(1 for e in unique if e.test_passed)}, "
285
+ f"failed: {sum(1 for e in unique if not e.test_passed)})")
286
+
287
+ if len(unique) < min_examples:
288
+ logger.warning(
289
+ f"Only {len(unique)} examples collected, "
290
+ f"below minimum of {min_examples}. "
291
+ f"Consider adding more sources."
292
+ )
293
+
294
+ return unique
295
+
296
+ def save_dataset(self, examples: list[PatchExample], filename: str = "patch_examples.jsonl"):
297
+ """Save examples to JSONL."""
298
+ path = self.cache_dir / filename
299
+ with open(path, 'w') as f:
300
+ for ex in examples:
301
+ f.write(json.dumps(ex.to_dict()) + "\n")
302
+ logger.info(f"Saved {len(examples)} examples to {path}")
303
+ return path
304
+
305
+ def load_saved_dataset(self, filename: str = "patch_examples.jsonl") -> list[PatchExample]:
306
+ """Load previously saved examples."""
307
+ path = self.cache_dir / filename
308
+ examples = []
309
+ with open(path) as f:
310
+ for line in f:
311
+ if line.strip():
312
+ examples.append(PatchExample.from_dict(json.loads(line)))
313
+ logger.info(f"Loaded {len(examples)} examples from {path}")
314
+ return examples
315
+
316
+
317
+ def extract_repo_context_from_diff(diff: str) -> list[str]:
318
+ """Extract filenames mentioned in a diff."""
319
+ files = []
320
+ for line in diff.split('\n'):
321
+ if line.startswith('diff --git'):
322
+ # Extract b/path
323
+ match = re.search(r'b/(.+)$', line)
324
+ if match:
325
+ files.append(match.group(1))
326
+ elif line.startswith('---') and not line.startswith('--- /dev/null'):
327
+ match = re.search(r'a/(.+)$', line)
328
+ if match:
329
+ files.append(match.group(1))
330
+ return list(set(files))
331
+
332
+
333
+ def get_diff_stats(diff: str) -> dict:
334
+ """Get basic stats from a unified diff."""
335
+ lines = diff.split('\n')
336
+ added = sum(1 for l in lines if l.startswith('+') and not l.startswith('+++'))
337
+ removed = sum(1 for l in lines if l.startswith('-') and not l.startswith('---'))
338
+ files = len(extract_repo_context_from_diff(diff))
339
+ hunks = sum(1 for l in lines if l.startswith('@@'))
340
+ return {
341
+ "lines_added": added,
342
+ "lines_removed": removed,
343
+ "files_changed": files,
344
+ "hunks": hunks,
345
+ }
346
+
347
+
348
+ if __name__ == "__main__":
349
+ logging.basicConfig(level=logging.INFO)
350
+ loader = SWEBenchLoader()
351
+
352
+ # Load from HF datasets (no S3 dependency)
353
+ examples = loader.build_dataset(sources=["coderforge", "o1"])
354
+
355
+ # Stats
356
+ passed = sum(1 for e in examples if e.test_passed)
357
+ failed = len(examples) - passed
358
+ repos = set(e.repo for e in examples)
359
+ agents = set(e.agent_name for e in examples)
360
+
361
+ print(f"\n{'='*60}")
362
+ print(f"PatchJudge Dataset Summary")
363
+ print(f"{'='*60}")
364
+ print(f"Total examples: {len(examples)}")
365
+ print(f" Test passed: {passed}")
366
+ print(f" Test failed: {failed}")
367
+ print(f"Unique instances: {len(set(e.instance_id for e in examples))}")
368
+ print(f"Unique repos: {len(repos)}")
369
+ print(f"Agent sources: {agents}")
370
+ print(f"\nDifficulty distribution:")
371
+
372
+ diff_counts = defaultdict(int)
373
+ for e in examples:
374
+ diff_counts[e.difficulty] += 1
375
+ for d, c in sorted(diff_counts.items()):
376
+ print(f" {d}: {c}")
377
+
378
+ # Save
379
+ path = loader.save_dataset(examples)
380
+ print(f"\nSaved to: {path}")