| |
| |
| """ |
| 数据集集成脚本:将身份证数据集成到OCR文本训练数据中 |
| 同时改进身份证数据的instruction格式 |
| """ |
|
|
| import json |
| import os |
| from typing import List, Dict, Any |
|
|
|
|
| def load_json_data(file_path: str) -> List[Dict[str, Any]]: |
| """加载JSON数据""" |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| print(f"成功加载 {file_path},包含 {len(data)} 条记录") |
| return data |
| except Exception as e: |
| print(f"加载文件失败 {file_path}: {e}") |
| return [] |
|
|
|
|
| def improve_idcard_instruction(original_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| 改进身份证数据的instruction格式,参考OCR文本数据的详细写法 |
| """ |
| improved_data = [] |
| |
| |
| detailed_instruction = """请从OCR文本中抽取旅行订单中的游客身份证信息。 |
| |
| 提取字段及说明: |
| name (姓名): 游客的真实姓名 |
| idcard (身份证号): 18位身份证号码,支持末位为X的格式 |
| gender (性别): 根据身份证号倒数第二位数字判断(奇数为男,偶数为女) |
| phone (电话号码): 游客联系电话,如果OCR文本中没有则为null |
| |
| 严格按照以下JSON格式输出: |
| [ |
| { |
| "name": "姓名", |
| "idcard": "身份证号", |
| "gender": "男/女", |
| "phone": "电话号码或null" |
| } |
| ] |
| |
| 注意事项: |
| 1. 确保身份证号码格式正确(18位数字,末位可为X) |
| 2. 性别根据身份证号自动判断,不依赖OCR文本中的性别信息 |
| 3. 如果姓名缺失或无法识别,name字段设为"无" |
| 4. 电话号码如果在OCR文本中不存在,设置为null |
| 5. 输出必须是有效的JSON数组格式""" |
|
|
| for item in original_data: |
| improved_item = item.copy() |
| improved_item['instruction'] = detailed_instruction |
| improved_data.append(improved_item) |
| |
| print(f"改进了 {len(improved_data)} 条身份证数据的instruction格式") |
| return improved_data |
|
|
|
|
| def integrate_datasets(ocr_text_data: List[Dict[str, Any]], |
| idcard_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| 集成两个数据集 |
| """ |
| print("正在集成数据集...") |
| |
| |
| integrated_data = ocr_text_data + idcard_data |
| |
| print(f"集成完成:") |
| print(f"- OCR文本数据:{len(ocr_text_data)} 条") |
| print(f"- 身份证数据:{len(idcard_data)} 条") |
| print(f"- 总计:{len(integrated_data)} 条") |
| |
| return integrated_data |
|
|
|
|
| def save_integrated_dataset(data: List[Dict[str, Any]], output_path: str): |
| """ |
| 保存集成后的数据集 |
| """ |
| try: |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
| print(f"集成数据集已保存到: {output_path}") |
| except Exception as e: |
| print(f"保存文件失败: {e}") |
|
|
|
|
| def validate_dataset(data: List[Dict[str, Any]]) -> bool: |
| """ |
| 验证数据集格式正确性 |
| """ |
| print("正在验证数据集格式...") |
| |
| required_fields = ['instruction', 'input', 'output'] |
| issues = [] |
| |
| for i, item in enumerate(data): |
| |
| for field in required_fields: |
| if field not in item: |
| issues.append(f"记录 {i}: 缺少字段 '{field}'") |
| |
| |
| if 'instruction' in item and not item['instruction'].strip(): |
| issues.append(f"记录 {i}: instruction字段为空") |
| |
| |
| if 'output' in item: |
| try: |
| json.loads(item['output']) |
| except json.JSONDecodeError: |
| issues.append(f"记录 {i}: output字段不是有效的JSON格式") |
| |
| if issues: |
| print(f"发现 {len(issues)} 个问题:") |
| for issue in issues[:10]: |
| print(f" - {issue}") |
| if len(issues) > 10: |
| print(f" - ... 还有 {len(issues) - 10} 个问题") |
| return False |
| else: |
| print("数据集格式验证通过!") |
| return True |
|
|
|
|
| def main(): |
| """ |
| 主函数 |
| """ |
| print("开始数据集集成任务...") |
| |
| |
| ocr_text_file = "/home/ziqiang/LLaMA-Factory/data/ocr_text_orders_08_18.json" |
| idcard_file = "/home/ziqiang/LLaMA-Factory/data/ocr_idcards_orders.json" |
| output_file = "/home/ziqiang/LLaMA-Factory/data/text_idcards_8_18.json" |
| |
| |
| print("\n1. 加载原始数据集...") |
| ocr_text_data = load_json_data(ocr_text_file) |
| idcard_data = load_json_data(idcard_file) |
| |
| if not ocr_text_data or not idcard_data: |
| print("数据加载失败,终止执行") |
| return |
| |
| |
| print("\n2. 改进身份证数据的instruction格式...") |
| improved_idcard_data = improve_idcard_instruction(idcard_data) |
| |
| |
| print("\n3. 集成数据集...") |
| integrated_data = integrate_datasets(ocr_text_data, improved_idcard_data) |
| |
| |
| print("\n4. 验证数据集格式...") |
| if not validate_dataset(integrated_data): |
| print("数据集验证失败,请检查数据格式") |
| return |
| |
| |
| print("\n5. 保存集成数据集...") |
| save_integrated_dataset(integrated_data, output_file) |
| |
| print(f"\n✅ 数据集集成完成!") |
| print(f"集成后的数据集保存在: {output_file}") |
| print(f"总记录数: {len(integrated_data)}") |
| |
| |
| ocr_count = len(ocr_text_data) |
| idcard_count = len(improved_idcard_data) |
| print(f"\n数据构成:") |
| print(f"- OCR文本数据: {ocr_count} 条 ({ocr_count/len(integrated_data)*100:.1f}%)") |
| print(f"- 身份证数据: {idcard_count} 条 ({idcard_count/len(integrated_data)*100:.1f}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|