DouDou commited on
Commit
be0bdd8
·
verified ·
1 Parent(s): 10ae0ab

Upload data3/extract_functions_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data3/extract_functions_v2.py +308 -0
data3/extract_functions_v2.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract individual functions from enhanced_dataset.csv and create a new dataset.
4
+ Each function becomes a separate row in the new dataset.
5
+ Version 2: Better handling of malformed CSV/JSON
6
+ """
7
+
8
+ import csv
9
+ import json
10
+ import re
11
+ from collections import defaultdict
12
+ import sys
13
+
14
+ def clean_json_string(json_str):
15
+ """
16
+ Clean up malformed JSON strings that may have been corrupted by CSV formatting.
17
+ """
18
+ # Remove extra spaces in key names that might have been inserted
19
+ # This is a bit risky but we'll try to handle common cases
20
+
21
+ # Replace common malformed patterns
22
+ json_str = re.sub(r'"\s*function_nam\s*e\s*"', '"function_name"', json_str)
23
+ json_str = re.sub(r'"\s*function_start_line\s*"', '"function_start_line"', json_str)
24
+ json_str = re.sub(r'"\s*function_end_line\s*"', '"function_end_line"', json_str)
25
+ json_str = re.sub(r'"\s*relevance_score\s*"', '"relevance_score"', json_str)
26
+ json_str = re.sub(r'"\s*relevance_reason\s*"', '"relevance_reason"', json_str)
27
+ json_str = re.sub(r'"\s*doc_start_line\s*"', '"doc_start_line"', json_str)
28
+ json_str = re.sub(r'"\s*doc_end_line\s*"', '"doc_end_line"', json_str)
29
+
30
+ # Remove markdown bold markers that might have been inserted
31
+ json_str = json_str.replace('**', '')
32
+
33
+ # Fix spacing issues in keys
34
+ json_str = re.sub(r'"\s*([a-z_]+)\s*([a-z_]+)\s*([a-z_]*)\s*":',
35
+ lambda m: '"' + m.group(1) + m.group(2) + (m.group(3) if m.group(3) else '') + '":',
36
+ json_str)
37
+
38
+ return json_str
39
+
40
+
41
+ def extract_function_content(text, start_line, end_line):
42
+ """
43
+ Extract function content from text based on line number range.
44
+
45
+ Args:
46
+ text: The full code text
47
+ start_line: Starting line number (1-indexed)
48
+ end_line: Ending line number (1-indexed)
49
+
50
+ Returns:
51
+ Extracted function content as string
52
+ """
53
+ lines = text.split('\n')
54
+ # Convert to 0-indexed (since start_line is 1-indexed, we subtract 1)
55
+ start_idx = max(0, start_line - 1)
56
+ end_idx = min(len(lines), end_line) # end_line is inclusive, so we don't subtract 1
57
+
58
+ function_lines = lines[start_idx:end_idx]
59
+ return '\n'.join(function_lines)
60
+
61
+
62
+ def process_dataset(input_file, output_file):
63
+ """
64
+ Process enhanced_dataset.csv and extract functions.
65
+
66
+ Args:
67
+ input_file: Path to enhanced_dataset.csv
68
+ output_file: Path to output CSV file
69
+ """
70
+ print(f"Reading from: {input_file}")
71
+ print(f"Writing to: {output_file}")
72
+
73
+ # Statistics
74
+ total_rows = 0
75
+ total_functions = 0
76
+ score_distribution = defaultdict(int)
77
+ skipped_rows = 0
78
+ parse_errors = 0
79
+ empty_function_info = 0
80
+
81
+ with open(input_file, 'r', encoding='utf-8') as infile, \
82
+ open(output_file, 'w', encoding='utf-8', newline='') as outfile:
83
+
84
+ reader = csv.DictReader(infile)
85
+
86
+ # Define output columns
87
+ fieldnames = [
88
+ 'original_index', # Original row number
89
+ 'function_index', # Index within the file
90
+ 'repo_name',
91
+ 'path',
92
+ 'language',
93
+ 'license',
94
+ 'keyword',
95
+ 'text_hash',
96
+ 'config',
97
+ 'split',
98
+ 'repo_path',
99
+ 'ds_source',
100
+ 'function_name',
101
+ 'function_start_line',
102
+ 'function_end_line',
103
+ 'doc_start_line',
104
+ 'doc_end_line',
105
+ 'relevance_score',
106
+ 'relevance_reason',
107
+ 'function_content'
108
+ ]
109
+
110
+ writer = csv.DictWriter(outfile, fieldnames=fieldnames)
111
+ writer.writeheader()
112
+
113
+ # Store all function rows for later sorting
114
+ all_function_rows = []
115
+
116
+ print("\nProcessing rows...")
117
+ for row in reader:
118
+ total_rows += 1
119
+
120
+ if total_rows % 1000 == 0:
121
+ print(f"Processed {total_rows} rows, extracted {total_functions} functions, errors: {parse_errors}...", end='\r')
122
+
123
+ # Parse function_info JSON
124
+ function_info_str = row.get('function_info', '[]')
125
+ if not function_info_str or function_info_str.strip() == '':
126
+ empty_function_info += 1
127
+ skipped_rows += 1
128
+ continue
129
+
130
+ # Clean the JSON string
131
+ function_info_str = clean_json_string(function_info_str)
132
+
133
+ # Handle potential CSV escaping issues
134
+ try:
135
+ # First try direct JSON parsing
136
+ function_info_list = json.loads(function_info_str)
137
+ except (json.JSONDecodeError, ValueError) as e:
138
+ # If that fails, try with ast.literal_eval as backup
139
+ try:
140
+ import ast
141
+ function_info_list = ast.literal_eval(function_info_str)
142
+ except:
143
+ # If still fails, skip this row
144
+ parse_errors += 1
145
+ if parse_errors <= 5: # Only print first 5 errors
146
+ print(f"\nWarning: Failed to parse function_info in row {total_rows}")
147
+ if parse_errors == 5:
148
+ print("(Suppressing further parse error messages...)")
149
+ skipped_rows += 1
150
+ continue
151
+
152
+ # Validate that we got a list
153
+ if not isinstance(function_info_list, list):
154
+ skipped_rows += 1
155
+ continue
156
+
157
+ # Get the original text
158
+ text = row.get('text', '')
159
+
160
+ # Extract each function
161
+ for func_idx, func_info in enumerate(function_info_list):
162
+ # Validate func_info is a dictionary
163
+ if not isinstance(func_info, dict):
164
+ continue
165
+
166
+ # Extract function content
167
+ start_line = func_info.get('function_start_line', 0)
168
+ end_line = func_info.get('function_end_line', 0)
169
+
170
+ # Ensure they are integers
171
+ try:
172
+ start_line = int(start_line) if start_line else 0
173
+ end_line = int(end_line) if end_line else 0
174
+ except (ValueError, TypeError):
175
+ start_line = 0
176
+ end_line = 0
177
+
178
+ if start_line > 0 and end_line > 0:
179
+ function_content = extract_function_content(text, start_line, end_line)
180
+ else:
181
+ function_content = ""
182
+
183
+ # Get relevance score
184
+ relevance_score = func_info.get('relevance_score', 0)
185
+
186
+ # Ensure it's an integer
187
+ try:
188
+ relevance_score = int(relevance_score) if relevance_score else 0
189
+ except (ValueError, TypeError):
190
+ relevance_score = 0
191
+
192
+ # Track score distribution (in buckets of 10)
193
+ score_bucket = (relevance_score // 10) * 10
194
+ score_distribution[score_bucket] += 1
195
+
196
+ # Create new row
197
+ new_row = {
198
+ 'original_index': row.get('Unnamed: 0', row.get('Unnamed: 0.1', total_rows - 1)),
199
+ 'function_index': func_idx,
200
+ 'repo_name': row.get('repo_name', ''),
201
+ 'path': row.get('path', ''),
202
+ 'language': row.get('language', ''),
203
+ 'license': row.get('license', ''),
204
+ 'keyword': row.get('keyword', ''),
205
+ 'text_hash': row.get('text_hash', ''),
206
+ 'config': row.get('config', ''),
207
+ 'split': row.get('split', ''),
208
+ 'repo_path': row.get('repo_path', ''),
209
+ 'ds_source': row.get('ds_source', ''),
210
+ 'function_name': func_info.get('function_name', ''),
211
+ 'function_start_line': start_line,
212
+ 'function_end_line': end_line,
213
+ 'doc_start_line': func_info.get('doc_start_line', ''),
214
+ 'doc_end_line': func_info.get('doc_end_line', ''),
215
+ 'relevance_score': relevance_score,
216
+ 'relevance_reason': func_info.get('relevance_reason', ''),
217
+ 'function_content': function_content
218
+ }
219
+
220
+ all_function_rows.append(new_row)
221
+ total_functions += 1
222
+
223
+ print(f"\n\nTotal rows processed: {total_rows}")
224
+ print(f"Total functions extracted: {total_functions}")
225
+ print(f"Skipped rows:")
226
+ print(f" - Empty function_info: {empty_function_info}")
227
+ print(f" - Parse errors: {parse_errors}")
228
+ print(f" - Total skipped: {skipped_rows}")
229
+
230
+ # Sort by relevance_score (descending - highest first)
231
+ print("\nSorting by relevance score...")
232
+ all_function_rows.sort(key=lambda x: x['relevance_score'], reverse=True)
233
+
234
+ # Write sorted rows
235
+ print("Writing sorted data to output file...")
236
+ for row in all_function_rows:
237
+ writer.writerow(row)
238
+
239
+ print(f"\nSuccessfully written {total_functions} functions to {output_file}")
240
+
241
+ # Print score distribution
242
+ print("\n" + "="*70)
243
+ print("SCORE DISTRIBUTION")
244
+ print("="*70)
245
+ print(f"{'Score Range':<15} {'Count':<12} {'Percentage':<12} {'Visualization'}")
246
+ print("-"*70)
247
+
248
+ # Sort by score range (descending)
249
+ sorted_scores = sorted(score_distribution.items(), reverse=True)
250
+
251
+ # Filter out anomalous scores (very negative values)
252
+ normal_scores = [(k, v) for k, v in sorted_scores if k >= 0]
253
+ anomalous_scores = [(k, v) for k, v in sorted_scores if k < 0]
254
+
255
+ for score_bucket, count in normal_scores:
256
+ percentage = (count / total_functions * 100) if total_functions > 0 else 0
257
+ bar = '█' * min(50, int(percentage / 2)) # Scale bar to fit
258
+ print(f"{score_bucket:>3}-{score_bucket+9:<9} {count:<12} {percentage:>6.2f}% {bar}")
259
+
260
+ if anomalous_scores:
261
+ print("\nAnomalous scores (negative or out of range):")
262
+ for score_bucket, count in anomalous_scores:
263
+ percentage = (count / total_functions * 100) if total_functions > 0 else 0
264
+ print(f"{score_bucket:>15} {count:<12} {percentage:>6.2f}%")
265
+
266
+ print("-"*70)
267
+ print(f"{'Total':<15} {total_functions:<12} {'100.00%':<12}")
268
+ print("="*70)
269
+
270
+ # Additional statistics
271
+ if total_functions > 0:
272
+ # Filter out anomalous scores for statistics
273
+ valid_scores = [row['relevance_score'] for row in all_function_rows
274
+ if 0 <= row['relevance_score'] <= 100]
275
+
276
+ if valid_scores:
277
+ avg_score = sum(valid_scores) / len(valid_scores)
278
+ max_score = max(valid_scores)
279
+ min_score = min(valid_scores)
280
+
281
+ print(f"\nScore Statistics (valid scores 0-100 only):")
282
+ print(f" Average Score: {avg_score:.2f}")
283
+ print(f" Maximum Score: {max_score}")
284
+ print(f" Minimum Score: {min_score}")
285
+ print(f" Valid Functions: {len(valid_scores)} / {total_functions}")
286
+
287
+
288
+ if __name__ == "__main__":
289
+ input_file = "enhanced_dataset.csv"
290
+ output_file = "function_dataset_v2.csv"
291
+
292
+ # Allow command line arguments
293
+ if len(sys.argv) > 1:
294
+ input_file = sys.argv[1]
295
+ if len(sys.argv) > 2:
296
+ output_file = sys.argv[2]
297
+
298
+ try:
299
+ process_dataset(input_file, output_file)
300
+ print("\n✅ Processing complete!")
301
+ except FileNotFoundError:
302
+ print(f"❌ Error: File '{input_file}' not found.")
303
+ sys.exit(1)
304
+ except Exception as e:
305
+ print(f"❌ Error: {e}")
306
+ import traceback
307
+ traceback.print_exc()
308
+ sys.exit(1)