sql-error-classifier / src /generate_dataset.py
nishu08's picture
Deploy CodeBERT inference Space
8a3099e verified
Raw
History Blame Contribute Delete
3.87 kB
"""Generate labeled SQL error dataset at scale."""
from __future__ import annotations
import argparse
import random
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
from tqdm import tqdm
from src.categories import load_categories
from src.exercises import generate_exercise
from src.sql_templates import ERROR_INJECTORS
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
def generate_dataset(
total_samples: int = 1_000_000,
output_path: Path = DEFAULT_OUTPUT,
batch_size: int = 10_000,
workers: int = 8,
seed: int = 42,
) -> Path:
categories = load_categories()
label_ids = [c.id for c in categories]
samples_per_class = total_samples // len(label_ids)
remainder = total_samples % len(label_ids)
# Balanced label schedule: each class gets equal share (+1 for first `remainder` classes)
schedule: List[int] = []
for cat in categories:
count = samples_per_class + (1 if cat.id < remainder else 0)
schedule.extend([cat.id] * count)
random.Random(seed).shuffle(schedule)
output_path.parent.mkdir(parents=True, exist_ok=True)
chunks: List[pd.DataFrame] = []
num_batches = (total_samples + batch_size - 1) // batch_size
with ProcessPoolExecutor(max_workers=workers) as executor:
futures = []
offset = 0
for batch_idx in range(num_batches):
current_batch = min(batch_size, total_samples - offset)
batch_labels = schedule[offset : offset + current_batch]
futures.append(
executor.submit(
_generate_batch_with_labels,
batch_labels,
seed + batch_idx,
)
)
offset += current_batch
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
rows = future.result()
chunks.append(pd.DataFrame(rows))
df = pd.concat(chunks, ignore_index=True)
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
df.to_parquet(output_path, index=False)
print(f"Saved {len(df):,} samples to {output_path}")
print("\nClass distribution:")
print(df["label_name"].value_counts().sort_index().to_string())
return output_path
def _generate_batch_with_labels(label_ids: List[int], seed: int) -> List[Dict]:
rng = random.Random(seed)
categories = load_categories()
rows = []
for label_id in label_ids:
exercise = generate_exercise(rng)
injector = ERROR_INJECTORS[label_id]
query, error_message = injector(rng, exercise)
rows.append(
{
"schema": exercise.schema,
"question": exercise.question,
"correct_query": exercise.correct_query,
"query": query.strip(),
"error_message": error_message,
"label_id": label_id,
"label_name": categories[label_id].name,
}
)
return rows
def main() -> None:
parser = argparse.ArgumentParser(description="Generate labeled SQL error dataset")
parser.add_argument("--samples", type=int, default=1_000_000, help="Total samples")
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
parser.add_argument("--batch-size", type=int, default=10_000)
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
generate_dataset(
total_samples=args.samples,
output_path=args.output,
batch_size=args.batch_size,
workers=args.workers,
seed=args.seed,
)
if __name__ == "__main__":
main()