sql-error-classifier / train_space_app.py
nishu08's picture
Deploy CodeBERT inference Space
8a3099e verified
Raw
History Blame Contribute Delete
7.72 kB
"""
Hugging Face Space — CodeBERT SQL Error Classifier Training UI.
Deploy as a Gradio Space with app_file: train_space_app.py
Set hardware to GPU (t4-small recommended).
Add HF_TOKEN secret to push trained models to your Hub account.
"""
from __future__ import annotations
import json
import os
import shutil
import tempfile
from pathlib import Path
import gradio as gr
import pandas as pd
from src.hf_train_codebert import train
PROJECT_ROOT = Path(__file__).parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
OUTPUT_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
BUNDLED_DATASETS = {
"Dev (15K samples)": str(PROJECT_ROOT / "data" / "sql_errors_dev.parquet"),
"Full (1M samples)": str(PROJECT_ROOT / "data" / "sql_errors_1m.parquet"),
}
def _format_metrics(metrics: dict) -> str:
val = metrics.get("validation", {})
test = metrics.get("test", {})
lines = [
"## Training complete",
"",
f"- Train samples: **{metrics.get('train_samples', 0):,}**",
f"- Val samples: **{metrics.get('val_samples', 0):,}**",
f"- Test samples: **{metrics.get('test_samples', 0):,}**",
"",
"### Validation",
f"- F1 macro: **{val.get('eval_f1_macro', 0):.4f}**",
f"- F1 micro: **{val.get('eval_f1_micro', 0):.4f}**",
"",
"### Test",
f"- F1 macro: **{test.get('f1_macro', 0):.4f}**",
f"- F1 micro: **{test.get('f1_micro', 0):.4f}**",
f"- Subset accuracy: **{test.get('subset_accuracy', 0):.4f}**",
"",
f"Model saved to `{OUTPUT_DIR}`",
]
if metrics.get("hub_url"):
lines.append(f"\n**Hub model:** {metrics['hub_url']}")
return "\n".join(lines)
def run_training(
dataset_choice: str,
uploaded_file,
max_samples: int,
epochs: float,
batch_size: int,
learning_rate: float,
max_length: int,
fp16: bool,
push_to_hub: bool,
hub_model_id: str,
progress=gr.Progress(),
):
progress(0, desc="Preparing dataset...")
if uploaded_file is not None:
data_path = Path(uploaded_file.name)
else:
data_path = Path(BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA))
if not data_path.exists():
return (
f"Dataset not found: `{data_path}`. "
"Upload a parquet file or include data/ in the Space repo.",
None,
None,
)
hub_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
if push_to_hub and not hub_token:
return (
"Add `HF_TOKEN` to Space secrets to push models to the Hub.",
None,
None,
)
if push_to_hub and not hub_model_id.strip():
return "Enter a Hub model id (e.g. `your-username/sql-codebert-classifier`).", None, None
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
samples = int(max_samples) if max_samples and max_samples > 0 else None
progress(0.1, desc="Starting CodeBERT training...")
try:
metrics = train(
data_path=data_path,
output_dir=OUTPUT_DIR,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
max_length=max_length,
max_samples=samples,
fp16=fp16,
save_strategy="no",
push_to_hub=push_to_hub,
hub_model_id=hub_model_id.strip() or None,
hub_token=hub_token,
)
except Exception as exc:
return f"Training failed:\n\n```\n{exc}\n```", None, None
progress(1.0, desc="Done")
if push_to_hub and hub_model_id.strip():
metrics["hub_url"] = f"https://huggingface.co/{hub_model_id.strip()}"
metrics_path = OUTPUT_DIR / "metrics.json"
summary = _format_metrics(metrics)
return summary, str(metrics_path) if metrics_path.exists() else None, str(OUTPUT_DIR)
def load_preview(dataset_choice: str, uploaded_file) -> str:
try:
if uploaded_file is not None:
df = pd.read_parquet(uploaded_file.name)
else:
path = BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA)
if not Path(path).exists():
return f"Dataset not found: {path}"
df = pd.read_parquet(path)
cols = list(df.columns)
sample = df.head(2).to_dict(orient="records")
return f"**Rows:** {len(df):,}\n\n**Columns:** `{cols}`\n\n**Sample:**\n```json\n{json.dumps(sample, indent=2)[:2000]}\n```"
except Exception as exc:
return f"Could not load preview: {exc}"
with gr.Blocks(title="SQL Error Classifier — Train") as demo:
gr.Markdown(
"""
# SQL Error Classifier — CodeBERT Training
Train **microsoft/codebert-base** as a cross-encoder on this Space.
**Input format:** `QUESTION` + `SCHEMA` + `STUDENT_SQL` + `CORRECT_SQL` (single sequence)
**GPU recommended** — upgrade Space hardware to `t4-small` or better.
"""
)
with gr.Row():
with gr.Column(scale=1):
dataset_choice = gr.Dropdown(
choices=list(BUNDLED_DATASETS.keys()),
value="Dev (15K samples)",
label="Bundled dataset",
)
uploaded = gr.File(
label="Or upload parquet",
file_types=[".parquet"],
)
preview_btn = gr.Button("Preview dataset")
preview_out = gr.Markdown()
max_samples = gr.Number(
label="Max samples (0 = all)",
value=5000,
precision=0,
)
epochs = gr.Slider(1, 10, value=2, step=1, label="Epochs")
batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch size")
learning_rate = gr.Number(label="Learning rate", value=2e-5)
max_length = gr.Slider(128, 512, value=512, step=64, label="Max length")
fp16 = gr.Checkbox(label="FP16 (GPU only)", value=True)
push_to_hub = gr.Checkbox(label="Push to Hugging Face Hub", value=False)
hub_model_id = gr.Textbox(
label="Hub model id",
placeholder="your-username/sql-codebert-classifier",
)
train_btn = gr.Button("Start Training", variant="primary")
with gr.Column(scale=1):
result = gr.Markdown(label="Results")
metrics_file = gr.File(label="metrics.json")
model_dir = gr.Textbox(label="Model output path", interactive=False)
preview_btn.click(load_preview, [dataset_choice, uploaded], preview_out)
train_btn.click(
run_training,
[
dataset_choice,
uploaded,
max_samples,
epochs,
batch_size,
learning_rate,
max_length,
fp16,
push_to_hub,
hub_model_id,
],
[result, metrics_file, model_dir],
)
gr.Markdown(
"""
### Space setup
1. Create a Gradio Space and push this repo
2. Set **Hardware → GPU (t4-small)**
3. Add secret `HF_TOKEN` (write token) to push models
4. Include `data/sql_errors_dev.parquet` in the repo (or upload at runtime)
### After training
Use the saved model with:
```python
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
clf = CodeBERTSQLErrorClassifier("models/codebert-cross-encoder")
```
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)