Spaces:
Sleeping
Sleeping
| """Generate labeled SQL error dataset at scale.""" | |
| from __future__ import annotations | |
| import argparse | |
| import random | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from src.categories import load_categories | |
| from src.exercises import generate_exercise | |
| from src.sql_templates import ERROR_INJECTORS | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" | |
| def generate_dataset( | |
| total_samples: int = 1_000_000, | |
| output_path: Path = DEFAULT_OUTPUT, | |
| batch_size: int = 10_000, | |
| workers: int = 8, | |
| seed: int = 42, | |
| ) -> Path: | |
| categories = load_categories() | |
| label_ids = [c.id for c in categories] | |
| samples_per_class = total_samples // len(label_ids) | |
| remainder = total_samples % len(label_ids) | |
| # Balanced label schedule: each class gets equal share (+1 for first `remainder` classes) | |
| schedule: List[int] = [] | |
| for cat in categories: | |
| count = samples_per_class + (1 if cat.id < remainder else 0) | |
| schedule.extend([cat.id] * count) | |
| random.Random(seed).shuffle(schedule) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| chunks: List[pd.DataFrame] = [] | |
| num_batches = (total_samples + batch_size - 1) // batch_size | |
| with ProcessPoolExecutor(max_workers=workers) as executor: | |
| futures = [] | |
| offset = 0 | |
| for batch_idx in range(num_batches): | |
| current_batch = min(batch_size, total_samples - offset) | |
| batch_labels = schedule[offset : offset + current_batch] | |
| futures.append( | |
| executor.submit( | |
| _generate_batch_with_labels, | |
| batch_labels, | |
| seed + batch_idx, | |
| ) | |
| ) | |
| offset += current_batch | |
| for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"): | |
| rows = future.result() | |
| chunks.append(pd.DataFrame(rows)) | |
| df = pd.concat(chunks, ignore_index=True) | |
| df = df.sample(frac=1, random_state=seed).reset_index(drop=True) | |
| df.to_parquet(output_path, index=False) | |
| print(f"Saved {len(df):,} samples to {output_path}") | |
| print("\nClass distribution:") | |
| print(df["label_name"].value_counts().sort_index().to_string()) | |
| return output_path | |
| def _generate_batch_with_labels(label_ids: List[int], seed: int) -> List[Dict]: | |
| rng = random.Random(seed) | |
| categories = load_categories() | |
| rows = [] | |
| for label_id in label_ids: | |
| exercise = generate_exercise(rng) | |
| injector = ERROR_INJECTORS[label_id] | |
| query, error_message = injector(rng, exercise) | |
| rows.append( | |
| { | |
| "schema": exercise.schema, | |
| "question": exercise.question, | |
| "correct_query": exercise.correct_query, | |
| "query": query.strip(), | |
| "error_message": error_message, | |
| "label_id": label_id, | |
| "label_name": categories[label_id].name, | |
| } | |
| ) | |
| return rows | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Generate labeled SQL error dataset") | |
| parser.add_argument("--samples", type=int, default=1_000_000, help="Total samples") | |
| parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) | |
| parser.add_argument("--batch-size", type=int, default=10_000) | |
| parser.add_argument("--workers", type=int, default=8) | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| generate_dataset( | |
| total_samples=args.samples, | |
| output_path=args.output, | |
| batch_size=args.batch_size, | |
| workers=args.workers, | |
| seed=args.seed, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |