| | import json |
| | import random |
| | import time |
| | import sys |
| | from typing import List, Dict, Any |
| | from synthetic_data.pipeline import SyntheticDataPipeline |
| | from synthetic_data.validate import validate_synthetic_data |
| |
|
| | CATEGORY_DISTRIBUTION = { |
| | "company.brand_core": 0.10, |
| | "company.strategic_signatures": 0.08, |
| | "company.knowledge_artifacts": 0.08, |
| | "company.business_priorities": 0.10, |
| | "company.tools_config": 0.07, |
| | "company.performance_context": 0.09, |
| | "user.communication_style": 0.10, |
| | "user.strategic_approach": 0.09, |
| | "user.role_context": 0.07, |
| | "user.workflow_patterns": 0.08, |
| | "user.session_history": 0.06, |
| | "user.interaction_preferences": 0.08, |
| | "none": 0.10 |
| | } |
| |
|
| | def run_pipeline_batches(total_items: int = 100, batch_size: int = 10): |
| | pipeline = SyntheticDataPipeline() |
| | categories = list(CATEGORY_DISTRIBUTION.keys()) |
| | weights = list(CATEGORY_DISTRIBUTION.values()) |
| | |
| | all_data = [] |
| | num_batches = max(1, total_items // batch_size) |
| | |
| | print(f"Starting generation of {total_items} items in {num_batches} batches (Size: {batch_size})...") |
| |
|
| | for batch_num in range(1, num_batches + 1): |
| | print(f"\n=== Processing Batch {batch_num}/{num_batches} ===") |
| | batch_data = [] |
| | |
| | while len(batch_data) < batch_size: |
| | category = random.choices(categories, weights=weights, k=1)[0] |
| | current_count = len(batch_data) + 1 |
| | print(f" Generating item {current_count}/{batch_size} (Category: {category})...") |
| | |
| | |
| | distractor = None |
| | if random.random() < 0.30 and category != "none": |
| | possible_distractors = [c for c in categories if c != category and c != "none"] |
| | if possible_distractors: |
| | distractor = random.choice(possible_distractors) |
| |
|
| | persistence = _get_persistence_for_category(category) |
| | turns = random.randint(4, 10) |
| | |
| | scenario = pipeline.generate_scenario_spec( |
| | category=category, |
| | distractor=distractor, |
| | persistence=persistence, |
| | turns=turns |
| | ) |
| | |
| | if not scenario: |
| | print(f" Failed to generate scenario for {category}. Retrying...") |
| | time.sleep(20) |
| | continue |
| | |
| | conversation = pipeline.generate_conversation(scenario, turn_count=turns) |
| | |
| | if conversation: |
| | batch_data.append(conversation) |
| | print(f" Generated: {conversation.get('scenario_id', 'Unknown ID')}") |
| | else: |
| | print(f" Failed to generate conversation for {category}. Retrying...") |
| | time.sleep(20) |
| | continue |
| | |
| | print(" Sleeping for 15s to avoid rate limits...") |
| | time.sleep(15) |
| | |
| | |
| | batch_filename = f"synthetic_data/batch_{batch_num:02d}.json" |
| | with open(batch_filename, "w") as f: |
| | json.dump(batch_data, f, indent=2) |
| | print(f" Saved batch to {batch_filename}") |
| | |
| | |
| | print(" Validating batch...") |
| | metrics = validate_synthetic_data(batch_filename) |
| | print(json.dumps(metrics, indent=2)) |
| | |
| | all_data.extend(batch_data) |
| | |
| | |
| | with open("synthetic_data/all_generated_data_100.json", "w") as f: |
| | json.dump(all_data, f, indent=2) |
| | print(f"\nCompleted. Total items generated: {len(all_data)}") |
| | print("Full dataset saved to synthetic_data/all_generated_data_100.json") |
| |
|
| | def _get_persistence_for_category(category: str) -> str: |
| | if "brand_core" in category or "strategic_signatures" in category or "knowledge_artifacts" in category or "communication_style" in category or "strategic_approach" in category: |
| | return "long" |
| | elif "tools_config" in category or "role_context" in category or "workflow_patterns" in category: |
| | return "medium" |
| | elif "business_priorities" in category or "session_history" in category: |
| | return "short" |
| | elif "performance_context" in category: |
| | return "rolling" |
| | elif "interaction_preferences" in category: |
| | return "evolving" |
| | elif "none" in category: |
| | return "short" |
| | return "medium" |
| |
|
| | if __name__ == "__main__": |
| | total = int(sys.argv[1]) if len(sys.argv) > 1 else 100 |
| | batch = int(sys.argv[2]) if len(sys.argv) > 2 else 10 |
| | run_pipeline_batches(total, batch) |
| |
|