DouDou commited on
Commit
10ae0ab
·
verified ·
1 Parent(s): 2aa06ec

Upload data3/extract_functions.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data3/extract_functions.py +254 -0
data3/extract_functions.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract individual functions from enhanced_dataset.csv and create a new dataset.
4
+ Each function becomes a separate row in the new dataset.
5
+ """
6
+
7
+ import csv
8
+ import json
9
+ from collections import defaultdict
10
+ import sys
11
+
12
+ def extract_function_content(text, start_line, end_line):
13
+ """
14
+ Extract function content from text based on line number range.
15
+
16
+ Args:
17
+ text: The full code text
18
+ start_line: Starting line number (1-indexed)
19
+ end_line: Ending line number (1-indexed)
20
+
21
+ Returns:
22
+ Extracted function content as string
23
+ """
24
+ lines = text.split('\n')
25
+ # Convert to 0-indexed and handle boundary cases
26
+ start_idx = max(0, start_line - 1)
27
+ end_idx = min(len(lines), end_line)
28
+
29
+ function_lines = lines[start_idx:end_idx]
30
+ return '\n'.join(function_lines)
31
+
32
+
33
+ def process_dataset(input_file, output_file):
34
+ """
35
+ Process enhanced_dataset.csv and extract functions.
36
+
37
+ Args:
38
+ input_file: Path to enhanced_dataset.csv
39
+ output_file: Path to output CSV file
40
+ """
41
+ print(f"Reading from: {input_file}")
42
+ print(f"Writing to: {output_file}")
43
+
44
+ # Statistics
45
+ total_rows = 0
46
+ total_functions = 0
47
+ score_distribution = defaultdict(int)
48
+ skipped_rows = 0
49
+
50
+ with open(input_file, 'r', encoding='utf-8') as infile, \
51
+ open(output_file, 'w', encoding='utf-8', newline='') as outfile:
52
+
53
+ reader = csv.DictReader(infile)
54
+
55
+ # Define output columns
56
+ fieldnames = [
57
+ 'original_index', # Original row number
58
+ 'function_index', # Index within the file
59
+ 'repo_name',
60
+ 'path',
61
+ 'language',
62
+ 'license',
63
+ 'keyword',
64
+ 'text_hash',
65
+ 'config',
66
+ 'split',
67
+ 'repo_path',
68
+ 'ds_source',
69
+ 'function_name',
70
+ 'function_start_line',
71
+ 'function_end_line',
72
+ 'doc_start_line',
73
+ 'doc_end_line',
74
+ 'relevance_score',
75
+ 'relevance_reason',
76
+ 'function_content'
77
+ ]
78
+
79
+ writer = csv.DictWriter(outfile, fieldnames=fieldnames)
80
+ writer.writeheader()
81
+
82
+ # Store all function rows for later sorting
83
+ all_function_rows = []
84
+
85
+ print("\nProcessing rows...")
86
+ for row in reader:
87
+ total_rows += 1
88
+
89
+ if total_rows % 100 == 0:
90
+ print(f"Processed {total_rows} rows, extracted {total_functions} functions...", end='\r')
91
+
92
+ # Parse function_info JSON
93
+ function_info_str = row.get('function_info', '[]')
94
+ if not function_info_str or function_info_str.strip() == '':
95
+ skipped_rows += 1
96
+ continue
97
+
98
+ # Handle potential CSV escaping issues
99
+ # In CSV, quotes might be doubled, so we need to unescape them
100
+ try:
101
+ # First try direct JSON parsing
102
+ function_info_list = json.loads(function_info_str)
103
+ except (json.JSONDecodeError, ValueError) as e:
104
+ # If that fails, try with ast.literal_eval as backup
105
+ try:
106
+ import ast
107
+ function_info_list = ast.literal_eval(function_info_str)
108
+ except:
109
+ # If still fails, skip this row
110
+ if total_rows <= 20: # Only print first 20 errors
111
+ print(f"\nWarning: Failed to parse function_info in row {total_rows}")
112
+ skipped_rows += 1
113
+ continue
114
+
115
+ # Validate that we got a list
116
+ if not isinstance(function_info_list, list):
117
+ skipped_rows += 1
118
+ continue
119
+
120
+ # Get the original text
121
+ text = row.get('text', '')
122
+
123
+ # Extract each function
124
+ for func_idx, func_info in enumerate(function_info_list):
125
+ # Validate func_info is a dictionary
126
+ if not isinstance(func_info, dict):
127
+ continue
128
+
129
+ # Extract function content
130
+ start_line = func_info.get('function_start_line', 0)
131
+ end_line = func_info.get('function_end_line', 0)
132
+
133
+ # Ensure they are integers
134
+ try:
135
+ start_line = int(start_line) if start_line else 0
136
+ end_line = int(end_line) if end_line else 0
137
+ except (ValueError, TypeError):
138
+ start_line = 0
139
+ end_line = 0
140
+
141
+ if start_line > 0 and end_line > 0:
142
+ function_content = extract_function_content(text, start_line, end_line)
143
+ else:
144
+ function_content = ""
145
+
146
+ # Get relevance score
147
+ relevance_score = func_info.get('relevance_score', 0)
148
+
149
+ # Ensure it's an integer
150
+ try:
151
+ relevance_score = int(relevance_score) if relevance_score else 0
152
+ except (ValueError, TypeError):
153
+ relevance_score = 0
154
+
155
+ # Track score distribution (in buckets of 10)
156
+ score_bucket = (relevance_score // 10) * 10
157
+ score_distribution[score_bucket] += 1
158
+
159
+ # Create new row
160
+ new_row = {
161
+ 'original_index': row.get('Unnamed: 0', row.get('Unnamed: 0.1', total_rows - 1)),
162
+ 'function_index': func_idx,
163
+ 'repo_name': row.get('repo_name', ''),
164
+ 'path': row.get('path', ''),
165
+ 'language': row.get('language', ''),
166
+ 'license': row.get('license', ''),
167
+ 'keyword': row.get('keyword', ''),
168
+ 'text_hash': row.get('text_hash', ''),
169
+ 'config': row.get('config', ''),
170
+ 'split': row.get('split', ''),
171
+ 'repo_path': row.get('repo_path', ''),
172
+ 'ds_source': row.get('ds_source', ''),
173
+ 'function_name': func_info.get('function_name', ''),
174
+ 'function_start_line': start_line,
175
+ 'function_end_line': end_line,
176
+ 'doc_start_line': func_info.get('doc_start_line', ''),
177
+ 'doc_end_line': func_info.get('doc_end_line', ''),
178
+ 'relevance_score': relevance_score,
179
+ 'relevance_reason': func_info.get('relevance_reason', ''),
180
+ 'function_content': function_content
181
+ }
182
+
183
+ all_function_rows.append(new_row)
184
+ total_functions += 1
185
+
186
+ print(f"\n\nTotal rows processed: {total_rows}")
187
+ print(f"Total functions extracted: {total_functions}")
188
+ print(f"Skipped rows (no valid function_info): {skipped_rows}")
189
+
190
+ # Sort by relevance_score (descending - highest first)
191
+ print("\nSorting by relevance score...")
192
+ all_function_rows.sort(key=lambda x: x['relevance_score'], reverse=True)
193
+
194
+ # Write sorted rows
195
+ print("Writing sorted data to output file...")
196
+ for row in all_function_rows:
197
+ writer.writerow(row)
198
+
199
+ print(f"\nSuccessfully written {total_functions} functions to {output_file}")
200
+
201
+ # Print score distribution
202
+ print("\n" + "="*60)
203
+ print("SCORE DISTRIBUTION")
204
+ print("="*60)
205
+ print(f"{'Score Range':<20} {'Count':<10} {'Percentage':<10} {'Bar'}")
206
+ print("-"*60)
207
+
208
+ # Sort by score range
209
+ sorted_scores = sorted(score_distribution.items(), reverse=True)
210
+
211
+ for score_bucket, count in sorted_scores:
212
+ percentage = (count / total_functions * 100) if total_functions > 0 else 0
213
+ bar = '█' * int(percentage / 2) # Scale bar to fit
214
+ print(f"{score_bucket}-{score_bucket+9:<18} {count:<10} {percentage:>6.2f}% {bar}")
215
+
216
+ print("-"*60)
217
+ print(f"{'Total':<20} {total_functions:<10} {'100.00%':<10}")
218
+ print("="*60)
219
+
220
+ # Additional statistics
221
+ if total_functions > 0:
222
+ scores = [row['relevance_score'] for row in all_function_rows]
223
+ avg_score = sum(scores) / len(scores)
224
+ max_score = max(scores)
225
+ min_score = min(scores)
226
+
227
+ print(f"\nScore Statistics:")
228
+ print(f" Average Score: {avg_score:.2f}")
229
+ print(f" Maximum Score: {max_score}")
230
+ print(f" Minimum Score: {min_score}")
231
+ print(f" Total Functions: {total_functions}")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ input_file = "enhanced_dataset.csv"
236
+ output_file = "function_dataset.csv"
237
+
238
+ # Allow command line arguments
239
+ if len(sys.argv) > 1:
240
+ input_file = sys.argv[1]
241
+ if len(sys.argv) > 2:
242
+ output_file = sys.argv[2]
243
+
244
+ try:
245
+ process_dataset(input_file, output_file)
246
+ print("\n✅ Processing complete!")
247
+ except FileNotFoundError:
248
+ print(f"❌ Error: File '{input_file}' not found.")
249
+ sys.exit(1)
250
+ except Exception as e:
251
+ print(f"❌ Error: {e}")
252
+ import traceback
253
+ traceback.print_exc()
254
+ sys.exit(1)