| |
| |
| """ |
| 执行数据平衡的主脚本 |
| |
| 结合你的具体需求: |
| - 新叶古村-新叶古村门票: 1 -> 5 (+4) |
| - 大慈岩-大慈岩索道: 2 -> 5 (+3) |
| - 其他低频资源也会被相应增强 |
| """ |
|
|
| import json |
| import sys |
| import os |
| from advanced_data_augmentation import AdvancedDataAugmenter |
|
|
| def load_training_data(file_path: str): |
| """加载原始训练数据""" |
| with open(file_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
|
|
| def merge_enhanced_samples(original_data, enhanced_samples): |
| """合并原始数据和增强样本""" |
| return original_data + enhanced_samples |
|
|
| def analyze_final_distribution(data): |
| """分析最终的数据分布""" |
| from collections import Counter |
| |
| resource_counts = Counter() |
| |
| for item in data: |
| if 'output' in item: |
| try: |
| output_data = json.loads(item['output']) |
| if 'resource_names' in output_data: |
| resources = output_data['resource_names'] |
| for resource in resources: |
| resource_counts[resource] += 1 |
| except: |
| continue |
| |
| print("📊 最终数据分布:") |
| print("-" * 60) |
| |
| |
| focus_resources = [ |
| "新叶古村-新叶古村门票", |
| "大慈岩-大慈岩索道", |
| "灵栖洞-灵栖洞西游魔毯", |
| "宿江公司-江清月近人实景演艺门票" |
| ] |
| |
| print("🎯 重点关注的资源:") |
| for resource in focus_resources: |
| count = resource_counts.get(resource, 0) |
| print(f" {resource}: {count}") |
| |
| print(f"\n📈 所有资源分布 (总计 {len(resource_counts)} 种资源):") |
| for resource, count in resource_counts.most_common(): |
| status = "✅" if count >= 5 else "⚠️" |
| print(f" {status} {resource}: {count}") |
|
|
| def main(): |
| |
| input_files = [ |
| "/home/ziqiang/LLaMA-Factory/data/ocr_text_orders_08_14_test_v4.json" |
| ] |
| |
| |
| training_file = None |
| for file_path in input_files: |
| if os.path.exists(file_path): |
| training_file = file_path |
| break |
| |
| if not training_file: |
| print("❌ 未找到训练数据文件,请检查路径:") |
| for file_path in input_files: |
| print(f" {file_path}") |
| return |
| |
| print(f"📂 使用训练数据文件: {training_file}") |
| print("=" * 60) |
| |
| |
| print("📥 加载原始训练数据...") |
| original_data = load_training_data(training_file) |
| print(f" 原始样本数: {len(original_data)}") |
| |
| |
| print("\n🔄 生成增强样本...") |
| augmenter = AdvancedDataAugmenter() |
| enhanced_samples = augmenter.generate_all_samples() |
| |
| |
| print(f"\n🔗 合并原始数据和增强样本...") |
| balanced_data = merge_enhanced_samples(original_data, enhanced_samples) |
| print(f" 合并后样本数: {len(balanced_data)}") |
| print(f" 新增样本数: {len(enhanced_samples)}") |
| |
| |
| output_file = "balanced_training_data.json" |
| print(f"\n💾 保存平衡后的数据到: {output_file}") |
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(balanced_data, f, ensure_ascii=False, indent=2) |
| |
| |
| print(f"\n📊 分析最终数据分布...") |
| analyze_final_distribution(balanced_data) |
| |
| print(f"\n🎉 数据平衡完成!") |
| print("📋 建议的下一步:") |
| print(" 1. 使用 balanced_training_data.json 重新训练模型") |
| print(" 2. 在验证集上测试性能改进") |
| print(" 3. 特别关注新叶古村、大慈岩索道等低频资源的识别效果") |
|
|
| if __name__ == "__main__": |
| main() |
|
|