| |
| |
| """ |
| 训练数据平衡脚本 - 针对旅游资源名称抽取任务 |
| |
| 主要功能: |
| 1. 分析当前数据分布 |
| 2. 对低频资源进行上采样 |
| 3. 生成数据增强样本 |
| 4. 输出平衡后的训练集 |
| """ |
|
|
| import json |
| import random |
| import copy |
| from collections import Counter, defaultdict |
| from typing import List, Dict, Any |
| import argparse |
|
|
| class DataBalancer: |
| def __init__(self, input_file: str): |
| self.input_file = input_file |
| self.data = self.load_data() |
| self.resource_counts = Counter() |
| self.resource_samples = defaultdict(list) |
| self.analyze_distribution() |
| |
| def load_data(self) -> List[Dict]: |
| """加载训练数据""" |
| with open(self.input_file, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| |
| def analyze_distribution(self): |
| """分析资源分布""" |
| for idx, item in enumerate(self.data): |
| if 'output' in item: |
| try: |
| output_data = json.loads(item['output']) |
| if 'resource_names' in output_data: |
| resources = output_data['resource_names'] |
| for resource in resources: |
| self.resource_counts[resource] += 1 |
| self.resource_samples[resource].append(idx) |
| except: |
| continue |
| |
| def get_balance_strategy(self, target_min_samples: int = 5) -> Dict[str, int]: |
| """ |
| 计算平衡策略 |
| |
| Args: |
| target_min_samples: 目标最小样本数 |
| |
| Returns: |
| Dict[资源名称, 需要增加的样本数] |
| """ |
| balance_strategy = {} |
| |
| print("📊 当前资源分布分析:") |
| print("-" * 50) |
| |
| for resource, count in self.resource_counts.most_common(): |
| if count < target_min_samples: |
| needed = target_min_samples - count |
| balance_strategy[resource] = needed |
| print(f"❌ {resource}: {count} -> {target_min_samples} (需要+{needed})") |
| else: |
| print(f"✅ {resource}: {count}") |
| |
| return balance_strategy |
| |
| def create_augmented_sample(self, original_idx: int, target_resource: str) -> Dict[str, Any]: |
| """ |
| 创建数据增强样本 |
| |
| 策略: |
| 1. 保持原有的instruction不变 |
| 2. 修改input中的关键信息(日期、人数、联系人等) |
| 3. 保持目标资源在output中 |
| """ |
| original = copy.deepcopy(self.data[original_idx]) |
| |
| |
| dates = ["7月15日", "7月16日", "7月19日", "7月21日", "7月22日", "7月25日", "7月26日", "8月1日", "8月2日", "8月5日"] |
| |
| |
| people_counts = ["15人", "25人", "35人", "45人", "55人", "8人", "12人", "18人", "22人", "28人"] |
| |
| |
| phone_endings = ["1234", "5678", "9012", "3456", "7890", "2468", "1357", "9753", "8642", "0246"] |
| |
| input_text = original['input'] |
| |
| |
| for date in ["7月17日", "7月18日", "7月20日", "7月28日", "7月29日", "7月30日", "7月31日"]: |
| if date in input_text: |
| input_text = input_text.replace(date, random.choice(dates)) |
| break |
| |
| |
| import re |
| people_pattern = r'\d+人' |
| matches = re.findall(people_pattern, input_text) |
| if matches: |
| for match in matches: |
| input_text = input_text.replace(match, random.choice(people_counts), 1) |
| |
| |
| phone_pattern = r'1[3-9]\d{9}' |
| def replace_phone(match): |
| phone = match.group() |
| return phone[:-4] + random.choice(phone_endings) |
| |
| input_text = re.sub(phone_pattern, replace_phone, input_text) |
| |
| |
| new_sample = copy.deepcopy(original) |
| new_sample['input'] = input_text |
| |
| return new_sample |
| |
| def balance_data(self, target_min_samples: int = 5) -> List[Dict[str, Any]]: |
| """ |
| 平衡数据集 |
| |
| Args: |
| target_min_samples: 目标最小样本数 |
| |
| Returns: |
| 平衡后的数据集 |
| """ |
| balance_strategy = self.get_balance_strategy(target_min_samples) |
| |
| if not balance_strategy: |
| print("✅ 数据已经平衡,无需调整") |
| return self.data |
| |
| print(f"\n🔄 开始数据平衡,目标最小样本数: {target_min_samples}") |
| print("-" * 50) |
| |
| balanced_data = copy.deepcopy(self.data) |
| |
| for resource, needed_count in balance_strategy.items(): |
| print(f"📈 正在增强 '{resource}' 的样本...") |
| |
| |
| original_samples = self.resource_samples[resource] |
| |
| for i in range(needed_count): |
| |
| source_idx = random.choice(original_samples) |
| augmented_sample = self.create_augmented_sample(source_idx, resource) |
| balanced_data.append(augmented_sample) |
| |
| print(f" ✅ 已添加 {needed_count} 个增强样本") |
| |
| return balanced_data |
| |
| def save_balanced_data(self, balanced_data: List[Dict], output_file: str): |
| """保存平衡后的数据""" |
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(balanced_data, f, ensure_ascii=False, indent=2) |
| |
| print(f"\n💾 已保存平衡后的数据到: {output_file}") |
| print(f" 原始样本数: {len(self.data)}") |
| print(f" 平衡后样本数: {len(balanced_data)}") |
| print(f" 新增样本数: {len(balanced_data) - len(self.data)}") |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='平衡旅游资源训练数据') |
| parser.add_argument('--input', required=True, help='输入的训练数据文件') |
| parser.add_argument('--output', required=True, help='输出的平衡数据文件') |
| parser.add_argument('--min-samples', type=int, default=5, help='目标最小样本数 (默认: 5)') |
| |
| args = parser.parse_args() |
| |
| print("🚀 开始数据平衡流程...") |
| print("=" * 60) |
| |
| |
| balancer = DataBalancer(args.input) |
| |
| |
| balanced_data = balancer.balance_data(args.min_samples) |
| |
| |
| balancer.save_balanced_data(balanced_data, args.output) |
| |
| print("\n🎉 数据平衡完成!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|