File size: 3,868 Bytes
8a3099e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Generate labeled SQL error dataset at scale."""

from __future__ import annotations

import argparse
import random
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd
from tqdm import tqdm

from src.categories import load_categories
from src.exercises import generate_exercise
from src.sql_templates import ERROR_INJECTORS

PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"


def generate_dataset(
    total_samples: int = 1_000_000,
    output_path: Path = DEFAULT_OUTPUT,
    batch_size: int = 10_000,
    workers: int = 8,
    seed: int = 42,
) -> Path:
    categories = load_categories()
    label_ids = [c.id for c in categories]
    samples_per_class = total_samples // len(label_ids)
    remainder = total_samples % len(label_ids)

    # Balanced label schedule: each class gets equal share (+1 for first `remainder` classes)
    schedule: List[int] = []
    for cat in categories:
        count = samples_per_class + (1 if cat.id < remainder else 0)
        schedule.extend([cat.id] * count)
    random.Random(seed).shuffle(schedule)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    chunks: List[pd.DataFrame] = []
    num_batches = (total_samples + batch_size - 1) // batch_size

    with ProcessPoolExecutor(max_workers=workers) as executor:
        futures = []
        offset = 0
        for batch_idx in range(num_batches):
            current_batch = min(batch_size, total_samples - offset)
            batch_labels = schedule[offset : offset + current_batch]
            futures.append(
                executor.submit(
                    _generate_batch_with_labels,
                    batch_labels,
                    seed + batch_idx,
                )
            )
            offset += current_batch

        for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
            rows = future.result()
            chunks.append(pd.DataFrame(rows))

    df = pd.concat(chunks, ignore_index=True)
    df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
    df.to_parquet(output_path, index=False)

    print(f"Saved {len(df):,} samples to {output_path}")
    print("\nClass distribution:")
    print(df["label_name"].value_counts().sort_index().to_string())
    return output_path


def _generate_batch_with_labels(label_ids: List[int], seed: int) -> List[Dict]:
    rng = random.Random(seed)
    categories = load_categories()
    rows = []
    for label_id in label_ids:
        exercise = generate_exercise(rng)
        injector = ERROR_INJECTORS[label_id]
        query, error_message = injector(rng, exercise)
        rows.append(
            {
                "schema": exercise.schema,
                "question": exercise.question,
                "correct_query": exercise.correct_query,
                "query": query.strip(),
                "error_message": error_message,
                "label_id": label_id,
                "label_name": categories[label_id].name,
            }
        )
    return rows


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate labeled SQL error dataset")
    parser.add_argument("--samples", type=int, default=1_000_000, help="Total samples")
    parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
    parser.add_argument("--batch-size", type=int, default=10_000)
    parser.add_argument("--workers", type=int, default=8)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    generate_dataset(
        total_samples=args.samples,
        output_path=args.output,
        batch_size=args.batch_size,
        workers=args.workers,
        seed=args.seed,
    )


if __name__ == "__main__":
    main()