DouDou commited on
Commit
432dc67
·
verified ·
1 Parent(s): 900dd38

Upload data2/instruction_generation/extract_repo_functions.py with huggingface_hub

Browse files
data2/instruction_generation/extract_repo_functions.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-language Function Parsing Script
4
+ Scans code files in each repository, uses Qwen to parse dependencies and functions,
5
+ generates functions_with_context.csv
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import json
11
+ import csv
12
+ import asyncio
13
+ import argparse
14
+ import hashlib
15
+ from pathlib import Path
16
+ from typing import List, Dict, Optional
17
+ from tqdm import tqdm
18
+ from dotenv import load_dotenv
19
+
20
+ # Load .env file (before importing other modules)
21
+ env_file = Path(__file__).parent / ".env"
22
+ if env_file.exists():
23
+ load_dotenv(env_file)
24
+ elif (Path(__file__).parent.parent / ".env").exists():
25
+ # If not in current directory, try loading from project root
26
+ load_dotenv(Path(__file__).parent.parent / ".env")
27
+
28
+ # Add current directory to path (for importing schemas)
29
+ sys.path.insert(0, str(Path(__file__).parent))
30
+ # Add domain_code/src to path for reusing util functions
31
+ sys.path.insert(0, str(Path(__file__).parent.parent / "domain_code" / "src"))
32
+ from util import call_llm, init_logger, logger, CODE_EXTENSIONS
33
+ from schemas import FileParseResult
34
+
35
+ # Exclude markdown files (should not be parsed as code files)
36
+ PARSEABLE_CODE_EXTENSIONS = CODE_EXTENSIONS - {".md", ".markdown"}
37
+
38
+
39
+ # Default output filename (written back to repository directory)
40
+ CSV_FILENAME = "functions_with_context.csv"
41
+ SUMMARY_FILENAME = "README_SUMMARY.md"
42
+
43
+
44
+ def detect_language(file_path: Path) -> str:
45
+ """
46
+ Detect programming language based on file extension
47
+
48
+ Args:
49
+ file_path: File path
50
+
51
+ Returns:
52
+ Programming language name (e.g., python, cpp, java)
53
+ """
54
+ ext_map = {
55
+ ".py": "python",
56
+ ".ipynb": "python",
57
+ ".java": "java",
58
+ ".c": "c",
59
+ ".cpp": "cpp",
60
+ ".cc": "cpp",
61
+ ".cxx": "cpp",
62
+ ".h": "cpp",
63
+ ".hpp": "cpp",
64
+ ".hh": "cpp",
65
+ ".F": "fortran",
66
+ ".f90": "fortran",
67
+ ".f": "fortran",
68
+ ".f95": "fortran",
69
+ ".r": "r",
70
+ ".R": "r",
71
+ ".m": "matlab",
72
+ ".sh": "shell",
73
+ ".bash": "shell",
74
+ ".rs": "rust",
75
+ ".go": "go",
76
+ ".jl": "julia",
77
+ }
78
+
79
+ ext = file_path.suffix.lower()
80
+ return ext_map.get(ext, ext.lstrip(".") if ext else "unknown")
81
+
82
+
83
+ def read_readme_summary(repo_dir: Path) -> Optional[str]:
84
+ """
85
+ Read README_SUMMARY.md content as project context
86
+
87
+ Args:
88
+ repo_dir: Repository root directory
89
+
90
+ Returns:
91
+ README summary text or None
92
+ """
93
+ summary_file = repo_dir / SUMMARY_FILENAME
94
+ if not summary_file.exists():
95
+ return None
96
+
97
+ try:
98
+ with open(summary_file, "r", encoding="utf-8", errors="ignore") as f:
99
+ return f.read().strip()
100
+ except Exception as e:
101
+ logger.warning(f"Unable to read README summary file {summary_file}: {e}")
102
+ return None
103
+
104
+
105
+ def find_code_files(repo_dir: Path, max_file_chars: int = 200000) -> List[Path]:
106
+ """
107
+ Find all code files in the repository (files covered by CODE_EXTENSIONS)
108
+
109
+ Args:
110
+ repo_dir: Repository root directory
111
+ max_file_chars: Maximum file size (chars), files exceeding this are skipped
112
+
113
+ Returns:
114
+ List of code file paths
115
+ """
116
+ code_files = []
117
+
118
+ for root, dirs, files in os.walk(repo_dir):
119
+ # Skip hidden directories and common non-source directories
120
+ dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ["__pycache__", "node_modules", ".git"]]
121
+
122
+ for file in files:
123
+ file_path = Path(root) / file
124
+ # Use PARSEABLE_CODE_EXTENSIONS to exclude markdown files
125
+ if file_path.suffix.lower() in PARSEABLE_CODE_EXTENSIONS:
126
+ # Check file size
127
+ try:
128
+ size = file_path.stat().st_size
129
+ # Simple estimation: assume average 1 byte per char (UTF-8 encoding)
130
+ if size <= max_file_chars:
131
+ code_files.append(file_path)
132
+ else:
133
+ logger.debug(f"Skipping large file: {file_path} ({size} bytes)")
134
+ except Exception as e:
135
+ logger.warning(f"Unable to get file size {file_path}: {e}")
136
+
137
+ return sorted(code_files)
138
+
139
+
140
+ def read_code_file(file_path: Path) -> Optional[str]:
141
+ """
142
+ Read code file content
143
+
144
+ Args:
145
+ file_path: File path
146
+
147
+ Returns:
148
+ File content or None
149
+ """
150
+ try:
151
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
152
+ return f.read()
153
+ except Exception as e:
154
+ logger.warning(f"Unable to read file {file_path}: {e}")
155
+ return None
156
+
157
+
158
+ def compute_file_hash(file_path: Path, content: str) -> str:
159
+ """
160
+ Compute SHA1 hash of file
161
+
162
+ Args:
163
+ file_path: File path
164
+ content: File content
165
+
166
+ Returns:
167
+ SHA1 hash (hex string)
168
+ """
169
+ return hashlib.sha1(content.encode("utf-8")).hexdigest()
170
+
171
+
172
+ def compute_function_hash(repo_name: str, path: str, start_line: int, end_line: int, body: str) -> str:
173
+ """
174
+ Compute function hash (for deduplication)
175
+
176
+ Args:
177
+ repo_name: Repository name
178
+ path: Relative file path
179
+ start_line: Function start line number
180
+ end_line: Function end line number
181
+ body: Function body
182
+
183
+ Returns:
184
+ SHA1 hash (hex string)
185
+ """
186
+ key = f"{repo_name}:{path}:{start_line}:{end_line}:{body}"
187
+ return hashlib.sha1(key.encode("utf-8")).hexdigest()
188
+
189
+
190
+ async def parse_code_file(
191
+ file_path: Path,
192
+ repo_dir: Path,
193
+ project_context: str,
194
+ base_url: str,
195
+ model: str,
196
+ api_key: str,
197
+ log_file: str,
198
+ ) -> Optional[Dict]:
199
+ """
200
+ Use LLM to parse code file, extract dependencies and function information
201
+
202
+ Args:
203
+ file_path: Code file path
204
+ repo_dir: Repository root directory
205
+ project_context: Project context (README summary)
206
+ base_url: LLM API base URL
207
+ model: Model name
208
+ api_key: API key
209
+ log_file: Log file path
210
+
211
+ Returns:
212
+ Parse result (dict) or None
213
+ """
214
+ # Read code content
215
+ code_content = read_code_file(file_path)
216
+ if not code_content:
217
+ return None
218
+
219
+ # Detect language
220
+ language = detect_language(file_path)
221
+
222
+ # Build relative path
223
+ rel_path = str(file_path.relative_to(repo_dir))
224
+
225
+ # Read prompt template
226
+ prompt_template_path = Path(__file__).parent / "prompts" / "function_extract.txt"
227
+ try:
228
+ with open(prompt_template_path, "r", encoding="utf-8") as f:
229
+ prompt_template = f.read()
230
+ except Exception as e:
231
+ logger.error(f"Unable to read prompt template: {e}")
232
+ return None
233
+
234
+ # Build prompt
235
+ prompt = prompt_template.format(
236
+ project_context=project_context or "(No project context)",
237
+ file_path=rel_path,
238
+ language=language,
239
+ code_content=code_content,
240
+ )
241
+
242
+ # Call LLM
243
+ messages = [{"role": "user", "content": prompt}]
244
+
245
+ try:
246
+ result = await call_llm(
247
+ messages=messages,
248
+ model=model,
249
+ base_url=base_url,
250
+ api_key=api_key,
251
+ pydantic_object=FileParseResult,
252
+ log_file=log_file,
253
+ )
254
+
255
+ if result is None:
256
+ logger.warning(f"LLM call returned None, skipping file: {rel_path}")
257
+ return None
258
+
259
+ # If result is a string, try to parse JSON
260
+ if isinstance(result, str):
261
+ try:
262
+ result = json.loads(result)
263
+ except json.JSONDecodeError:
264
+ logger.warning(f"Unable to parse JSON from LLM response: {result[:200]}")
265
+ return None
266
+
267
+ # Add file path (ensure consistency)
268
+ if isinstance(result, dict):
269
+ result["file_path"] = rel_path
270
+ result["language"] = language
271
+
272
+ return result
273
+ except Exception as e:
274
+ logger.error(f"LLM call failed (file: {rel_path}): {e}")
275
+ return None
276
+
277
+
278
+ def extract_repo_name(repo_dir: Path) -> str:
279
+ """
280
+ Extract repository name from directory name (owner___repo -> owner/repo)
281
+
282
+ Args:
283
+ repo_dir: Repository root directory
284
+
285
+ Returns:
286
+ Repository name (owner/repo format)
287
+ """
288
+ dir_name = repo_dir.name
289
+ return dir_name.replace("___", "/")
290
+
291
+
292
+ async def process_single_repo(
293
+ repo_dir: Path,
294
+ base_url: str,
295
+ model: str,
296
+ api_key: str,
297
+ log_file: str,
298
+ max_file_chars: int = 200000,
299
+ max_concurrency: int = 8,
300
+ overwrite: bool = False,
301
+ ) -> Dict[str, any]:
302
+ """
303
+ Process function parsing for a single repository
304
+
305
+ Args:
306
+ repo_dir: Repository root directory
307
+ base_url: LLM API base URL
308
+ model: Model name
309
+ api_key: API key
310
+ log_file: Log file path
311
+ max_file_chars: Maximum file size (chars)
312
+ max_concurrency: Maximum concurrency
313
+ overwrite: Whether to overwrite existing CSV file
314
+
315
+ Returns:
316
+ Processing result dictionary
317
+ """
318
+ repo_name = repo_dir.name
319
+ csv_file = repo_dir / CSV_FILENAME
320
+
321
+ # Check if CSV file already exists
322
+ if csv_file.exists() and not overwrite:
323
+ return {
324
+ "repo": repo_name,
325
+ "status": "skipped",
326
+ "reason": "CSV file already exists",
327
+ }
328
+
329
+ # Read README summary as project context
330
+ project_context = read_readme_summary(repo_dir)
331
+ if not project_context:
332
+ logger.warning(f"Repository {repo_name} has no README_SUMMARY.md, skipping")
333
+ return {
334
+ "repo": repo_name,
335
+ "status": "no_summary",
336
+ "reason": "README_SUMMARY.md not found",
337
+ }
338
+
339
+ # Find code files
340
+ code_files = find_code_files(repo_dir, max_file_chars=max_file_chars)
341
+ if not code_files:
342
+ return {
343
+ "repo": repo_name,
344
+ "status": "no_code",
345
+ "reason": "No code files found",
346
+ }
347
+
348
+ logger.info(f"Repository {repo_name}: found {len(code_files)} code files")
349
+
350
+ # Parse all code files
351
+ semaphore = asyncio.Semaphore(max_concurrency)
352
+
353
+ async def parse_with_semaphore(file_path: Path):
354
+ async with semaphore:
355
+ return await parse_code_file(
356
+ file_path=file_path,
357
+ repo_dir=repo_dir,
358
+ project_context=project_context,
359
+ base_url=base_url,
360
+ model=model,
361
+ api_key=api_key,
362
+ log_file=log_file,
363
+ )
364
+
365
+ # Parse all files concurrently
366
+ tasks = [parse_with_semaphore(file_path) for file_path in code_files]
367
+ parse_results = []
368
+
369
+ for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"Parsing {repo_name}", leave=False):
370
+ result = await task
371
+ if result:
372
+ parse_results.append(result)
373
+
374
+ if not parse_results:
375
+ return {
376
+ "repo": repo_name,
377
+ "status": "parse_failed",
378
+ "reason": "All files failed to parse",
379
+ }
380
+
381
+ # Generate CSV file
382
+ repo_name_normalized = extract_repo_name(repo_dir)
383
+
384
+ # CSV fields
385
+ fieldnames = [
386
+ "repo_name",
387
+ "readme_summary_path",
388
+ "readme_summary_text",
389
+ "path",
390
+ "language",
391
+ "dependencies",
392
+ "function_name",
393
+ "function_start_line",
394
+ "function_end_line",
395
+ "function_body",
396
+ "doc_start_line",
397
+ "doc_end_line",
398
+ "file_size_bytes",
399
+ "file_sha1",
400
+ "function_hash",
401
+ "ds_source",
402
+ ]
403
+
404
+ # Write CSV
405
+ try:
406
+ with open(csv_file, "w", encoding="utf-8", newline="") as f:
407
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
408
+ writer.writeheader()
409
+
410
+ function_count = 0
411
+ for parse_result in parse_results:
412
+ file_path = parse_result["file_path"]
413
+ language = parse_result["language"]
414
+ dependencies = parse_result.get("dependencies", [])
415
+ functions = parse_result.get("functions", [])
416
+
417
+ # Read file content (for hash and size calculation)
418
+ full_file_path = repo_dir / file_path
419
+ file_content = read_code_file(full_file_path)
420
+ file_size = len(file_content.encode("utf-8")) if file_content else 0
421
+ file_sha1 = compute_file_hash(full_file_path, file_content) if file_content else ""
422
+
423
+ # Write a row for each function
424
+ for func in functions:
425
+ function_name = func.get("function_name", "")
426
+ function_start_line = func.get("function_start_line", 0)
427
+ function_end_line = func.get("function_end_line", 0)
428
+ function_body = func.get("function_body", "")
429
+ doc_start_line = func.get("doc_start_line")
430
+ doc_end_line = func.get("doc_end_line")
431
+
432
+ function_hash = compute_function_hash(
433
+ repo_name_normalized,
434
+ file_path,
435
+ function_start_line,
436
+ function_end_line,
437
+ function_body,
438
+ )
439
+
440
+ # Truncate project_context (if too long)
441
+ context_text = project_context[:5000] if len(project_context) > 5000 else project_context
442
+
443
+ row = {
444
+ "repo_name": repo_name_normalized,
445
+ "readme_summary_path": SUMMARY_FILENAME,
446
+ "readme_summary_text": context_text,
447
+ "path": file_path,
448
+ "language": language,
449
+ "dependencies": json.dumps(dependencies, ensure_ascii=False),
450
+ "function_name": function_name,
451
+ "function_start_line": function_start_line,
452
+ "function_end_line": function_end_line,
453
+ "function_body": function_body,
454
+ "doc_start_line": doc_start_line if doc_start_line else "",
455
+ "doc_end_line": doc_end_line if doc_end_line else "",
456
+ "file_size_bytes": file_size,
457
+ "file_sha1": file_sha1,
458
+ "function_hash": function_hash,
459
+ "ds_source": "repos_filtered",
460
+ }
461
+
462
+ writer.writerow(row)
463
+ function_count += 1
464
+
465
+ logger.info(f"Repository {repo_name}: wrote {function_count} functions to {csv_file}")
466
+
467
+ return {
468
+ "repo": repo_name,
469
+ "status": "success",
470
+ "csv_file": str(csv_file),
471
+ "file_count": len(code_files),
472
+ "function_count": function_count,
473
+ }
474
+ except Exception as e:
475
+ logger.error(f"Unable to write CSV file {csv_file}: {e}")
476
+ return {
477
+ "repo": repo_name,
478
+ "status": "write_failed",
479
+ "reason": str(e),
480
+ }
481
+
482
+
483
+ async def process_all_repos(
484
+ repos_dir: Path,
485
+ base_url: str,
486
+ model: str,
487
+ api_key: str,
488
+ log_file: str,
489
+ max_file_chars: int = 200000,
490
+ max_concurrency: int = 8,
491
+ overwrite: bool = False,
492
+ ) -> List[Dict]:
493
+ """
494
+ Process function parsing for all repositories
495
+
496
+ Args:
497
+ repos_dir: Repository root directory
498
+ base_url: LLM API base URL
499
+ model: Model name
500
+ api_key: API key
501
+ log_file: Log file path
502
+ max_file_chars: Maximum file size (chars)
503
+ max_concurrency: Maximum concurrency
504
+ overwrite: Whether to overwrite existing CSV files
505
+
506
+ Returns:
507
+ List of processing results for all repositories
508
+ """
509
+ # Get all repository directories
510
+ repo_dirs = [
511
+ d for d in repos_dir.iterdir()
512
+ if d.is_dir() and not d.name.startswith(".")
513
+ ]
514
+ repo_dirs.sort()
515
+
516
+ logger.info(f"Found {len(repo_dirs)} repositories, starting processing...")
517
+
518
+ # Process each repository sequentially (concurrency is controlled at file level)
519
+ results = []
520
+
521
+ for repo_dir in tqdm(repo_dirs, desc="Processing repos"):
522
+ result = await process_single_repo(
523
+ repo_dir=repo_dir,
524
+ base_url=base_url,
525
+ model=model,
526
+ api_key=api_key,
527
+ log_file=log_file,
528
+ max_file_chars=max_file_chars,
529
+ max_concurrency=max_concurrency,
530
+ overwrite=overwrite,
531
+ )
532
+ results.append(result)
533
+
534
+ return results
535
+
536
+
537
+ if __name__ == "__main__":
538
+ parser = argparse.ArgumentParser(description="Multi-language Function Parsing Tool")
539
+ parser.add_argument(
540
+ "--repos_dir",
541
+ type=str,
542
+ default="/home/weifengsun/tangou1/domain_code/src/workdir/repos_filtered",
543
+ help="Repository root directory path",
544
+ )
545
+ parser.add_argument(
546
+ "--base_url",
547
+ type=str,
548
+ default=os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1"),
549
+ help="LLM API base URL (default: http://localhost:8000/v1)",
550
+ )
551
+ parser.add_argument(
552
+ "--model",
553
+ type=str,
554
+ default="Qwen3",
555
+ help="Model name (default: Qwen3)",
556
+ )
557
+ parser.add_argument(
558
+ "--api_key_env",
559
+ type=str,
560
+ default="OPENAI_API_KEY",
561
+ help="API key environment variable name (default: OPENAI_API_KEY)",
562
+ )
563
+ parser.add_argument(
564
+ "--max_concurrency",
565
+ type=int,
566
+ default=8,
567
+ help="Maximum concurrency (default: 8)",
568
+ )
569
+ parser.add_argument(
570
+ "--max_file_chars",
571
+ type=int,
572
+ default=200000,
573
+ help="Maximum file size in chars (default: 200000)",
574
+ )
575
+ parser.add_argument(
576
+ "--overwrite",
577
+ action="store_true",
578
+ help="Overwrite existing CSV files",
579
+ )
580
+ parser.add_argument(
581
+ "--log_file",
582
+ type=str,
583
+ default="instruction_generation/workdir/logs/extract.log",
584
+ help="Log file path",
585
+ )
586
+
587
+ args = parser.parse_args()
588
+
589
+ # Initialize logger
590
+ init_logger(args.log_file, level="INFO")
591
+
592
+ # Get API key
593
+ api_key = os.getenv(args.api_key_env, "none")
594
+
595
+ # Process all repositories
596
+ repos_dir = Path(args.repos_dir)
597
+ if not repos_dir.exists():
598
+ logger.error(f"Repository directory does not exist: {repos_dir}")
599
+ sys.exit(1)
600
+
601
+ # Create log directory
602
+ log_file_path = Path(args.log_file)
603
+ log_file_path.parent.mkdir(parents=True, exist_ok=True)
604
+
605
+ # Run main logic
606
+ results = asyncio.run(
607
+ process_all_repos(
608
+ repos_dir=repos_dir,
609
+ base_url=args.base_url,
610
+ model=args.model,
611
+ api_key=api_key,
612
+ log_file=str(log_file_path),
613
+ max_file_chars=args.max_file_chars,
614
+ max_concurrency=args.max_concurrency,
615
+ overwrite=args.overwrite,
616
+ )
617
+ )
618
+
619
+ # Statistics
620
+ status_counts = {}
621
+ total_functions = 0
622
+ for result in results:
623
+ status = result["status"]
624
+ status_counts[status] = status_counts.get(status, 0) + 1
625
+ if "function_count" in result:
626
+ total_functions += result["function_count"]
627
+
628
+ logger.info("\n" + "=" * 80)
629
+ logger.info("Processing complete!")
630
+ logger.info("=" * 80)
631
+ logger.info(f"Total: {len(results)} repositories")
632
+ logger.info(f"Total: {total_functions} functions")
633
+ for status, count in status_counts.items():
634
+ logger.info(f" {status}: {count}")
635
+ logger.info("=" * 80)