dataset-builder / data1 /reporting /code_file_stats_fast.py
SunDou's picture
Upload data1/reporting/code_file_stats_fast.py with huggingface_hub
f1c06ef verified
"""
Stage C: 代码文件级统计(优化版 - 大幅提速)
优化策略:
1. 使用简化的统计方法替代复杂正则匹配
2. 对大文件使用粗略估计
3. 断点续传支持
4. 批量处理减少IPC开销
5. 跳过详细函数参数分析,使用快速计数
"""
import os
import json
import sys
from pathlib import Path
from collections import defaultdict, Counter
from tqdm import tqdm
import statistics
import math
from multiprocessing import Pool, cpu_count
import pandas as pd
import pickle
import hashlib
# ============== 快速统计函数(替代复杂正则) ==============
# 函数关键字(用于快速计数)
FUNC_KEYWORDS = {
'python': [b'def '],
'jupyter': [b'def '],
'java': [b'public ', b'private ', b'protected ', b'void ', b'static '],
'c/c++': [b'void ', b'int ', b'float ', b'double ', b'char ', b'bool '],
'go': [b'func '],
'rust': [b'fn '],
'r': [b'function(', b'function ('],
'matlab': [b'function '],
'shell': [b'function ', b'() {'],
'fortran': [b'subroutine ', b'function ', b'SUBROUTINE ', b'FUNCTION '],
}
# 注释标记
COMMENT_MARKERS = {
'python': (b'#', b'"""', b"'''"),
'jupyter': (b'#', b'"""', b"'''"),
'java': (b'//', b'/*'),
'c/c++': (b'//', b'/*'),
'go': (b'//', b'/*'),
'rust': (b'//', b'/*'),
'r': (b'#',),
'matlab': (b'%', b'%{'),
'shell': (b'#',),
'fortran': (b'!',),
}
# 文件扩展名映射
EXT_MAP = {
'.py': 'python', '.java': 'java', '.c': 'c/c++', '.h': 'c/c++',
'.hh': 'c/c++', '.hpp': 'c/c++', '.cpp': 'c/c++', '.cc': 'c/c++',
'.cxx': 'c/c++', '.c++': 'c/c++', '.f': 'fortran', '.f90': 'fortran',
'.f95': 'fortran', '.F': 'fortran', '.r': 'r', '.m': 'matlab',
'.sh': 'shell', '.bash': 'shell', '.rs': 'rust', '.go': 'go',
'.ipynb': 'jupyter'
}
def detect_language_fast(file_path: str) -> str:
"""快速语言检测"""
ext = os.path.splitext(file_path)[1].lower()
return EXT_MAP.get(ext, 'unknown')
def fast_analyze_file(file_path: Path, repo_name: str, max_file_size_bytes: int = 2*1024*1024) -> dict:
"""
快速分析单个代码文件(使用字节操作,比字符串快得多)
"""
try:
file_size = file_path.stat().st_size
if file_size > max_file_size_bytes:
return None
ext = file_path.suffix.lower()
# Notebook 特殊处理
if ext == '.ipynb':
return fast_analyze_notebook(file_path, repo_name, file_size)
# 读取文件(二进制模式,更快)
try:
with open(file_path, 'rb') as f:
content = f.read()
except:
return None
lang = detect_language_fast(str(file_path))
# 快速统计
lines = content.count(b'\n') + 1
# 快速注释行估计(计数注释标记)
comment_lines = 0
if lang in COMMENT_MARKERS:
for marker in COMMENT_MARKERS[lang]:
comment_lines += content.count(marker)
# 粗略估计:假设每个注释标记对应一行注释
comment_lines = min(comment_lines, lines // 2) # 限制最多一半是注释
# 快速函数计数
functions = 0
if lang in FUNC_KEYWORDS:
for keyword in FUNC_KEYWORDS[lang]:
functions += content.count(keyword)
# 快速token估计(空白分割)
tokens = len(content.split())
# 空行计数(快速方法)
empty_lines = content.count(b'\n\n') + content.count(b'\r\n\r\n')
code_lines = max(0, lines - empty_lines - comment_lines)
return {
'repo_name': repo_name,
'file_path': str(file_path.name), # 只保存文件名,减少内存
'file_size_bytes': file_size,
'language': lang,
'total_lines': lines,
'comment_lines': comment_lines,
'code_lines': code_lines,
'tokens': tokens,
'functions': functions,
'parameters': functions * 2, # 粗略估计:平均每个函数2个参数
}
except Exception:
return None
def fast_analyze_notebook(file_path: Path, repo_name: str, file_size: int) -> dict:
"""快速分析 Jupyter Notebook"""
try:
with open(file_path, 'rb') as f:
content = f.read()
# 快速计数 code cells
code_cell_count = content.count(b'"cell_type": "code"') + content.count(b'"cell_type":"code"')
# 估计代码行数
lines = content.count(b'\n') + 1
code_lines = code_cell_count * 10 # 粗略估计每个cell 10行代码
return {
'repo_name': repo_name,
'file_path': str(file_path.name),
'file_size_bytes': file_size,
'language': 'jupyter',
'total_lines': lines,
'comment_lines': code_cell_count, # markdown cells 算注释
'code_lines': code_lines,
'tokens': len(content.split()),
'functions': content.count(b'def '),
'parameters': content.count(b'def ') * 2,
}
except:
return None
def _default_repo_stats():
"""Factory function for defaultdict"""
return {
'total_files': 0,
'total_lines': 0,
'total_code_lines': 0,
'total_comment_lines': 0,
'total_tokens': 0,
'total_functions': 0,
'total_parameters': 0,
'languages': Counter(),
'file_sizes': [],
}
# 跳过目录
SKIP_DIRS = {
'.git', 'node_modules', 'vendor', 'dist', 'build', '__pycache__',
'.pytest_cache', '.ipynb_checkpoints', 'venv', 'env', '.venv',
'target', '.idea', '.vscode', '.mypy_cache', '.tox', '.eggs',
'site-packages', 'lib', 'libs', 'third_party', 'external'
}
# 代码文件扩展名
CODE_EXTENSIONS = {
'.py', '.java', '.c', '.h', '.hh', '.hpp', '.cpp', '.cc', '.cxx', '.c++',
'.f', '.f90', '.f95', '.F', '.r', '.m', '.sh', '.bash', '.rs', '.go',
'.ipynb'
}
def scan_repo_fast(args):
"""快速扫描单个仓库(用于多进程)"""
repo_path, max_file_size_bytes, max_files_per_repo = args
repo_name = repo_path.name
repo_files = []
file_count = 0
try:
for root, dirs, files in os.walk(repo_path):
# 跳过不需要的目录
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
for file in files:
if file_count >= max_files_per_repo:
break
file_path = Path(root) / file
ext = file_path.suffix.lower()
# 只处理代码文件
if ext in CODE_EXTENSIONS:
result = fast_analyze_file(file_path, repo_name, max_file_size_bytes)
if result:
repo_files.append(result)
file_count += 1
if file_count >= max_files_per_repo:
break
except Exception:
pass
return repo_files
class CodeFileStatsFast:
def __init__(self, repos_dir, output_dir, top_n=None, max_file_size_mb=2, max_files_per_repo=500):
self.repos_dir = Path(repos_dir)
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.top_n = top_n
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
self.max_files_per_repo = max_files_per_repo # 限制每个仓库最多分析的文件数
self.file_stats = []
self.repo_stats = defaultdict(_default_repo_stats)
# 断点续传支持
self.checkpoint_file = self.output_dir / 'checkpoint.pkl'
self.processed_repos = set()
def load_checkpoint(self):
"""加载断点"""
if self.checkpoint_file.exists():
try:
with open(self.checkpoint_file, 'rb') as f:
data = pickle.load(f)
self.processed_repos = data.get('processed_repos', set())
self.file_stats = data.get('file_stats', [])
print(f"Loaded checkpoint: {len(self.processed_repos)} repos already processed")
return True
except:
pass
return False
def save_checkpoint(self):
"""保存断点"""
try:
with open(self.checkpoint_file, 'wb') as f:
pickle.dump({
'processed_repos': self.processed_repos,
'file_stats': self.file_stats,
}, f)
except:
pass
def scan_all_repos(self, num_workers=None):
"""扫描所有仓库(优化版)"""
if num_workers is None:
num_workers = min(cpu_count(), 48) # 增加进程数
# 加载断点
self.load_checkpoint()
# 获取所有仓库目录
all_repos = sorted([d for d in self.repos_dir.iterdir() if d.is_dir()])
if self.top_n is None:
selected_repos = all_repos
else:
selected_repos = all_repos[:self.top_n]
# 过滤已处理的仓库
repos_to_process = [r for r in selected_repos if r.name not in self.processed_repos]
print(f"Total repos: {len(selected_repos)} ({'all' if self.top_n is None else f'top {self.top_n}'}), Already processed: {len(self.processed_repos)}, To process: {len(repos_to_process)}")
print(f"Using {num_workers} workers...")
if not repos_to_process:
print("All repos already processed!")
return
# 准备参数
args_list = [(repo, self.max_file_size_bytes, self.max_files_per_repo) for repo in repos_to_process]
# 使用更大的 chunksize 减少 IPC 开销
chunksize = max(1, len(repos_to_process) // (num_workers * 10))
# 多进程处理
processed_count = 0
checkpoint_interval = 500 # 每处理500个仓库保存一次断点
with Pool(processes=num_workers) as pool:
for repo_files in tqdm(
pool.imap_unordered(scan_repo_fast, args_list, chunksize=chunksize),
total=len(repos_to_process),
desc="Scanning repos"
):
if repo_files:
self.file_stats.extend(repo_files)
if repo_files:
self.processed_repos.add(repo_files[0]['repo_name'])
processed_count += 1
# 定期保存断点
if processed_count % checkpoint_interval == 0:
self.save_checkpoint()
print(f"\nCheckpoint saved: {len(self.processed_repos)} repos processed, {len(self.file_stats)} files found")
# 最终保存断点
self.save_checkpoint()
print(f"Found {len(self.file_stats)} code files from {len(self.processed_repos)} repos")
def aggregate_repo_stats(self):
"""聚合仓库级统计(与原版兼容)"""
for file_stat in self.file_stats:
repo = file_stat['repo_name']
self.repo_stats[repo]['total_files'] += 1
self.repo_stats[repo]['total_lines'] += file_stat['total_lines']
self.repo_stats[repo]['total_code_lines'] += file_stat['code_lines']
self.repo_stats[repo]['total_comment_lines'] += file_stat['comment_lines']
self.repo_stats[repo]['total_tokens'] += file_stat['tokens']
self.repo_stats[repo]['total_functions'] += file_stat['functions']
self.repo_stats[repo]['total_parameters'] += file_stat['parameters']
self.repo_stats[repo]['languages'][file_stat['language']] += 1
self.repo_stats[repo]['file_sizes'].append(file_stat['file_size_bytes'])
# 转换为可序列化格式
repo_stats_list = []
for repo, stats in self.repo_stats.items():
total_files = stats['total_files']
if total_files == 0:
continue
stats_dict = {
'repo_name': repo,
'full_name': repo.replace('___', '/'),
'total_files': total_files,
'total_lines': stats['total_lines'],
'total_code_lines': stats['total_code_lines'],
'total_comment_lines': stats['total_comment_lines'],
'total_tokens': stats['total_tokens'],
'total_functions': stats['total_functions'],
'total_parameters': stats['total_parameters'],
'language_count': len(stats['languages']),
'primary_language': stats['languages'].most_common(1)[0][0] if stats['languages'] else 'unknown',
'primary_language_files': stats['languages'].most_common(1)[0][1] if stats['languages'] else 0,
}
# 派生指标
if stats['total_lines'] > 0:
stats_dict['comment_ratio'] = stats['total_comment_lines'] / stats['total_lines']
else:
stats_dict['comment_ratio'] = 0
if stats['total_functions'] > 0:
stats_dict['avg_func_length'] = stats['total_code_lines'] / stats['total_functions']
stats_dict['avg_params_per_func'] = stats['total_parameters'] / stats['total_functions']
else:
stats_dict['avg_func_length'] = 0
stats_dict['avg_params_per_func'] = 0
# 语言多样性(熵)- 与原版兼容
if stats['languages']:
total_lang_files = sum(stats['languages'].values())
entropy = 0
for count in stats['languages'].values():
p = count / total_lang_files
if p > 0:
entropy -= p * math.log2(p)
stats_dict['language_entropy'] = entropy
else:
stats_dict['language_entropy'] = 0
# 文件大小统计 - 与原版兼容
if stats['file_sizes']:
stats_dict['avg_file_size_kb'] = statistics.mean(stats['file_sizes']) / 1024
stats_dict['max_file_size_mb'] = max(stats['file_sizes']) / (1024 * 1024)
else:
stats_dict['avg_file_size_kb'] = 0
stats_dict['max_file_size_mb'] = 0
# 主语言占比 - 与原版兼容
if stats['languages']:
primary_lang_count = stats['languages'].most_common(1)[0][1]
stats_dict['primary_language_ratio'] = primary_lang_count / total_files
else:
stats_dict['primary_language_ratio'] = 0
repo_stats_list.append(stats_dict)
return repo_stats_list
def save_results(self):
"""保存结果"""
# 保存文件级统计(抽样)
file_df = pd.DataFrame(self.file_stats)
if len(file_df) > 10000:
file_df_sample = file_df.sample(n=10000, random_state=42)
else:
file_df_sample = file_df
# 使用与原版相同的文件名,以便兼容 visualization 和 insights
file_df_sample.to_csv(self.output_dir / 'file_level_metrics_sampled.csv', index=False)
# 保存仓库级统计(动态文件名)
repo_stats_list = self.aggregate_repo_stats()
repo_df = pd.DataFrame(repo_stats_list)
top_n_suffix = f"_top{self.top_n}" if self.top_n else ""
repo_df.to_csv(self.output_dir / f'repo_level_metrics{top_n_suffix}.csv', index=False)
# 汇总统计
summary = {
'total_files': len(self.file_stats),
'total_repos': len(self.repo_stats),
'avg_files_per_repo': len(self.file_stats) / len(self.repo_stats) if self.repo_stats else 0,
}
# 按语言统计
lang_counter = Counter(f['language'] for f in self.file_stats)
summary['files_by_language'] = dict(lang_counter.most_common(20))
# 使用与原版相同的文件名
with open(self.output_dir / 'code_stats_summary.json', 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
# 清理断点文件
if self.checkpoint_file.exists():
self.checkpoint_file.unlink()
def run(self, num_workers=None):
"""执行完整流程"""
print("Stage C (Fast): Analyzing code files...")
self.scan_all_repos(num_workers=num_workers)
print("Aggregating repo-level stats...")
print("Saving results...")
self.save_results()
print(f"Code file stats complete! Results saved to {self.output_dir}")
if __name__ == "__main__":
repos_dir = "/home/weifengsun/tangou1/domain_code/src/workdir/repos_filtered"
output_dir = "/home/weifengsun/tangou1/domain_code/src/workdir/reporting/code_stats"
# 使用优化版本
stats = CodeFileStatsFast(
repos_dir,
output_dir,
top_n=15000,
max_file_size_mb=2,
max_files_per_repo=500 # 限制每个仓库最多500个文件
)
stats.run(num_workers=48) # 使用更多进程