import traceback import os import json import sys from concurrent.futures import ProcessPoolExecutor, as_completed from tree_sitter import Language, Parser import tree_sitter_c import tree_sitter_cpp import tree_sitter_java import tree_sitter_go import tree_sitter_rust import tree_sitter_julia import tree_sitter_python ######################################## # 配置区 ######################################## LANG_SO = "build/my-languages.so" OUTPUT_DIR = "/home/weifengsun/tangou1/step2/step22/dataset" EXCLUDE_DIRS = { ".git", "node_modules", "vendor", "third_party", "build", "dist", "target", "__pycache__" } LANGUAGE_CONFIG = { "python": { "ext": [".py"], "function_nodes": ["function_definition"], "class_nodes": ["class_definition"], "name_field": "name", }, "c": { "ext": [".c", ".h"], "function_nodes": ["function_definition"], "name_field": "declarator", }, "cpp": { "ext": [".cpp", ".cc", ".cxx", ".hpp", ".hh"], "function_nodes": ["function_definition"], "name_field": "declarator", }, "java": { "ext": [".java"], "function_nodes": ["method_declaration", "constructor_declaration"], "class_nodes": ["class_declaration"], "name_field": "name", }, "go": { "ext": [".go"], "function_nodes": ["function_declaration", "method_declaration"], "name_field": "name", "receiver_field": "receiver", }, "rust": { "ext": [".rs"], "function_nodes": ["function_item"], "class_nodes": ["impl_item", "trait_item"], "name_field": "name", }, "julia": { "ext": [".jl"], "function_nodes": ["function_definition"], "name_field": "name", }, } EXT_TO_LANG = {} for lang, cfg in LANGUAGE_CONFIG.items(): for e in cfg["ext"]: EXT_TO_LANG[e] = lang ######################################## # worker 初始化 ######################################## LANGUAGES = { "python": Language(tree_sitter_python.language()), "go": Language(tree_sitter_go.language()), "rust": Language(tree_sitter_rust.language()), "julia": Language(tree_sitter_julia.language()), "c": Language(tree_sitter_c.language()), "cpp": Language(tree_sitter_cpp.language()), "java": Language(tree_sitter_java.language()), } def init_worker(): global PARSERS PARSERS = {} for lang in LANGUAGE_CONFIG: try: # LANGUAGE=Language(LANG_SO, lang) parser = Parser(LANGUAGES[lang]) PARSERS[lang] = parser except Exception: print(f"Failed to load parser for {lang}") print(traceback.format_exc()) pass ######################################## # 函数提取逻辑 ######################################## def extract_functions(tree, file_path, language): cfg = LANGUAGE_CONFIG[language] results = [] def walk(node, scope): # class / impl 作用域 if node.type in cfg.get("class_nodes", []): name_node = node.child_by_field_name("name") or \ node.child_by_field_name("type") if name_node: scope.append(name_node.text.decode()) # Go receiver if language == "go" and node.type in cfg["function_nodes"]: recv = node.child_by_field_name("receiver") if recv: scope.append(recv.text.decode()) # 函数定义 if node.type in cfg["function_nodes"]: name_node = node.child_by_field_name(cfg["name_field"]) if name_node: name = name_node.text.decode() qual = ".".join(scope + [name]) results.append({ "language": language, "name": name, "qualified_name": qual, "file": file_path, "start_line": node.start_point[0] + 1, "end_line": node.end_point[0] + 1, }) # Julia 简写函数 foo(x)=... if language == "julia" and node.type == "assignment": left = node.child(0) if left and left.type == "call_expression": fn = left.child_by_field_name("function") if fn: name = fn.text.decode() results.append({ "language": language, "name": name, "qualified_name": name, "file": file_path, "start_line": node.start_point[0] + 1, "end_line": node.end_point[0] + 1, }) for c in node.children: walk(c, scope) if node.type in cfg.get("class_nodes", []): scope.pop() walk(tree.root_node, []) return results ######################################## # 项目处理 ######################################## def process_project(project_path): project_name = os.path.basename(project_path.rstrip("/")) output_path = os.path.join(OUTPUT_DIR, project_name, "functions.jsonl") with open(output_path, "w", encoding="utf-8") as out: for root, dirs, files in os.walk(project_path): dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS] for f in files: ext = os.path.splitext(f)[1] lang = EXT_TO_LANG.get(ext) if not lang: continue path = os.path.join(root, f) try: code = open(path, "rb").read() tree = PARSERS[lang].parse(code) funcs = extract_functions(tree, path, lang) for fn in funcs: out.write(json.dumps(fn, ensure_ascii=False) + "\n") except Exception: continue ######################################## # 主入口 ######################################## def load_projects(root): return [ os.path.join(root, d) for d in os.listdir(root) if os.path.isdir(os.path.join(root, d)) ] def main(): # if len(sys.argv) != 2: # print("Usage: python extract_functions.py ") # sys.exit(1) # projects_root = sys.argv[1] projects_root = "/home/weifengsun/tangou1/domain_code/src/workdir/repos_filtered" os.makedirs(OUTPUT_DIR, exist_ok=True) projects = load_projects(projects_root) with ProcessPoolExecutor( max_workers=min(os.cpu_count(), 32), initializer=init_worker ) as pool: futures = { pool.submit(process_project, p): p for p in projects } for f in as_completed(futures): proj = futures[f] try: f.result() print(f"[OK] {proj}") except Exception as e: print(f"[FAIL] {proj}: {e}") if __name__ == "__main__": main()