DouDou commited on
Commit
315961d
·
verified ·
1 Parent(s): 8d5d80e

Upload data3/generate_programming_problems.py with huggingface_hub

Browse files
data3/generate_programming_problems.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate programming problems from function_dataset_v2.csv using Gemini API.
4
+ Filters by relevance score and controls API cost.
5
+ """
6
+
7
+ import csv
8
+ import json
9
+ import os
10
+ import sys
11
+ import vertexai
12
+ from vertexai.generative_models import GenerativeModel
13
+ from datetime import datetime
14
+ from typing import Dict, Optional, Tuple, List
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+ import threading
17
+
18
+
19
+ # Configuration
20
+ PROJECT_ID = "tangou"
21
+ MODEL_NAME = "gemini-2.5-flash-lite" # Using flash model for cost efficiency
22
+ MIN_RELEVANCE_SCORE = 1 # Only process functions with score >= 60
23
+ MAX_BUDGET_USD = 50.0 # Maximum budget in USD
24
+
25
+ # Gemini 2.0 Flash pricing (as of Dec 2024)
26
+ # https://cloud.google.com/vertex-ai/generative-ai/pricing
27
+ INPUT_PRICE_PER_MILLION = 0.1 # Free tier or promotional pricing
28
+ OUTPUT_PRICE_PER_MILLION = 0.4 # Free tier or promotional pricing
29
+
30
+ # If using Gemini 1.5 Flash instead:
31
+ # INPUT_PRICE_PER_MILLION = 0.075
32
+ # OUTPUT_PRICE_PER_MILLION = 0.30
33
+
34
+ PROMPT_TEMPLATE = """You are an expert in scientific computing and computational chemistry/biology/physics. Please create a high-quality programming problem inspired by the following code snippet from a real scientific computing project.
35
+
36
+ The problem should focus on scientific computing concepts such as:
37
+ - Numerical algorithms and simulations
38
+ - Data analysis and visualization
39
+ - Mathematical modeling
40
+ - Scientific data processing
41
+ - Computational methods in chemistry, biology, or physics
42
+
43
+ Code snippet for inspiration:
44
+ ```python
45
+ {code}
46
+ ```
47
+
48
+ Present your output in two distinct sections:
49
+
50
+ [Problem Description]
51
+ Create a **completely self-contained** problem description that:
52
+ - Does NOT directly reference the code snippet above
53
+ - Provides all necessary context and background
54
+ - Clearly states what needs to be implemented
55
+ - Specifies input/output format and constraints
56
+ - Is inspired by the scientific computing concepts in the code but creates a NEW, interesting problem
57
+ - Assumes common programming knowledge but explains any domain-specific concepts
58
+
59
+ [Solution]
60
+ Provide a comprehensive, **correct** Python solution that:
61
+ - Accurately solves the problem described
62
+ - Includes clear comments explaining the approach
63
+ - Uses appropriate scientific computing libraries (numpy, scipy, etc.) when relevant
64
+ - Is complete and runnable
65
+ - Follows best practices for scientific computing
66
+
67
+ Remember: The problem should be INSPIRED by the code, not a direct copy. Create something educational and interesting for scientific computing practitioners."""
68
+
69
+
70
+ class GeminiAPIClient:
71
+ """Client for Gemini API with cost tracking."""
72
+
73
+ def __init__(self, project_id: str, model_name: str):
74
+ """Initialize Gemini API client.
75
+
76
+ Args:
77
+ project_id: Google Cloud project ID
78
+ model_name: Name of the Gemini model to use
79
+ """
80
+ vertexai.init(project=project_id)
81
+ self.model = GenerativeModel(model_name)
82
+ self.total_input_tokens = 0
83
+ self.total_output_tokens = 0
84
+ self.total_requests = 0
85
+ self.total_cost = 0.0
86
+ self._lock = threading.Lock() # Thread safety for concurrent requests
87
+
88
+ def generate_content(self, prompt: str) -> Tuple[str, Dict]:
89
+ """Generate content using Gemini API and track usage.
90
+
91
+ Args:
92
+ prompt: The prompt to send to the API
93
+
94
+ Returns:
95
+ Tuple of (response_text, usage_info)
96
+ usage_info contains: input_tokens, output_tokens, cost
97
+ """
98
+ try:
99
+ response = self.model.generate_content(prompt)
100
+ usage_metadata = response.usage_metadata
101
+
102
+ input_tokens = usage_metadata.prompt_token_count
103
+ output_tokens = usage_metadata.candidates_token_count
104
+
105
+ # Calculate cost
106
+ input_cost = (input_tokens / 1_000_000) * INPUT_PRICE_PER_MILLION
107
+ output_cost = (output_tokens / 1_000_000) * OUTPUT_PRICE_PER_MILLION
108
+ request_cost = input_cost + output_cost
109
+
110
+ # Update totals (thread-safe)
111
+ with self._lock:
112
+ self.total_input_tokens += input_tokens
113
+ self.total_output_tokens += output_tokens
114
+ self.total_requests += 1
115
+ self.total_cost += request_cost
116
+
117
+ usage_info = {
118
+ 'input_tokens': input_tokens,
119
+ 'output_tokens': output_tokens,
120
+ 'total_tokens': input_tokens + output_tokens,
121
+ 'input_cost': input_cost,
122
+ 'output_cost': output_cost,
123
+ 'request_cost': request_cost
124
+ }
125
+
126
+ return response.text, usage_info
127
+
128
+ except Exception as e:
129
+ print(f"Error generating content: {e}")
130
+ raise
131
+
132
+ def get_total_usage(self) -> Dict:
133
+ """Get total usage statistics.
134
+
135
+ Returns:
136
+ Dictionary with total usage information
137
+ """
138
+ return {
139
+ 'total_requests': self.total_requests,
140
+ 'total_input_tokens': self.total_input_tokens,
141
+ 'total_output_tokens': self.total_output_tokens,
142
+ 'total_tokens': self.total_input_tokens + self.total_output_tokens,
143
+ 'total_cost': self.total_cost
144
+ }
145
+
146
+ def print_usage_summary(self):
147
+ """Print a summary of API usage and costs."""
148
+ usage = self.get_total_usage()
149
+ print("\n" + "="*70)
150
+ print("API USAGE SUMMARY")
151
+ print("="*70)
152
+ print(f"Total Requests: {usage['total_requests']}")
153
+ print(f"Total Input Tokens: {usage['total_input_tokens']:,}")
154
+ print(f"Total Output Tokens: {usage['total_output_tokens']:,}")
155
+ print(f"Total Tokens: {usage['total_tokens']:,}")
156
+ print(f"\nTotal Cost: ${usage['total_cost']:.6f}")
157
+ print(f"Budget Remaining: ${MAX_BUDGET_USD - usage['total_cost']:.6f}")
158
+ print("="*70)
159
+
160
+
161
+ def process_function_dataset(
162
+ input_file: str,
163
+ output_file: str,
164
+ min_score: int = MIN_RELEVANCE_SCORE,
165
+ max_budget: float = MAX_BUDGET_USD,
166
+ max_samples: Optional[int] = None,
167
+ start_from: int = 0,
168
+ max_workers: int = 5
169
+ ):
170
+ """Process function dataset and generate programming problems.
171
+
172
+ Args:
173
+ input_file: Path to function_dataset_v2.csv
174
+ output_file: Path to output JSONL file
175
+ min_score: Minimum relevance score to process
176
+ max_budget: Maximum budget in USD
177
+ max_samples: Maximum number of samples to process (None for all)
178
+ start_from: Skip first N rows (for resuming)
179
+ max_workers: Maximum number of concurrent workers (default: 5)
180
+ """
181
+ print(f"Starting programming problem generation...")
182
+ print(f"Input: {input_file}")
183
+ print(f"Output: {output_file}")
184
+ print(f"Min Relevance Score: {min_score}")
185
+ print(f"Max Budget: ${max_budget:.2f}")
186
+ print(f"Max Workers: {max_workers}")
187
+ if max_samples:
188
+ print(f"Max Samples: {max_samples}")
189
+ print(f"Starting from row: {start_from}")
190
+ print()
191
+
192
+ # Read already processed row numbers from output file
193
+ processed_rows = set()
194
+ if os.path.exists(output_file):
195
+ print(f"Checking existing output file for already processed rows...")
196
+ try:
197
+ with open(output_file, 'r', encoding='utf-8') as f:
198
+ for line in f:
199
+ try:
200
+ data = json.loads(line.strip())
201
+ if 'row_number' in data:
202
+ processed_rows.add(data['row_number'])
203
+ except json.JSONDecodeError:
204
+ continue
205
+ print(f"Found {len(processed_rows)} already processed rows. These will be skipped.")
206
+ except Exception as e:
207
+ print(f"Warning: Could not read existing output file: {e}")
208
+ else:
209
+ print(f"No existing output file found. Will create new file.")
210
+ print()
211
+
212
+ # Initialize Gemini client
213
+ client = GeminiAPIClient(PROJECT_ID, MODEL_NAME)
214
+
215
+ # Statistics
216
+ total_rows = 0
217
+ processed = 0
218
+ skipped_low_score = 0
219
+ skipped_no_code = 0
220
+ skipped_already_processed = 0
221
+ errors = 0
222
+
223
+ # Prepare tasks to process
224
+ tasks = []
225
+
226
+ with open(input_file, 'r', encoding='utf-8') as infile:
227
+ reader = csv.DictReader(infile)
228
+
229
+ for row in reader:
230
+ total_rows += 1
231
+
232
+ # Skip if resuming
233
+ if total_rows <= start_from:
234
+ continue
235
+
236
+ # Skip if already processed
237
+ if total_rows in processed_rows:
238
+ skipped_already_processed += 1
239
+ continue
240
+
241
+ # Check if we've reached max samples
242
+ if max_samples and len(tasks) >= max_samples:
243
+ break
244
+
245
+ # Filter by relevance score
246
+ try:
247
+ relevance_score = int(row.get('relevance_score', 0))
248
+ except (ValueError, TypeError):
249
+ relevance_score = 0
250
+
251
+ if relevance_score < min_score:
252
+ skipped_low_score += 1
253
+ continue
254
+
255
+ # Get function content
256
+ function_content = row.get('function_content', '').strip()
257
+ if not function_content or len(function_content) < 50:
258
+ skipped_no_code += 1
259
+ continue
260
+
261
+ # Prepare metadata
262
+ metadata = {
263
+ 'original_index': row.get('original_index'),
264
+ 'function_name': row.get('function_name'),
265
+ 'repo_name': row.get('repo_name'),
266
+ 'path': row.get('path'),
267
+ 'language': row.get('language'),
268
+ 'relevance_score': relevance_score,
269
+ 'function_start_line': row.get('function_start_line'),
270
+ 'function_end_line': row.get('function_end_line'),
271
+ }
272
+
273
+ # Generate prompt
274
+ prompt = PROMPT_TEMPLATE.format(code=function_content)
275
+
276
+ tasks.append({
277
+ 'row_number': total_rows,
278
+ 'metadata': metadata,
279
+ 'prompt': prompt,
280
+ 'function_content': function_content
281
+ })
282
+
283
+ print(f"Total rows read: {total_rows}")
284
+ print(f"Tasks to process: {len(tasks)}")
285
+ print(f"Skipped (low score): {skipped_low_score}")
286
+ print(f"Skipped (no/short code): {skipped_no_code}")
287
+ print(f"\nStarting concurrent processing with {max_workers} workers...\n")
288
+
289
+ # Define worker function
290
+ def process_task(task):
291
+ """Process a single task."""
292
+ try:
293
+ row_number = task['row_number']
294
+ metadata = task['metadata']
295
+ prompt = task['prompt']
296
+
297
+ print(f"Processing row {row_number} (score={metadata['relevance_score']}, func={metadata['function_name']})...", end=' ')
298
+
299
+ response_text, usage_info = client.generate_content(prompt)
300
+
301
+ print(f"✓ (${usage_info['request_cost']:.6f}, {usage_info['total_tokens']} tokens)")
302
+
303
+ # Return result
304
+ return {
305
+ 'success': True,
306
+ 'data': {
307
+ 'metadata': metadata,
308
+ 'prompt': prompt,
309
+ 'response': response_text,
310
+ 'usage': usage_info,
311
+ 'timestamp': datetime.now().isoformat(),
312
+ 'row_number': row_number
313
+ }
314
+ }
315
+
316
+ except Exception as e:
317
+ print(f"✗ Error: {e}")
318
+ return {
319
+ 'success': False,
320
+ 'error': str(e),
321
+ 'row_number': task['row_number']
322
+ }
323
+
324
+ # Open output file in append mode if resuming
325
+ mode = 'a' if start_from > 0 else 'w'
326
+
327
+ # Process tasks concurrently
328
+ with open(output_file, mode, encoding='utf-8') as outfile:
329
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
330
+ # Submit all tasks
331
+ future_to_task = {executor.submit(process_task, task): task for task in tasks}
332
+
333
+ # Process results as they complete
334
+ for future in as_completed(future_to_task):
335
+ # Check budget before processing more
336
+ if client.total_cost >= max_budget:
337
+ print(f"\n⚠️ Budget limit reached (${client.total_cost:.6f} >= ${max_budget:.2f})")
338
+ print(f"Cancelling remaining tasks...")
339
+ # Cancel pending futures
340
+ for f in future_to_task:
341
+ f.cancel()
342
+ break
343
+
344
+ result = future.result()
345
+
346
+ if result['success']:
347
+ # Save result
348
+ outfile.write(json.dumps(result['data'], ensure_ascii=False) + '\n')
349
+ outfile.flush() # Ensure data is written immediately
350
+
351
+ processed += 1
352
+
353
+ # Print periodic summary
354
+ if processed % 10 == 0:
355
+ print(f"\n--- Progress: {processed} problems generated, ${client.total_cost:.6f} spent ---\n")
356
+ else:
357
+ errors += 1
358
+
359
+ # Final summary
360
+ print("\n" + "="*70)
361
+ print("PROCESSING COMPLETE")
362
+ print("="*70)
363
+ print(f"Total rows read: {total_rows}")
364
+ print(f"Successfully processed: {processed}")
365
+ print(f"Skipped (low score): {skipped_low_score}")
366
+ print(f"Skipped (no/short code): {skipped_no_code}")
367
+ print(f"Errors: {errors}")
368
+
369
+ client.print_usage_summary()
370
+
371
+ print(f"\nResults saved to: {output_file}")
372
+
373
+ return processed
374
+
375
+
376
+ if __name__ == "__main__":
377
+ import argparse
378
+
379
+ parser = argparse.ArgumentParser(
380
+ description='Generate programming problems from function dataset using Gemini API'
381
+ )
382
+ parser.add_argument(
383
+ '--input',
384
+ default='function_dataset_v2.csv',
385
+ help='Input CSV file (default: function_dataset_v2.csv)'
386
+ )
387
+ parser.add_argument(
388
+ '--output',
389
+ default='programming_problems.jsonl',
390
+ help='Output JSONL file (default: programming_problems.jsonl)'
391
+ )
392
+ parser.add_argument(
393
+ '--min-score',
394
+ type=int,
395
+ default=MIN_RELEVANCE_SCORE,
396
+ help=f'Minimum relevance score (default: {MIN_RELEVANCE_SCORE})'
397
+ )
398
+ parser.add_argument(
399
+ '--max-budget',
400
+ type=float,
401
+ default=MAX_BUDGET_USD,
402
+ help=f'Maximum budget in USD (default: {MAX_BUDGET_USD})'
403
+ )
404
+ parser.add_argument(
405
+ '--max-samples',
406
+ type=int,
407
+ default=None,
408
+ help='Maximum number of samples to process (default: no limit)'
409
+ )
410
+ parser.add_argument(
411
+ '--start-from',
412
+ type=int,
413
+ default=0,
414
+ help='Start from row N (for resuming, default: 0)'
415
+ )
416
+ parser.add_argument(
417
+ '--max-workers',
418
+ type=int,
419
+ default=10,
420
+ help='Maximum number of concurrent workers (default: 10)'
421
+ )
422
+
423
+ args = parser.parse_args()
424
+
425
+ # Check if input file exists
426
+ if not os.path.exists(args.input):
427
+ print(f"Error: Input file not found: {args.input}")
428
+ sys.exit(1)
429
+
430
+ try:
431
+ process_function_dataset(
432
+ input_file=args.input,
433
+ output_file=args.output,
434
+ min_score=args.min_score,
435
+ max_budget=args.max_budget,
436
+ max_samples=args.max_samples,
437
+ start_from=args.start_from,
438
+ max_workers=args.max_workers
439
+ )
440
+ print("\n✅ Success!")
441
+ except KeyboardInterrupt:
442
+ print("\n\n⚠️ Interrupted by user. Progress has been saved to output file.")
443
+ print(" You can resume by using --start-from <row_number>")
444
+ sys.exit(0)
445
+ except Exception as e:
446
+ print(f"\n❌ Error: {e}")
447
+ import traceback
448
+ traceback.print_exc()
449
+ sys.exit(1)