"""Generate labeled SQL error dataset at scale.""" from __future__ import annotations import argparse import random from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path from typing import Dict, List, Tuple import pandas as pd from tqdm import tqdm from src.categories import load_categories from src.exercises import generate_exercise from src.sql_templates import ERROR_INJECTORS PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" def generate_dataset( total_samples: int = 1_000_000, output_path: Path = DEFAULT_OUTPUT, batch_size: int = 10_000, workers: int = 8, seed: int = 42, ) -> Path: categories = load_categories() label_ids = [c.id for c in categories] samples_per_class = total_samples // len(label_ids) remainder = total_samples % len(label_ids) # Balanced label schedule: each class gets equal share (+1 for first `remainder` classes) schedule: List[int] = [] for cat in categories: count = samples_per_class + (1 if cat.id < remainder else 0) schedule.extend([cat.id] * count) random.Random(seed).shuffle(schedule) output_path.parent.mkdir(parents=True, exist_ok=True) chunks: List[pd.DataFrame] = [] num_batches = (total_samples + batch_size - 1) // batch_size with ProcessPoolExecutor(max_workers=workers) as executor: futures = [] offset = 0 for batch_idx in range(num_batches): current_batch = min(batch_size, total_samples - offset) batch_labels = schedule[offset : offset + current_batch] futures.append( executor.submit( _generate_batch_with_labels, batch_labels, seed + batch_idx, ) ) offset += current_batch for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"): rows = future.result() chunks.append(pd.DataFrame(rows)) df = pd.concat(chunks, ignore_index=True) df = df.sample(frac=1, random_state=seed).reset_index(drop=True) df.to_parquet(output_path, index=False) print(f"Saved {len(df):,} samples to {output_path}") print("\nClass distribution:") print(df["label_name"].value_counts().sort_index().to_string()) return output_path def _generate_batch_with_labels(label_ids: List[int], seed: int) -> List[Dict]: rng = random.Random(seed) categories = load_categories() rows = [] for label_id in label_ids: exercise = generate_exercise(rng) injector = ERROR_INJECTORS[label_id] query, error_message = injector(rng, exercise) rows.append( { "schema": exercise.schema, "question": exercise.question, "correct_query": exercise.correct_query, "query": query.strip(), "error_message": error_message, "label_id": label_id, "label_name": categories[label_id].name, } ) return rows def main() -> None: parser = argparse.ArgumentParser(description="Generate labeled SQL error dataset") parser.add_argument("--samples", type=int, default=1_000_000, help="Total samples") parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) parser.add_argument("--batch-size", type=int, default=10_000) parser.add_argument("--workers", type=int, default=8) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() generate_dataset( total_samples=args.samples, output_path=args.output, batch_size=args.batch_size, workers=args.workers, seed=args.seed, ) if __name__ == "__main__": main()