Spaces:
Sleeping
Sleeping
Deploy CodeBERT training Space
Browse files- .gitignore +13 -0
- README.md +41 -8
- README_HF_SPACE.md +17 -0
- README_TRAIN_SPACE.md +46 -0
- app.py +156 -0
- config/codebert_labels.yaml +29 -0
- config/error_categories.yaml +46 -0
- hub/CODEBERT_MODEL_CARD.md +65 -0
- hub/MODEL_CARD.md +97 -0
- requirements.txt +15 -0
- scripts/create_hf_package.py +35 -0
- scripts/deploy_train_space.sh +54 -0
- scripts/push_to_hub.py +87 -0
- scripts/run_codebert_training.sh +17 -0
- scripts/run_pipeline.sh +16 -0
- src/__init__.py +1 -0
- src/categories.py +35 -0
- src/codebert_dataset.py +117 -0
- src/codebert_formatting.py +28 -0
- src/codebert_labels.py +82 -0
- src/cross_encoder_model.py +312 -0
- src/evaluate.py +114 -0
- src/exercises.py +228 -0
- src/generate_dataset.py +115 -0
- src/hf_eval_codebert.py +69 -0
- src/hf_metrics.py +52 -0
- src/hf_predict_codebert.py +161 -0
- src/hf_train_codebert.py +226 -0
- src/huggingface.py +210 -0
- src/model.py +321 -0
- src/multi_tower_model.py +175 -0
- src/predict.py +132 -0
- src/sql_features.py +81 -0
- src/sql_templates.py +258 -0
- src/train.py +198 -0
- train_space_app.py +230 -0
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.venv/
|
| 4 |
+
venv/
|
| 5 |
+
.env
|
| 6 |
+
data/*.parquet
|
| 7 |
+
data/*.csv
|
| 8 |
+
models/*.joblib
|
| 9 |
+
models/evaluation/
|
| 10 |
+
.DS_Store
|
| 11 |
+
*.egg-info/
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
README.md
CHANGED
|
@@ -1,13 +1,46 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
|
| 9 |
-
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SQL Error Classifier Training
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: train_space_app.py
|
|
|
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
hardware: t4-small
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# SQL Error Classifier — CodeBERT Training Space
|
| 15 |
+
|
| 16 |
+
Train `microsoft/codebert-base` as a **cross-encoder** for multi-label SQL error classification.
|
| 17 |
+
|
| 18 |
+
## Setup
|
| 19 |
+
|
| 20 |
+
1. **Hardware:** Settings → Hardware → **GPU t4-small** (recommended)
|
| 21 |
+
2. **Secrets:** Settings → Secrets → add `HF_TOKEN` (Hugging Face write token) to push models to your account
|
| 22 |
+
3. **Data:** Include `data/sql_errors_dev.parquet` in this Space repo, or upload parquet at runtime
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
1. Choose bundled dataset or upload your own parquet
|
| 27 |
+
2. Set epochs, batch size, max samples
|
| 28 |
+
3. Click **Start Training**
|
| 29 |
+
4. Optionally enable **Push to Hub** with model id `your-username/sql-codebert-classifier`
|
| 30 |
+
|
| 31 |
+
## Dataset columns
|
| 32 |
+
|
| 33 |
+
Required (aliases supported):
|
| 34 |
+
|
| 35 |
+
| Column | Aliases |
|
| 36 |
+
|--------|---------|
|
| 37 |
+
| `question` | — |
|
| 38 |
+
| `schema` | — |
|
| 39 |
+
| `student_sql` | `query` |
|
| 40 |
+
| `correct_sql` | `correct_query` |
|
| 41 |
+
| `error_labels` | `label_name` |
|
| 42 |
+
|
| 43 |
+
## Labels (9-class multi-label)
|
| 44 |
+
|
| 45 |
+
`JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
|
| 46 |
+
`SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`
|
README_HF_SPACE.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SQL Error Classifier
|
| 3 |
+
emoji: 🗄️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# SQL Error Classifier
|
| 14 |
+
|
| 15 |
+
Demo Space for the multi-tower SQL error classification model.
|
| 16 |
+
|
| 17 |
+
Set `SPACE_MODEL_ID` in Space secrets to your model repo (e.g. `username/sql-error-classifier`).
|
README_TRAIN_SPACE.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SQL Error Classifier Training
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: train_space_app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
hardware: t4-small
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# SQL Error Classifier — CodeBERT Training Space
|
| 15 |
+
|
| 16 |
+
Train `microsoft/codebert-base` as a **cross-encoder** for multi-label SQL error classification.
|
| 17 |
+
|
| 18 |
+
## Setup
|
| 19 |
+
|
| 20 |
+
1. **Hardware:** Settings → Hardware → **GPU t4-small** (recommended)
|
| 21 |
+
2. **Secrets:** Settings → Secrets → add `HF_TOKEN` (Hugging Face write token) to push models to your account
|
| 22 |
+
3. **Data:** Include `data/sql_errors_dev.parquet` in this Space repo, or upload parquet at runtime
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
1. Choose bundled dataset or upload your own parquet
|
| 27 |
+
2. Set epochs, batch size, max samples
|
| 28 |
+
3. Click **Start Training**
|
| 29 |
+
4. Optionally enable **Push to Hub** with model id `your-username/sql-codebert-classifier`
|
| 30 |
+
|
| 31 |
+
## Dataset columns
|
| 32 |
+
|
| 33 |
+
Required (aliases supported):
|
| 34 |
+
|
| 35 |
+
| Column | Aliases |
|
| 36 |
+
|--------|---------|
|
| 37 |
+
| `question` | — |
|
| 38 |
+
| `schema` | — |
|
| 39 |
+
| `student_sql` | `query` |
|
| 40 |
+
| `correct_sql` | `correct_query` |
|
| 41 |
+
| `error_labels` | `label_name` |
|
| 42 |
+
|
| 43 |
+
## Labels (9-class multi-label)
|
| 44 |
+
|
| 45 |
+
`JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
|
| 46 |
+
`SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`
|
app.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio app for Hugging Face Spaces.
|
| 3 |
+
|
| 4 |
+
Deploy: create a Space with sdk=gradio and point app_file to this file.
|
| 5 |
+
Set SPACE_MODEL_ID env var to your HF model repo (e.g. username/sql-error-classifier).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
from src.huggingface import SQLLErrorClassifierHF
|
| 17 |
+
|
| 18 |
+
MODEL_ID = os.getenv("SPACE_MODEL_ID", "models/hf_package")
|
| 19 |
+
LOCAL_PACKAGE = Path(__file__).parent / "models" / "hf_package"
|
| 20 |
+
|
| 21 |
+
EXAMPLE = {
|
| 22 |
+
"question": "What is the average score of students in each department?",
|
| 23 |
+
"schema": "students(id, name, score, department_id) | departments(id, name)",
|
| 24 |
+
"correct_query": (
|
| 25 |
+
"SELECT department_id, AVG(score) FROM students GROUP BY department_id"
|
| 26 |
+
),
|
| 27 |
+
"student_query": (
|
| 28 |
+
"SELECT department_id, SUM(score) FROM students GROUP BY department_id"
|
| 29 |
+
),
|
| 30 |
+
"error_message": "query executes but produces incorrect result set",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _load_classifier() -> SQLLErrorClassifierHF:
|
| 35 |
+
if LOCAL_PACKAGE.exists() and (LOCAL_PACKAGE / "config.json").exists():
|
| 36 |
+
return SQLLErrorClassifierHF.from_pretrained(LOCAL_PACKAGE)
|
| 37 |
+
return SQLLErrorClassifierHF.from_pretrained(MODEL_ID)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
clf = _load_classifier()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def classify(
|
| 44 |
+
question: str,
|
| 45 |
+
schema: str,
|
| 46 |
+
correct_query: str,
|
| 47 |
+
student_query: str,
|
| 48 |
+
error_message: str,
|
| 49 |
+
) -> tuple[str, str, str]:
|
| 50 |
+
result = clf.predict(
|
| 51 |
+
question=question.strip(),
|
| 52 |
+
schema=schema.strip(),
|
| 53 |
+
correct_query=correct_query.strip(),
|
| 54 |
+
student_query=student_query.strip(),
|
| 55 |
+
error_message=error_message.strip() or None,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
summary = (
|
| 59 |
+
f"**{result['label_name']}** \n"
|
| 60 |
+
f"Confidence: **{result['confidence']:.1%}**"
|
| 61 |
+
)
|
| 62 |
+
top_k = "\n".join(
|
| 63 |
+
f"- {item['label_name']}: {item['confidence']:.1%}"
|
| 64 |
+
for item in result["top_k"]
|
| 65 |
+
)
|
| 66 |
+
sims = result.get("similarities") or result.get("pair_scores") or {}
|
| 67 |
+
diagnostics = "\n".join(
|
| 68 |
+
f"- **{k.replace('_', ' ').title()}**: {v:.3f}"
|
| 69 |
+
for k, v in sims.items()
|
| 70 |
+
)
|
| 71 |
+
if not diagnostics:
|
| 72 |
+
diagnostics = "_No diagnostic scores for this model type._"
|
| 73 |
+
return summary, top_k, diagnostics
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
with gr.Blocks(title="SQL Error Classifier") as demo:
|
| 77 |
+
gr.Markdown(
|
| 78 |
+
"""
|
| 79 |
+
# SQL Error Classifier
|
| 80 |
+
Classify **which mistake area** a student is struggling with, using:
|
| 81 |
+
**question**, **schema**, **correct query**, and the **student's query**.
|
| 82 |
+
|
| 83 |
+
Powered by a multi-tower MiniLM architecture on Hugging Face.
|
| 84 |
+
"""
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
with gr.Row():
|
| 88 |
+
with gr.Column():
|
| 89 |
+
question = gr.Textbox(
|
| 90 |
+
label="Question",
|
| 91 |
+
lines=2,
|
| 92 |
+
value=EXAMPLE["question"],
|
| 93 |
+
)
|
| 94 |
+
schema = gr.Textbox(
|
| 95 |
+
label="Schema",
|
| 96 |
+
lines=2,
|
| 97 |
+
value=EXAMPLE["schema"],
|
| 98 |
+
)
|
| 99 |
+
correct_query = gr.Textbox(
|
| 100 |
+
label="Correct Query",
|
| 101 |
+
lines=3,
|
| 102 |
+
value=EXAMPLE["correct_query"],
|
| 103 |
+
)
|
| 104 |
+
student_query = gr.Textbox(
|
| 105 |
+
label="Student Query",
|
| 106 |
+
lines=3,
|
| 107 |
+
value=EXAMPLE["student_query"],
|
| 108 |
+
)
|
| 109 |
+
error_message = gr.Textbox(
|
| 110 |
+
label="DB Error Message (optional)",
|
| 111 |
+
lines=2,
|
| 112 |
+
value=EXAMPLE["error_message"],
|
| 113 |
+
)
|
| 114 |
+
run_btn = gr.Button("Classify", variant="primary")
|
| 115 |
+
|
| 116 |
+
with gr.Column():
|
| 117 |
+
prediction = gr.Markdown(label="Prediction")
|
| 118 |
+
top_k = gr.Markdown(label="Top 3")
|
| 119 |
+
diagnostics = gr.Markdown(label="Semantic Diagnostics")
|
| 120 |
+
|
| 121 |
+
run_btn.click(
|
| 122 |
+
classify,
|
| 123 |
+
inputs=[question, schema, correct_query, student_query, error_message],
|
| 124 |
+
outputs=[prediction, top_k, diagnostics],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
gr.Examples(
|
| 128 |
+
examples=[
|
| 129 |
+
[
|
| 130 |
+
EXAMPLE["question"],
|
| 131 |
+
EXAMPLE["schema"],
|
| 132 |
+
EXAMPLE["correct_query"],
|
| 133 |
+
EXAMPLE["student_query"],
|
| 134 |
+
EXAMPLE["error_message"],
|
| 135 |
+
],
|
| 136 |
+
[
|
| 137 |
+
"Find students who have not provided an email address.",
|
| 138 |
+
"students(id, name, email, phone)",
|
| 139 |
+
"SELECT name FROM students WHERE email IS NULL",
|
| 140 |
+
"SELECT name FROM students WHERE email = NULL",
|
| 141 |
+
"use IS NULL or IS NOT NULL to test for null values",
|
| 142 |
+
],
|
| 143 |
+
[
|
| 144 |
+
"List each student's name along with their department name.",
|
| 145 |
+
"students(id, name, department_id) | departments(id, name)",
|
| 146 |
+
"SELECT students.name, departments.name FROM students "
|
| 147 |
+
"INNER JOIN departments ON students.department_id = departments.id",
|
| 148 |
+
"SELECT students.name, departments.name FROM students JOIN departments",
|
| 149 |
+
"missing ON clause or invalid join condition",
|
| 150 |
+
],
|
| 151 |
+
],
|
| 152 |
+
inputs=[question, schema, correct_query, student_query, error_message],
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
demo.launch()
|
config/codebert_labels.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Primary labels for CodeBERT cross-encoder training
|
| 2 |
+
labels:
|
| 3 |
+
- JOIN_ERROR
|
| 4 |
+
- AGGREGATION_ERROR
|
| 5 |
+
- FILTER_ERROR
|
| 6 |
+
- WINDOW_FUNCTION_ERROR
|
| 7 |
+
- SUBQUERY_ERROR
|
| 8 |
+
- NULL_HANDLING_ERROR
|
| 9 |
+
- PERFORMANCE_ERROR
|
| 10 |
+
- LOGICAL_ERROR
|
| 11 |
+
- SYNTAX_ERROR
|
| 12 |
+
|
| 13 |
+
# Map dataset label_name values → one or more CodeBERT labels (multi-label)
|
| 14 |
+
alias_map:
|
| 15 |
+
JOIN_ERROR: [JOIN_ERROR]
|
| 16 |
+
AGGREGATION_ERROR: [AGGREGATION_ERROR]
|
| 17 |
+
HAVING_WHERE_ERROR: [AGGREGATION_ERROR]
|
| 18 |
+
FILTERING_ERROR: [FILTER_ERROR]
|
| 19 |
+
WINDOW_FUNCTION_ERROR: [WINDOW_FUNCTION_ERROR]
|
| 20 |
+
SUBQUERY_ERROR: [SUBQUERY_ERROR]
|
| 21 |
+
NULL_HANDLING_ERROR: [NULL_HANDLING_ERROR]
|
| 22 |
+
PERFORMANCE_ERROR: [PERFORMANCE_ERROR]
|
| 23 |
+
LOGICAL_QUERY_ERROR: [LOGICAL_ERROR]
|
| 24 |
+
SYNTAX_ERROR: [SYNTAX_ERROR]
|
| 25 |
+
DATE_FUNCTION_ERROR: [SYNTAX_ERROR]
|
| 26 |
+
COLUMN_REFERENCE_ERROR: [SYNTAX_ERROR]
|
| 27 |
+
TABLE_REFERENCE_ERROR: [SYNTAX_ERROR]
|
| 28 |
+
DATA_TYPE_ERROR: [SYNTAX_ERROR]
|
| 29 |
+
DUPLICATE_RECORD_ERROR: [FILTER_ERROR]
|
config/error_categories.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
categories:
|
| 2 |
+
- id: 0
|
| 3 |
+
name: SYNTAX_ERROR
|
| 4 |
+
description: Missing comma, bracket, quote
|
| 5 |
+
- id: 1
|
| 6 |
+
name: JOIN_ERROR
|
| 7 |
+
description: Missing ON, wrong join condition
|
| 8 |
+
- id: 2
|
| 9 |
+
name: AGGREGATION_ERROR
|
| 10 |
+
description: Missing GROUP BY
|
| 11 |
+
- id: 3
|
| 12 |
+
name: HAVING_WHERE_ERROR
|
| 13 |
+
description: Using aggregate in WHERE
|
| 14 |
+
- id: 4
|
| 15 |
+
name: SUBQUERY_ERROR
|
| 16 |
+
description: Multiple rows returned
|
| 17 |
+
- id: 5
|
| 18 |
+
name: WINDOW_FUNCTION_ERROR
|
| 19 |
+
description: Incorrect OVER/PARTITION BY
|
| 20 |
+
- id: 6
|
| 21 |
+
name: NULL_HANDLING_ERROR
|
| 22 |
+
description: "= NULL instead of IS NULL"
|
| 23 |
+
- id: 7
|
| 24 |
+
name: DATE_FUNCTION_ERROR
|
| 25 |
+
description: Incorrect date format/function
|
| 26 |
+
- id: 8
|
| 27 |
+
name: COLUMN_REFERENCE_ERROR
|
| 28 |
+
description: Column doesn't exist
|
| 29 |
+
- id: 9
|
| 30 |
+
name: TABLE_REFERENCE_ERROR
|
| 31 |
+
description: Table doesn't exist
|
| 32 |
+
- id: 10
|
| 33 |
+
name: DATA_TYPE_ERROR
|
| 34 |
+
description: Comparing integer with string
|
| 35 |
+
- id: 11
|
| 36 |
+
name: DUPLICATE_RECORD_ERROR
|
| 37 |
+
description: Missing DISTINCT
|
| 38 |
+
- id: 12
|
| 39 |
+
name: LOGICAL_QUERY_ERROR
|
| 40 |
+
description: Query runs but answer is wrong
|
| 41 |
+
- id: 13
|
| 42 |
+
name: PERFORMANCE_ERROR
|
| 43 |
+
description: "SELECT *, inefficient joins"
|
| 44 |
+
- id: 14
|
| 45 |
+
name: FILTERING_ERROR
|
| 46 |
+
description: Incorrect WHERE clause
|
hub/CODEBERT_MODEL_CARD.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: mit
|
| 4 |
+
tags:
|
| 5 |
+
- codebert
|
| 6 |
+
- sql
|
| 7 |
+
- education
|
| 8 |
+
- text-classification
|
| 9 |
+
- cross-encoder
|
| 10 |
+
base_model: microsoft/codebert-base
|
| 11 |
+
pipeline_tag: text-classification
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# SQL CodeBERT Cross-Encoder
|
| 15 |
+
|
| 16 |
+
Multi-label SQL error classifier using **microsoft/codebert-base** as a cross-encoder.
|
| 17 |
+
|
| 18 |
+
## Input Format
|
| 19 |
+
|
| 20 |
+
All fields are concatenated into one sequence:
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
QUESTION:
|
| 24 |
+
{question}
|
| 25 |
+
|
| 26 |
+
SCHEMA:
|
| 27 |
+
{schema}
|
| 28 |
+
|
| 29 |
+
STUDENT_SQL:
|
| 30 |
+
{student_sql}
|
| 31 |
+
|
| 32 |
+
CORRECT_SQL:
|
| 33 |
+
{correct_sql}
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Labels
|
| 37 |
+
|
| 38 |
+
`JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
|
| 39 |
+
`SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`
|
| 40 |
+
|
| 41 |
+
## Training
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
python -m src.hf_train_codebert \
|
| 45 |
+
--data data/sql_errors_1m.parquet \
|
| 46 |
+
--output-dir models/codebert-cross-encoder \
|
| 47 |
+
--epochs 3 \
|
| 48 |
+
--push-to-hub \
|
| 49 |
+
--hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Inference
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
|
| 56 |
+
|
| 57 |
+
clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder")
|
| 58 |
+
result = clf.predict(
|
| 59 |
+
question="What is the average score per department?",
|
| 60 |
+
schema="students(id, score, department_id)",
|
| 61 |
+
student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
|
| 62 |
+
correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
|
| 63 |
+
)
|
| 64 |
+
print(result["error_labels"])
|
| 65 |
+
```
|
hub/MODEL_CARD.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: mit
|
| 4 |
+
tags:
|
| 5 |
+
- sql
|
| 6 |
+
- education
|
| 7 |
+
- text-classification
|
| 8 |
+
- sentence-transformers
|
| 9 |
+
- multi-tower
|
| 10 |
+
pipeline_tag: text-classification
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# SQL Error Classifier (Multi-Tower)
|
| 14 |
+
|
| 15 |
+
Lightweight classifier that identifies **which SQL mistake area** a student is struggling with, given:
|
| 16 |
+
|
| 17 |
+
- **Question** — natural-language task
|
| 18 |
+
- **Schema** — available tables and columns
|
| 19 |
+
- **Correct query** — reference solution
|
| 20 |
+
- **Student query** — what the student submitted
|
| 21 |
+
- **Error message** *(optional)* — database error text
|
| 22 |
+
|
| 23 |
+
## Architecture
|
| 24 |
+
|
| 25 |
+
Multi-tower semantic comparison using `sentence-transformers/all-MiniLM-L6-v2`:
|
| 26 |
+
|
| 27 |
+
1. **Intent tower** — question + schema
|
| 28 |
+
2. **Reference tower** — correct query
|
| 29 |
+
3. **Student tower** — student query (+ error)
|
| 30 |
+
4. **Comparison layer** — embedding diff, interaction, cosine similarities, SQL rule features
|
| 31 |
+
5. **Linear head** — 15 error categories
|
| 32 |
+
|
| 33 |
+
## Error Categories (15)
|
| 34 |
+
|
| 35 |
+
| ID | Category |
|
| 36 |
+
|----|----------|
|
| 37 |
+
| 0 | SYNTAX_ERROR |
|
| 38 |
+
| 1 | JOIN_ERROR |
|
| 39 |
+
| 2 | AGGREGATION_ERROR |
|
| 40 |
+
| 3 | HAVING_WHERE_ERROR |
|
| 41 |
+
| 4 | SUBQUERY_ERROR |
|
| 42 |
+
| 5 | WINDOW_FUNCTION_ERROR |
|
| 43 |
+
| 6 | NULL_HANDLING_ERROR |
|
| 44 |
+
| 7 | DATE_FUNCTION_ERROR |
|
| 45 |
+
| 8 | COLUMN_REFERENCE_ERROR |
|
| 46 |
+
| 9 | TABLE_REFERENCE_ERROR |
|
| 47 |
+
| 10 | DATA_TYPE_ERROR |
|
| 48 |
+
| 11 | DUPLICATE_RECORD_ERROR |
|
| 49 |
+
| 12 | LOGICAL_QUERY_ERROR |
|
| 50 |
+
| 13 | PERFORMANCE_ERROR |
|
| 51 |
+
| 14 | FILTERING_ERROR |
|
| 52 |
+
|
| 53 |
+
## Usage
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
from src.huggingface import SQLLErrorClassifierHF
|
| 57 |
+
|
| 58 |
+
clf = SQLLErrorClassifierHF.from_pretrained("YOUR_USERNAME/sql-error-classifier")
|
| 59 |
+
|
| 60 |
+
result = clf.predict(
|
| 61 |
+
question="What is the average score of students in each department?",
|
| 62 |
+
schema="students(id, name, score, department_id) | departments(id, name)",
|
| 63 |
+
correct_query="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
|
| 64 |
+
student_query="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
print(result["label_name"]) # LOGICAL_QUERY_ERROR
|
| 68 |
+
print(result["confidence"]) # 0.94
|
| 69 |
+
print(result["similarities"]) # semantic alignment scores
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## Gradio Demo
|
| 73 |
+
|
| 74 |
+
Deploy as a [Hugging Face Space](https://huggingface.co/docs/hub/spaces) with `app.py` from this repository.
|
| 75 |
+
|
| 76 |
+
## Model Details
|
| 77 |
+
|
| 78 |
+
- **Encoder**: `sentence-transformers/all-MiniLM-L6-v2` (loaded from Hub, not bundled)
|
| 79 |
+
- **Head**: scikit-learn SGDClassifier + StandardScaler
|
| 80 |
+
- **Size**: ~5 MB classifier head (encoder ~80 MB, cached separately)
|
| 81 |
+
- **Inference**: ~100–200 ms on CPU
|
| 82 |
+
|
| 83 |
+
## Training Data
|
| 84 |
+
|
| 85 |
+
Synthetically generated from exercise templates with per-category error injectors.
|
| 86 |
+
1M balanced samples across 15 classes.
|
| 87 |
+
|
| 88 |
+
## Citation
|
| 89 |
+
|
| 90 |
+
```bibtex
|
| 91 |
+
@misc{sql-error-classifier,
|
| 92 |
+
title = {SQL Error Classifier - Multi-Tower},
|
| 93 |
+
author = {SQLErrorClassification},
|
| 94 |
+
year = {2025},
|
| 95 |
+
publisher = {Hugging Face},
|
| 96 |
+
}
|
| 97 |
+
```
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.24.0
|
| 2 |
+
pandas>=2.0.0
|
| 3 |
+
scikit-learn>=1.3.0
|
| 4 |
+
joblib>=1.3.0
|
| 5 |
+
pyarrow>=14.0.0
|
| 6 |
+
tqdm>=4.66.0
|
| 7 |
+
pyyaml>=6.0
|
| 8 |
+
matplotlib>=3.7.0
|
| 9 |
+
sentence-transformers>=2.2.0
|
| 10 |
+
torch>=2.0.0
|
| 11 |
+
transformers>=4.36.0
|
| 12 |
+
accelerate>=0.25.0
|
| 13 |
+
datasets>=2.16.0
|
| 14 |
+
huggingface_hub>=0.20.0
|
| 15 |
+
gradio>=4.44.0
|
scripts/create_hf_package.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Create a local Hugging Face Hub package from a trained model."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 11 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 12 |
+
|
| 13 |
+
from src.huggingface import package_for_hub
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Create HF Hub package locally")
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--model",
|
| 19 |
+
type=Path,
|
| 20 |
+
default=PROJECT_ROOT / "models" / "multi_tower_dev.joblib",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--output",
|
| 24 |
+
type=Path,
|
| 25 |
+
default=PROJECT_ROOT / "models" / "hf_package",
|
| 26 |
+
)
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
out = package_for_hub(args.model, args.output)
|
| 30 |
+
print(f"Package ready at {out}")
|
| 31 |
+
print("Files:", [p.name for p in out.iterdir()])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
scripts/deploy_train_space.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Deploy training Space to Hugging Face.
|
| 3 |
+
# Usage: ./scripts/deploy_train_space.sh YOUR_HF_USERNAME/sql-error-classifier-train
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
SPACE_ID="${1:-}"
|
| 7 |
+
if [[ -z "${SPACE_ID}" ]]; then
|
| 8 |
+
echo "Usage: $0 YOUR_USERNAME/sql-error-classifier-train"
|
| 9 |
+
exit 1
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 13 |
+
TOKEN="${HF_TOKEN:-${HUGGING_FACE_HUB_TOKEN:-}}"
|
| 14 |
+
|
| 15 |
+
if [[ -z "${TOKEN}" ]]; then
|
| 16 |
+
echo "Set HF_TOKEN before deploying."
|
| 17 |
+
exit 1
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
WORKDIR=$(mktemp -d)
|
| 21 |
+
trap 'rm -rf "${WORKDIR}"' EXIT
|
| 22 |
+
|
| 23 |
+
echo "==> Preparing Space files in ${WORKDIR}..."
|
| 24 |
+
rsync -a \
|
| 25 |
+
--exclude '.venv' \
|
| 26 |
+
--exclude 'models' \
|
| 27 |
+
--exclude '__pycache__' \
|
| 28 |
+
--exclude '.git' \
|
| 29 |
+
"${ROOT}/" "${WORKDIR}/"
|
| 30 |
+
|
| 31 |
+
cp "${ROOT}/README_TRAIN_SPACE.md" "${WORKDIR}/README.md"
|
| 32 |
+
|
| 33 |
+
echo "==> Creating / updating Space ${SPACE_ID}..."
|
| 34 |
+
python - <<PY
|
| 35 |
+
from huggingface_hub import HfApi
|
| 36 |
+
api = HfApi(token="${TOKEN}")
|
| 37 |
+
api.create_repo("${SPACE_ID}", repo_type="space", space_sdk="gradio", exist_ok=True)
|
| 38 |
+
PY
|
| 39 |
+
|
| 40 |
+
echo "==> Uploading to Hugging Face Space..."
|
| 41 |
+
python - <<PY
|
| 42 |
+
from huggingface_hub import HfApi
|
| 43 |
+
api = HfApi(token="${TOKEN}")
|
| 44 |
+
api.upload_folder(
|
| 45 |
+
folder_path="${WORKDIR}",
|
| 46 |
+
repo_id="${SPACE_ID}",
|
| 47 |
+
repo_type="space",
|
| 48 |
+
commit_message="Deploy CodeBERT training Space",
|
| 49 |
+
)
|
| 50 |
+
PY
|
| 51 |
+
|
| 52 |
+
echo "==> Done: https://huggingface.co/spaces/${SPACE_ID}"
|
| 53 |
+
echo "Next: Space Settings → Hardware → GPU t4-small"
|
| 54 |
+
echo " Space Settings → Secrets → HF_TOKEN"
|
scripts/push_to_hub.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Package and push the SQL error classifier to Hugging Face Hub."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 12 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import HfApi, create_repo
|
| 15 |
+
|
| 16 |
+
from src.huggingface import package_for_hub
|
| 17 |
+
DEFAULT_MODEL = PROJECT_ROOT / "models" / "multi_tower_dev.joblib"
|
| 18 |
+
DEFAULT_PACKAGE = PROJECT_ROOT / "models" / "hf_package"
|
| 19 |
+
MODEL_CARD = PROJECT_ROOT / "hub" / "MODEL_CARD.md"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def push(
|
| 23 |
+
model_path: Path = DEFAULT_MODEL,
|
| 24 |
+
package_dir: Path = DEFAULT_PACKAGE,
|
| 25 |
+
repo_id: str = "",
|
| 26 |
+
private: bool = False,
|
| 27 |
+
token: str | None = None,
|
| 28 |
+
) -> str:
|
| 29 |
+
if not repo_id:
|
| 30 |
+
raise ValueError("--repo-id is required (e.g. your-username/sql-error-classifier)")
|
| 31 |
+
|
| 32 |
+
token = token or os.getenv("HF_TOKEN")
|
| 33 |
+
api = HfApi(token=token)
|
| 34 |
+
|
| 35 |
+
print(f"Packaging model from {model_path}...")
|
| 36 |
+
package_for_hub(model_path, package_dir)
|
| 37 |
+
|
| 38 |
+
print(f"Creating repo {repo_id}...")
|
| 39 |
+
create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token)
|
| 40 |
+
|
| 41 |
+
print("Uploading model files...")
|
| 42 |
+
api.upload_folder(
|
| 43 |
+
folder_path=str(package_dir),
|
| 44 |
+
repo_id=repo_id,
|
| 45 |
+
repo_type="model",
|
| 46 |
+
token=token,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if MODEL_CARD.exists():
|
| 50 |
+
api.upload_file(
|
| 51 |
+
path_or_fileobj=str(MODEL_CARD),
|
| 52 |
+
path_in_repo="README.md",
|
| 53 |
+
repo_id=repo_id,
|
| 54 |
+
repo_type="model",
|
| 55 |
+
token=token,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
url = f"https://huggingface.co/{repo_id}"
|
| 59 |
+
print(f"Done: {url}")
|
| 60 |
+
return url
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main() -> None:
|
| 64 |
+
parser = argparse.ArgumentParser(description="Push SQL error classifier to HF Hub")
|
| 65 |
+
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL)
|
| 66 |
+
parser.add_argument("--package-dir", type=Path, default=DEFAULT_PACKAGE)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--repo-id",
|
| 69 |
+
type=str,
|
| 70 |
+
required=True,
|
| 71 |
+
help="Hugging Face repo id, e.g. nishantgupta/sql-error-classifier",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument("--private", action="store_true")
|
| 74 |
+
parser.add_argument("--token", type=str, default=None)
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
push(
|
| 78 |
+
model_path=args.model,
|
| 79 |
+
package_dir=args.package_dir,
|
| 80 |
+
repo_id=args.repo_id,
|
| 81 |
+
private=args.private,
|
| 82 |
+
token=args.token,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
main()
|
scripts/run_codebert_training.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
DATA="${1:-data/sql_errors_dev.parquet}"
|
| 5 |
+
OUTPUT="${2:-models/codebert-cross-encoder}"
|
| 6 |
+
SAMPLES="${3:-}"
|
| 7 |
+
|
| 8 |
+
CMD=(python -m src.hf_train_codebert --data "${DATA}" --output-dir "${OUTPUT}")
|
| 9 |
+
|
| 10 |
+
if [[ -n "${SAMPLES}" ]]; then
|
| 11 |
+
CMD+=(--max-samples "${SAMPLES}")
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
echo "==> Training CodeBERT cross-encoder..."
|
| 15 |
+
"${CMD[@]}"
|
| 16 |
+
|
| 17 |
+
echo "==> Done. Model at ${OUTPUT}"
|
scripts/run_pipeline.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SAMPLES="${1:-1000000}"
|
| 5 |
+
WORKERS="${2:-8}"
|
| 6 |
+
|
| 7 |
+
echo "==> Generating ${SAMPLES} labeled SQL samples..."
|
| 8 |
+
python -m src.generate_dataset --samples "${SAMPLES}" --workers "${WORKERS}"
|
| 9 |
+
|
| 10 |
+
echo "==> Training classifier..."
|
| 11 |
+
python -m src.train
|
| 12 |
+
|
| 13 |
+
echo "==> Evaluating..."
|
| 14 |
+
python -m src.evaluate
|
| 15 |
+
|
| 16 |
+
echo "==> Done. Model at models/sql_error_classifier.joblib"
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""SQL error classification package."""
|
src/categories.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
CONFIG_PATH = Path(__file__).resolve().parent.parent / "config" / "error_categories.yaml"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class ErrorCategory:
|
| 14 |
+
id: int
|
| 15 |
+
name: str
|
| 16 |
+
description: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_categories(config_path: Path = CONFIG_PATH) -> List[ErrorCategory]:
|
| 20 |
+
with open(config_path) as f:
|
| 21 |
+
data = yaml.safe_load(f)
|
| 22 |
+
return [
|
| 23 |
+
ErrorCategory(id=c["id"], name=c["name"], description=c["description"])
|
| 24 |
+
for c in data["categories"]
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def id_to_name(categories: List[ErrorCategory] | None = None) -> Dict[int, str]:
|
| 29 |
+
cats = categories or load_categories()
|
| 30 |
+
return {c.id: c.name for c in cats}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def name_to_id(categories: List[ErrorCategory] | None = None) -> Dict[str, int]:
|
| 34 |
+
cats = categories or load_categories()
|
| 35 |
+
return {c.name: c.id for c in cats}
|
src/codebert_dataset.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch Dataset and preprocessing for CodeBERT Trainer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from transformers import PreTrainedTokenizerBase
|
| 12 |
+
|
| 13 |
+
from src.codebert_formatting import format_cross_encoder_input
|
| 14 |
+
from src.codebert_labels import label_to_multihot, load_codebert_labels
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def normalize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
| 18 |
+
"""Map project column names to the canonical training schema."""
|
| 19 |
+
col_map = {
|
| 20 |
+
"query": "student_sql",
|
| 21 |
+
"correct_query": "correct_sql",
|
| 22 |
+
"label_name": "error_labels",
|
| 23 |
+
}
|
| 24 |
+
out = df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}).copy()
|
| 25 |
+
|
| 26 |
+
required = ["question", "schema", "student_sql", "correct_sql", "error_labels"]
|
| 27 |
+
missing = [c for c in required if c not in out.columns]
|
| 28 |
+
if missing:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
f"Dataset missing required columns: {missing}. "
|
| 31 |
+
f"Expected {required} (or aliases query/correct_query/label_name)."
|
| 32 |
+
)
|
| 33 |
+
return out
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SQLCodeBERTDataset(Dataset):
|
| 37 |
+
"""Tokenized SQL error dataset for Hugging Face Trainer."""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
df: pd.DataFrame,
|
| 42 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 43 |
+
label_list: Optional[List[str]] = None,
|
| 44 |
+
max_length: int = 512,
|
| 45 |
+
):
|
| 46 |
+
self.df = normalize_dataframe(df).reset_index(drop=True)
|
| 47 |
+
self.tokenizer = tokenizer
|
| 48 |
+
self.label_list = label_list or load_codebert_labels()
|
| 49 |
+
self.max_length = max_length
|
| 50 |
+
self.num_labels = len(self.label_list)
|
| 51 |
+
|
| 52 |
+
def __len__(self) -> int:
|
| 53 |
+
return len(self.df)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 56 |
+
row = self.df.iloc[idx]
|
| 57 |
+
text = format_cross_encoder_input(
|
| 58 |
+
question=str(row["question"]),
|
| 59 |
+
schema=str(row["schema"]),
|
| 60 |
+
student_sql=str(row["student_sql"]),
|
| 61 |
+
correct_sql=str(row["correct_sql"]),
|
| 62 |
+
)
|
| 63 |
+
encoded = self.tokenizer(
|
| 64 |
+
text,
|
| 65 |
+
truncation=True,
|
| 66 |
+
max_length=self.max_length,
|
| 67 |
+
padding=False,
|
| 68 |
+
return_tensors=None,
|
| 69 |
+
)
|
| 70 |
+
labels = label_to_multihot(str(row["error_labels"]), self.label_list)
|
| 71 |
+
encoded["labels"] = labels.tolist()
|
| 72 |
+
return encoded
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SQLCodeBERTDataCollator:
|
| 76 |
+
"""Pad batches dynamically for Trainer."""
|
| 77 |
+
|
| 78 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 79 |
+
self.tokenizer = tokenizer
|
| 80 |
+
|
| 81 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 82 |
+
labels = [f.pop("labels") for f in features]
|
| 83 |
+
batch = self.tokenizer.pad(features, padding=True, return_tensors="pt")
|
| 84 |
+
batch["labels"] = torch.tensor(labels, dtype=torch.float)
|
| 85 |
+
return batch
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def prepare_datasets(
|
| 89 |
+
df: pd.DataFrame,
|
| 90 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 91 |
+
test_size: float = 0.1,
|
| 92 |
+
val_size: float = 0.1,
|
| 93 |
+
max_length: int = 512,
|
| 94 |
+
seed: int = 42,
|
| 95 |
+
) -> tuple[SQLCodeBERTDataset, SQLCodeBERTDataset, SQLCodeBERTDataset]:
|
| 96 |
+
from sklearn.model_selection import train_test_split
|
| 97 |
+
|
| 98 |
+
df = normalize_dataframe(df)
|
| 99 |
+
trainval, test_df = train_test_split(
|
| 100 |
+
df,
|
| 101 |
+
test_size=test_size,
|
| 102 |
+
random_state=seed,
|
| 103 |
+
stratify=df["error_labels"],
|
| 104 |
+
)
|
| 105 |
+
relative_val = val_size / (1 - test_size)
|
| 106 |
+
train_df, val_df = train_test_split(
|
| 107 |
+
trainval,
|
| 108 |
+
test_size=relative_val,
|
| 109 |
+
random_state=seed,
|
| 110 |
+
stratify=trainval["error_labels"],
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return (
|
| 114 |
+
SQLCodeBERTDataset(train_df, tokenizer, max_length=max_length),
|
| 115 |
+
SQLCodeBERTDataset(val_df, tokenizer, max_length=max_length),
|
| 116 |
+
SQLCodeBERTDataset(test_df, tokenizer, max_length=max_length),
|
| 117 |
+
)
|
src/codebert_formatting.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cross-encoder input formatting for CodeBERT."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
QUESTION_TAG = "QUESTION:"
|
| 6 |
+
SCHEMA_TAG = "SCHEMA:"
|
| 7 |
+
STUDENT_TAG = "STUDENT_SQL:"
|
| 8 |
+
CORRECT_TAG = "CORRECT_SQL:"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def format_cross_encoder_input(
|
| 12 |
+
question: str,
|
| 13 |
+
schema: str,
|
| 14 |
+
student_sql: str,
|
| 15 |
+
correct_sql: str,
|
| 16 |
+
) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Concatenate all fields into a single CodeBERT input sequence.
|
| 19 |
+
|
| 20 |
+
The model attends jointly across question intent, schema, student SQL,
|
| 21 |
+
and the reference solution — cross-encoder style in one forward pass.
|
| 22 |
+
"""
|
| 23 |
+
return (
|
| 24 |
+
f"{QUESTION_TAG}\n{question.strip()}\n\n"
|
| 25 |
+
f"{SCHEMA_TAG}\n{schema.strip()}\n\n"
|
| 26 |
+
f"{STUDENT_TAG}\n{student_sql.strip()}\n\n"
|
| 27 |
+
f"{CORRECT_TAG}\n{correct_sql.strip()}"
|
| 28 |
+
)
|
src/codebert_labels.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Label utilities for CodeBERT multi-label classification."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Sequence, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
CONFIG_PATH = (
|
| 12 |
+
Path(__file__).resolve().parent.parent / "config" / "codebert_labels.yaml"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_codebert_labels(config_path: Path = CONFIG_PATH) -> List[str]:
|
| 17 |
+
with open(config_path) as f:
|
| 18 |
+
data = yaml.safe_load(f)
|
| 19 |
+
return list(data["labels"])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_alias_map(config_path: Path = CONFIG_PATH) -> Dict[str, List[str]]:
|
| 23 |
+
with open(config_path) as f:
|
| 24 |
+
data = yaml.safe_load(f)
|
| 25 |
+
return {k: list(v) for k, v in data["alias_map"].items()}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def label_to_multihot(
|
| 29 |
+
error_labels: Union[str, Sequence[str]],
|
| 30 |
+
label_list: List[str] | None = None,
|
| 31 |
+
alias_map: Dict[str, List[str]] | None = None,
|
| 32 |
+
) -> np.ndarray:
|
| 33 |
+
"""
|
| 34 |
+
Convert error label(s) to multi-hot vector.
|
| 35 |
+
|
| 36 |
+
Accepts:
|
| 37 |
+
- comma-separated string: "JOIN_ERROR,AGGREGATION_ERROR"
|
| 38 |
+
- list of label strings
|
| 39 |
+
- single dataset label_name (resolved via alias_map)
|
| 40 |
+
"""
|
| 41 |
+
labels = label_list or load_codebert_labels()
|
| 42 |
+
aliases = alias_map or load_alias_map()
|
| 43 |
+
index = {name: i for i, name in enumerate(labels)}
|
| 44 |
+
vec = np.zeros(len(labels), dtype=np.float32)
|
| 45 |
+
|
| 46 |
+
if isinstance(error_labels, str):
|
| 47 |
+
raw = [s.strip() for s in error_labels.split(",") if s.strip()]
|
| 48 |
+
if len(raw) == 1 and raw[0] in aliases:
|
| 49 |
+
raw = aliases[raw[0]]
|
| 50 |
+
elif len(raw) == 1 and raw[0] in index:
|
| 51 |
+
raw = [raw[0]]
|
| 52 |
+
elif len(raw) == 1 and raw[0] not in index:
|
| 53 |
+
mapped = aliases.get(raw[0], [])
|
| 54 |
+
raw = mapped
|
| 55 |
+
else:
|
| 56 |
+
raw = list(error_labels)
|
| 57 |
+
expanded: List[str] = []
|
| 58 |
+
for item in raw:
|
| 59 |
+
if item in aliases:
|
| 60 |
+
expanded.extend(aliases[item])
|
| 61 |
+
elif item in index:
|
| 62 |
+
expanded.append(item)
|
| 63 |
+
raw = expanded
|
| 64 |
+
|
| 65 |
+
for name in raw:
|
| 66 |
+
if name not in index:
|
| 67 |
+
raise ValueError(f"Unknown label '{name}'. Expected one of {labels}")
|
| 68 |
+
vec[index[name]] = 1.0
|
| 69 |
+
|
| 70 |
+
if vec.sum() == 0:
|
| 71 |
+
raise ValueError(f"No valid labels found in {error_labels}")
|
| 72 |
+
return vec
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def multihot_to_label_names(
|
| 76 |
+
vec: np.ndarray,
|
| 77 |
+
label_list: List[str] | None = None,
|
| 78 |
+
threshold: float = 0.5,
|
| 79 |
+
) -> List[str]:
|
| 80 |
+
labels = label_list or load_codebert_labels()
|
| 81 |
+
indices = np.where(vec >= threshold)[0]
|
| 82 |
+
return [labels[i] for i in indices]
|
src/cross_encoder_model.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cross-encoder architecture for SQL error classification."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.linear_model import LogisticRegression
|
| 11 |
+
from sklearn.preprocessing import StandardScaler
|
| 12 |
+
|
| 13 |
+
from src.multi_tower_model import QueryContext, contexts_from_dataframe
|
| 14 |
+
from src.sql_features import extract_sql_features
|
| 15 |
+
|
| 16 |
+
DEFAULT_CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L6-v2"
|
| 17 |
+
DEFAULT_FINETUNED_CE = "cross-encoder/ms-marco-MiniLM-L6-v2"
|
| 18 |
+
|
| 19 |
+
PAIR_NAMES = (
|
| 20 |
+
"intent_vs_student",
|
| 21 |
+
"reference_vs_student",
|
| 22 |
+
"intent_vs_reference",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True)
|
| 27 |
+
class CrossEncoderPair:
|
| 28 |
+
name: str
|
| 29 |
+
text_a: str
|
| 30 |
+
text_b: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _intent_text(ctx: QueryContext) -> str:
|
| 34 |
+
return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _reference_text(ctx: QueryContext) -> str:
|
| 38 |
+
return f"REFERENCE: {ctx.correct_query}"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _student_text(ctx: QueryContext) -> str:
|
| 42 |
+
parts = [f"STUDENT: {ctx.student_query}"]
|
| 43 |
+
if ctx.error_message:
|
| 44 |
+
parts.append(f"ERROR: {ctx.error_message}")
|
| 45 |
+
return " ".join(parts)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _context_text(ctx: QueryContext) -> str:
|
| 49 |
+
"""Full task context for fine-tuned cross-encoder."""
|
| 50 |
+
return (
|
| 51 |
+
f"QUESTION: {ctx.question} "
|
| 52 |
+
f"SCHEMA: {ctx.schema} "
|
| 53 |
+
f"REFERENCE: {ctx.correct_query}"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_pairs(ctx: QueryContext) -> List[CrossEncoderPair]:
|
| 58 |
+
intent, reference, student = (
|
| 59 |
+
_intent_text(ctx),
|
| 60 |
+
_reference_text(ctx),
|
| 61 |
+
_student_text(ctx),
|
| 62 |
+
)
|
| 63 |
+
return [
|
| 64 |
+
CrossEncoderPair("intent_vs_student", intent, student),
|
| 65 |
+
CrossEncoderPair("reference_vs_student", reference, student),
|
| 66 |
+
CrossEncoderPair("intent_vs_reference", intent, reference),
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class CrossEncoderClassifier:
|
| 71 |
+
"""
|
| 72 |
+
Hybrid cross-encoder: frozen pairwise relevance + linear head.
|
| 73 |
+
|
| 74 |
+
Unlike bi-encoders (multi-tower), the cross-encoder attends jointly over
|
| 75 |
+
each (context, student) pair — better for logical and filtering errors.
|
| 76 |
+
|
| 77 |
+
Three pairs are scored:
|
| 78 |
+
1. intent vs student — does the query address the question?
|
| 79 |
+
2. reference vs student — how far is the student from the answer?
|
| 80 |
+
3. intent vs reference — task-answer alignment baseline
|
| 81 |
+
|
| 82 |
+
Pair scores + SQL rule features → LogisticRegression → 15 classes.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
cross_encoder_name: str = DEFAULT_CROSS_ENCODER,
|
| 88 |
+
batch_size: int = 32,
|
| 89 |
+
max_length: int = 512,
|
| 90 |
+
):
|
| 91 |
+
self.cross_encoder_name = cross_encoder_name
|
| 92 |
+
self.batch_size = batch_size
|
| 93 |
+
self.max_length = max_length
|
| 94 |
+
self.cross_encoder = None
|
| 95 |
+
self.scaler = StandardScaler()
|
| 96 |
+
self.clf = LogisticRegression(
|
| 97 |
+
max_iter=1000,
|
| 98 |
+
solver="lbfgs",
|
| 99 |
+
class_weight="balanced",
|
| 100 |
+
random_state=42,
|
| 101 |
+
)
|
| 102 |
+
self.classes_: Optional[np.ndarray] = None
|
| 103 |
+
|
| 104 |
+
def _load_cross_encoder(self):
|
| 105 |
+
if self.cross_encoder is None:
|
| 106 |
+
from sentence_transformers import CrossEncoder
|
| 107 |
+
|
| 108 |
+
self.cross_encoder = CrossEncoder(
|
| 109 |
+
self.cross_encoder_name,
|
| 110 |
+
max_length=self.max_length,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def _pair_batches(self, contexts: List[QueryContext]) -> List[List[Tuple[str, str]]]:
|
| 114 |
+
"""One batch list per pair type across all contexts."""
|
| 115 |
+
pair_lists: List[List[Tuple[str, str]]] = [[], [], []]
|
| 116 |
+
for ctx in contexts:
|
| 117 |
+
pairs = build_pairs(ctx)
|
| 118 |
+
for i, pair in enumerate(pairs):
|
| 119 |
+
pair_lists[i].append((pair.text_a, pair.text_b))
|
| 120 |
+
return pair_lists
|
| 121 |
+
|
| 122 |
+
def _score_pairs(
|
| 123 |
+
self,
|
| 124 |
+
contexts: List[QueryContext],
|
| 125 |
+
show_progress: bool = False,
|
| 126 |
+
) -> np.ndarray:
|
| 127 |
+
self._load_cross_encoder()
|
| 128 |
+
pair_batches = self._pair_batches(contexts)
|
| 129 |
+
scores = []
|
| 130 |
+
for batch in pair_batches:
|
| 131 |
+
raw = self.cross_encoder.predict(
|
| 132 |
+
batch,
|
| 133 |
+
batch_size=self.batch_size,
|
| 134 |
+
show_progress_bar=show_progress,
|
| 135 |
+
)
|
| 136 |
+
scores.append(np.asarray(raw, dtype=np.float64).reshape(-1, 1))
|
| 137 |
+
return np.hstack(scores) # (n, 3)
|
| 138 |
+
|
| 139 |
+
def _build_features(
|
| 140 |
+
self,
|
| 141 |
+
contexts: List[QueryContext],
|
| 142 |
+
show_progress: bool = False,
|
| 143 |
+
) -> np.ndarray:
|
| 144 |
+
pair_scores = self._score_pairs(contexts, show_progress=show_progress)
|
| 145 |
+
s_is, s_rs, s_ir = pair_scores[:, 0], pair_scores[:, 1], pair_scores[:, 2]
|
| 146 |
+
|
| 147 |
+
derived = np.column_stack(
|
| 148 |
+
[
|
| 149 |
+
s_rs - s_is, # reference closer than intent?
|
| 150 |
+
s_is - s_ir, # student-intent gap vs baseline
|
| 151 |
+
s_rs - s_ir, # student-reference gap vs baseline
|
| 152 |
+
s_is * s_rs, # interaction
|
| 153 |
+
np.abs(s_rs - s_is), # intent-reference disagreement
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
sql_feats = np.array(
|
| 158 |
+
[extract_sql_features(c.student_query, c.correct_query) for c in contexts],
|
| 159 |
+
dtype=np.float64,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
return np.hstack([pair_scores, derived, sql_feats])
|
| 163 |
+
|
| 164 |
+
def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 165 |
+
X = self.scaler.transform(self._build_features(contexts))
|
| 166 |
+
return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
|
| 167 |
+
|
| 168 |
+
def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "CrossEncoderClassifier":
|
| 169 |
+
X = self._build_features(contexts, show_progress=True)
|
| 170 |
+
X = self.scaler.fit_transform(X)
|
| 171 |
+
X = np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
|
| 172 |
+
self.clf.fit(X, y)
|
| 173 |
+
self.classes_ = self.clf.classes_
|
| 174 |
+
return self
|
| 175 |
+
|
| 176 |
+
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 177 |
+
return self.clf.predict(self._prepare_features(contexts))
|
| 178 |
+
|
| 179 |
+
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 180 |
+
return self.clf.predict_proba(self._prepare_features(contexts))
|
| 181 |
+
|
| 182 |
+
def explain_pair_scores(self, ctx: QueryContext) -> dict:
|
| 183 |
+
scores = self._score_pairs([ctx])[0]
|
| 184 |
+
return {
|
| 185 |
+
PAIR_NAMES[0]: float(scores[0]),
|
| 186 |
+
PAIR_NAMES[1]: float(scores[1]),
|
| 187 |
+
PAIR_NAMES[2]: float(scores[2]),
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class FineTunedCrossEncoderClassifier:
|
| 192 |
+
"""
|
| 193 |
+
End-to-end fine-tuned cross-encoder (highest accuracy).
|
| 194 |
+
|
| 195 |
+
Single cross-attention pass over [task_context | student_query] with
|
| 196 |
+
num_labels=15. Slower to train; best on smaller high-quality datasets.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
cross_encoder_name: str = DEFAULT_FINETUNED_CE,
|
| 202 |
+
batch_size: int = 16,
|
| 203 |
+
max_length: int = 512,
|
| 204 |
+
num_labels: int = 15,
|
| 205 |
+
):
|
| 206 |
+
self.cross_encoder_name = cross_encoder_name
|
| 207 |
+
self.batch_size = batch_size
|
| 208 |
+
self.max_length = max_length
|
| 209 |
+
self.num_labels = num_labels
|
| 210 |
+
self.model = None
|
| 211 |
+
self.classes_: Optional[np.ndarray] = None
|
| 212 |
+
|
| 213 |
+
def _load_model(self, num_labels: Optional[int] = None):
|
| 214 |
+
if self.model is None:
|
| 215 |
+
from sentence_transformers import CrossEncoder
|
| 216 |
+
|
| 217 |
+
self.model = CrossEncoder(
|
| 218 |
+
self.cross_encoder_name,
|
| 219 |
+
num_labels=num_labels or self.num_labels,
|
| 220 |
+
max_length=self.max_length,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def _to_examples(self, contexts: List[QueryContext], labels: Optional[np.ndarray] = None):
|
| 224 |
+
from sentence_transformers import InputExample
|
| 225 |
+
|
| 226 |
+
examples = []
|
| 227 |
+
for i, ctx in enumerate(contexts):
|
| 228 |
+
label = float(labels[i]) if labels is not None else 0.0
|
| 229 |
+
examples.append(
|
| 230 |
+
InputExample(
|
| 231 |
+
texts=[_context_text(ctx), _student_text(ctx)],
|
| 232 |
+
label=label,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
return examples
|
| 236 |
+
|
| 237 |
+
def fit(
|
| 238 |
+
self,
|
| 239 |
+
contexts: List[QueryContext],
|
| 240 |
+
y: np.ndarray,
|
| 241 |
+
epochs: int = 1,
|
| 242 |
+
warmup_steps: int = 100,
|
| 243 |
+
output_path: Optional[Path] = None,
|
| 244 |
+
) -> "FineTunedCrossEncoderClassifier":
|
| 245 |
+
from torch.utils.data import DataLoader
|
| 246 |
+
|
| 247 |
+
self._load_model(num_labels=len(np.unique(y)))
|
| 248 |
+
train_examples = self._to_examples(contexts, y)
|
| 249 |
+
loader = DataLoader(
|
| 250 |
+
train_examples,
|
| 251 |
+
shuffle=True,
|
| 252 |
+
batch_size=self.batch_size,
|
| 253 |
+
)
|
| 254 |
+
self.model.fit(
|
| 255 |
+
train_dataloader=loader,
|
| 256 |
+
epochs=epochs,
|
| 257 |
+
warmup_steps=min(warmup_steps, max(10, len(train_examples) // 10)),
|
| 258 |
+
show_progress_bar=True,
|
| 259 |
+
output_path=str(output_path) if output_path else None,
|
| 260 |
+
)
|
| 261 |
+
self.classes_ = np.sort(np.unique(y))
|
| 262 |
+
return self
|
| 263 |
+
|
| 264 |
+
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 265 |
+
self._load_model()
|
| 266 |
+
pairs = [[_context_text(c), _student_text(c)] for c in contexts]
|
| 267 |
+
logits = self.model.predict(
|
| 268 |
+
pairs,
|
| 269 |
+
batch_size=self.batch_size,
|
| 270 |
+
show_progress_bar=False,
|
| 271 |
+
convert_to_numpy=True,
|
| 272 |
+
)
|
| 273 |
+
logits = np.asarray(logits)
|
| 274 |
+
if logits.ndim == 1:
|
| 275 |
+
return logits.astype(int)
|
| 276 |
+
return logits.argmax(axis=1)
|
| 277 |
+
|
| 278 |
+
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 279 |
+
self._load_model()
|
| 280 |
+
pairs = [[_context_text(c), _student_text(c)] for c in contexts]
|
| 281 |
+
logits = self.model.predict(
|
| 282 |
+
pairs,
|
| 283 |
+
batch_size=self.batch_size,
|
| 284 |
+
show_progress_bar=False,
|
| 285 |
+
convert_to_numpy=True,
|
| 286 |
+
)
|
| 287 |
+
logits = np.asarray(logits, dtype=np.float64)
|
| 288 |
+
if logits.ndim == 1:
|
| 289 |
+
# binary fallback
|
| 290 |
+
probs = np.zeros((len(contexts), len(self.classes_)))
|
| 291 |
+
for i, pred in enumerate(logits.astype(int)):
|
| 292 |
+
idx = np.where(self.classes_ == pred)[0][0]
|
| 293 |
+
probs[i, idx] = 1.0
|
| 294 |
+
return probs
|
| 295 |
+
# softmax
|
| 296 |
+
exp = np.exp(logits - logits.max(axis=1, keepdims=True))
|
| 297 |
+
return exp / exp.sum(axis=1, keepdims=True)
|
| 298 |
+
|
| 299 |
+
def save(self, path: Path) -> Path:
|
| 300 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
self._load_model()
|
| 302 |
+
self.model.save(str(path))
|
| 303 |
+
return path
|
| 304 |
+
|
| 305 |
+
@classmethod
|
| 306 |
+
def load(cls, path: Path) -> "FineTunedCrossEncoderClassifier":
|
| 307 |
+
from sentence_transformers import CrossEncoder
|
| 308 |
+
|
| 309 |
+
instance = cls()
|
| 310 |
+
instance.model = CrossEncoder(str(path))
|
| 311 |
+
instance.classes_ = np.arange(instance.model.num_labels)
|
| 312 |
+
return instance
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate trained model with confusion matrix and per-class metrics."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use("Agg")
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
|
| 15 |
+
|
| 16 |
+
from src.categories import id_to_name, load_categories
|
| 17 |
+
from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
|
| 18 |
+
from src.cross_encoder_model import (
|
| 19 |
+
CrossEncoderClassifier,
|
| 20 |
+
FineTunedCrossEncoderClassifier,
|
| 21 |
+
)
|
| 22 |
+
from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe
|
| 23 |
+
|
| 24 |
+
CONTEXT_MODELS = (
|
| 25 |
+
CrossEncoderClassifier,
|
| 26 |
+
FineTunedCrossEncoderClassifier,
|
| 27 |
+
MultiTowerClassifier,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 31 |
+
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
|
| 32 |
+
DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "evaluation"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def evaluate(
|
| 36 |
+
data_path: Path = DEFAULT_DATA,
|
| 37 |
+
model_path: Path = DEFAULT_MODEL_PATH,
|
| 38 |
+
output_dir: Path = DEFAULT_OUTPUT,
|
| 39 |
+
sample_size: int = 100_000,
|
| 40 |
+
use_error_message: bool = True,
|
| 41 |
+
seed: int = 42,
|
| 42 |
+
) -> dict:
|
| 43 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
df = pd.read_parquet(data_path)
|
| 46 |
+
if len(df) > sample_size:
|
| 47 |
+
df = df.sample(n=sample_size, random_state=seed)
|
| 48 |
+
|
| 49 |
+
labels = df["label_id"].values
|
| 50 |
+
model = load_model(model_path)
|
| 51 |
+
|
| 52 |
+
if isinstance(model, CONTEXT_MODELS):
|
| 53 |
+
if not use_error_message and "error_message" in df.columns:
|
| 54 |
+
df = df.drop(columns=["error_message"])
|
| 55 |
+
preds = model.predict(contexts_from_dataframe(df))
|
| 56 |
+
else:
|
| 57 |
+
texts = combine_features(
|
| 58 |
+
queries=df["query"].tolist(),
|
| 59 |
+
error_messages=df["error_message"].tolist() if use_error_message else None,
|
| 60 |
+
schemas=df["schema"].tolist() if "schema" in df.columns else None,
|
| 61 |
+
questions=df["question"].tolist() if "question" in df.columns else None,
|
| 62 |
+
)
|
| 63 |
+
preds = model.predict(texts)
|
| 64 |
+
|
| 65 |
+
categories = load_categories()
|
| 66 |
+
target_names = [c.name for c in categories]
|
| 67 |
+
report = classification_report(
|
| 68 |
+
labels, preds, target_names=target_names, output_dict=True, zero_division=0
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
with open(output_dir / "classification_report.json", "w") as f:
|
| 72 |
+
json.dump(report, f, indent=2)
|
| 73 |
+
|
| 74 |
+
cm = ConfusionMatrixDisplay.from_predictions(
|
| 75 |
+
labels,
|
| 76 |
+
preds,
|
| 77 |
+
display_labels=target_names,
|
| 78 |
+
xticks_rotation=90,
|
| 79 |
+
cmap="Blues",
|
| 80 |
+
colorbar=False,
|
| 81 |
+
)
|
| 82 |
+
fig = cm.figure_
|
| 83 |
+
fig.set_size_inches(14, 12)
|
| 84 |
+
fig.tight_layout()
|
| 85 |
+
fig.savefig(output_dir / "confusion_matrix.png", dpi=150)
|
| 86 |
+
plt.close(fig)
|
| 87 |
+
|
| 88 |
+
print(f"Accuracy: {report['accuracy']:.4f}")
|
| 89 |
+
print(f"Reports saved to {output_dir}")
|
| 90 |
+
return report
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def main() -> None:
|
| 94 |
+
parser = argparse.ArgumentParser(description="Evaluate SQL error classifier")
|
| 95 |
+
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
|
| 96 |
+
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
|
| 97 |
+
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
|
| 98 |
+
parser.add_argument("--sample-size", type=int, default=100_000)
|
| 99 |
+
parser.add_argument("--no-error-message", action="store_true")
|
| 100 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 101 |
+
args = parser.parse_args()
|
| 102 |
+
|
| 103 |
+
evaluate(
|
| 104 |
+
data_path=args.data,
|
| 105 |
+
model_path=args.model,
|
| 106 |
+
output_dir=args.output,
|
| 107 |
+
sample_size=args.sample_size,
|
| 108 |
+
use_error_message=not args.no_error_message,
|
| 109 |
+
seed=args.seed,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
src/exercises.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate playground exercises: schema + question + correct SQL."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Callable, Dict, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class Exercise:
|
| 12 |
+
schema: str
|
| 13 |
+
question: str
|
| 14 |
+
correct_query: str
|
| 15 |
+
tables: tuple[str, ...]
|
| 16 |
+
columns: tuple[str, ...]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _fmt_schema(tables: Dict[str, List[str]]) -> str:
|
| 20 |
+
parts = [f"{name}({', '.join(cols)})" for name, cols in tables.items()]
|
| 21 |
+
return " | ".join(parts)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ExerciseBuilder = Callable[[random.Random], Exercise]
|
| 25 |
+
|
| 26 |
+
EXERCISE_BUILDERS: List[ExerciseBuilder] = []
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _register(builder: ExerciseBuilder) -> ExerciseBuilder:
|
| 30 |
+
EXERCISE_BUILDERS.append(builder)
|
| 31 |
+
return builder
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@_register
|
| 35 |
+
def exercise_avg_by_department(rng: random.Random) -> Exercise:
|
| 36 |
+
tables = {
|
| 37 |
+
"students": ["id", "name", "email", "score", "department_id"],
|
| 38 |
+
"departments": ["id", "name", "city"],
|
| 39 |
+
}
|
| 40 |
+
return Exercise(
|
| 41 |
+
schema=_fmt_schema(tables),
|
| 42 |
+
question="What is the average score of students in each department?",
|
| 43 |
+
correct_query=(
|
| 44 |
+
"SELECT department_id, AVG(score) "
|
| 45 |
+
"FROM students GROUP BY department_id"
|
| 46 |
+
),
|
| 47 |
+
tables=tuple(tables),
|
| 48 |
+
columns=("department_id", "score"),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@_register
|
| 53 |
+
def exercise_student_department_names(rng: random.Random) -> Exercise:
|
| 54 |
+
tables = {
|
| 55 |
+
"students": ["id", "name", "department_id"],
|
| 56 |
+
"departments": ["id", "name"],
|
| 57 |
+
}
|
| 58 |
+
return Exercise(
|
| 59 |
+
schema=_fmt_schema(tables),
|
| 60 |
+
question="List each student's name along with their department name.",
|
| 61 |
+
correct_query=(
|
| 62 |
+
"SELECT students.name, departments.name "
|
| 63 |
+
"FROM students "
|
| 64 |
+
"INNER JOIN departments ON students.department_id = departments.id"
|
| 65 |
+
),
|
| 66 |
+
tables=tuple(tables),
|
| 67 |
+
columns=("name", "department_id"),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@_register
|
| 72 |
+
def exercise_high_scoring_students(rng: random.Random) -> Exercise:
|
| 73 |
+
threshold = rng.randint(70, 90)
|
| 74 |
+
tables = {"students": ["id", "name", "age", "score", "status"]}
|
| 75 |
+
return Exercise(
|
| 76 |
+
schema=_fmt_schema(tables),
|
| 77 |
+
question=(
|
| 78 |
+
f"Find names of students older than 18 with a score above {threshold}."
|
| 79 |
+
),
|
| 80 |
+
correct_query=(
|
| 81 |
+
f"SELECT name FROM students "
|
| 82 |
+
f"WHERE age > 18 AND score > {threshold}"
|
| 83 |
+
),
|
| 84 |
+
tables=tuple(tables),
|
| 85 |
+
columns=("name", "age", "score", "status"),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@_register
|
| 90 |
+
def exercise_unique_cities(rng: random.Random) -> Exercise:
|
| 91 |
+
tables = {"students": ["id", "name", "city", "country"]}
|
| 92 |
+
return Exercise(
|
| 93 |
+
schema=_fmt_schema(tables),
|
| 94 |
+
question="List the unique cities where students live.",
|
| 95 |
+
correct_query="SELECT DISTINCT city FROM students",
|
| 96 |
+
tables=tuple(tables),
|
| 97 |
+
columns=("city",),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@_register
|
| 102 |
+
def exercise_top_scorer(rng: random.Random) -> Exercise:
|
| 103 |
+
tables = {"students": ["id", "name", "score"], "grades": ["id", "score"]}
|
| 104 |
+
return Exercise(
|
| 105 |
+
schema=_fmt_schema(tables),
|
| 106 |
+
question="Find students whose score equals the highest score in the class.",
|
| 107 |
+
correct_query=(
|
| 108 |
+
"SELECT name FROM students "
|
| 109 |
+
"WHERE score = (SELECT MAX(score) FROM grades)"
|
| 110 |
+
),
|
| 111 |
+
tables=tuple(tables),
|
| 112 |
+
columns=("name", "score"),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@_register
|
| 117 |
+
def exercise_departments_over_budget(rng: random.Random) -> Exercise:
|
| 118 |
+
budget = rng.randint(3, 8)
|
| 119 |
+
tables = {"employees": ["id", "name", "department_id", "salary"]}
|
| 120 |
+
return Exercise(
|
| 121 |
+
schema=_fmt_schema(tables),
|
| 122 |
+
question=f"Which departments have more than {budget} employees?",
|
| 123 |
+
correct_query=(
|
| 124 |
+
f"SELECT department_id, COUNT(*) AS cnt "
|
| 125 |
+
f"FROM employees GROUP BY department_id "
|
| 126 |
+
f"HAVING COUNT(*) > {budget}"
|
| 127 |
+
),
|
| 128 |
+
tables=tuple(tables),
|
| 129 |
+
columns=("department_id", "salary"),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@_register
|
| 134 |
+
def exercise_recent_orders(rng: random.Random) -> Exercise:
|
| 135 |
+
year = rng.randint(2020, 2024)
|
| 136 |
+
tables = {"orders": ["id", "customer_id", "amount", "order_date", "status"]}
|
| 137 |
+
return Exercise(
|
| 138 |
+
schema=_fmt_schema(tables),
|
| 139 |
+
question=f"Show orders placed on or after January 1, {year}.",
|
| 140 |
+
correct_query=(
|
| 141 |
+
f"SELECT id, amount FROM orders "
|
| 142 |
+
f"WHERE order_date >= DATE '{year}-01-01'"
|
| 143 |
+
),
|
| 144 |
+
tables=tuple(tables),
|
| 145 |
+
columns=("order_date", "amount", "status"),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@_register
|
| 150 |
+
def exercise_missing_email(rng: random.Random) -> Exercise:
|
| 151 |
+
tables = {"students": ["id", "name", "email", "phone"]}
|
| 152 |
+
return Exercise(
|
| 153 |
+
schema=_fmt_schema(tables),
|
| 154 |
+
question="Find students who have not provided an email address.",
|
| 155 |
+
correct_query="SELECT name FROM students WHERE email IS NULL",
|
| 156 |
+
tables=tuple(tables),
|
| 157 |
+
columns=("email", "name"),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@_register
|
| 162 |
+
def exercise_rank_by_score(rng: random.Random) -> Exercise:
|
| 163 |
+
tables = {"students": ["id", "name", "score", "department_id"]}
|
| 164 |
+
return Exercise(
|
| 165 |
+
schema=_fmt_schema(tables),
|
| 166 |
+
question="Rank students by score within each department.",
|
| 167 |
+
correct_query=(
|
| 168 |
+
"SELECT name, score, "
|
| 169 |
+
"RANK() OVER (PARTITION BY department_id ORDER BY score DESC) AS rnk "
|
| 170 |
+
"FROM students"
|
| 171 |
+
),
|
| 172 |
+
tables=tuple(tables),
|
| 173 |
+
columns=("name", "score", "department_id"),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@_register
|
| 178 |
+
def exercise_course_enrollment_count(rng: random.Random) -> Exercise:
|
| 179 |
+
tables = {
|
| 180 |
+
"courses": ["id", "title"],
|
| 181 |
+
"enrollments": ["id", "course_id", "student_id"],
|
| 182 |
+
}
|
| 183 |
+
return Exercise(
|
| 184 |
+
schema=_fmt_schema(tables),
|
| 185 |
+
question="How many students are enrolled in each course?",
|
| 186 |
+
correct_query=(
|
| 187 |
+
"SELECT courses.title, COUNT(enrollments.student_id) AS enrolled "
|
| 188 |
+
"FROM courses "
|
| 189 |
+
"INNER JOIN enrollments ON courses.id = enrollments.course_id "
|
| 190 |
+
"GROUP BY courses.title"
|
| 191 |
+
),
|
| 192 |
+
tables=tuple(tables),
|
| 193 |
+
columns=("title", "student_id", "course_id"),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@_register
|
| 198 |
+
def exercise_active_employees(rng: random.Random) -> Exercise:
|
| 199 |
+
tables = {"employees": ["id", "name", "salary", "status", "hire_date"]}
|
| 200 |
+
return Exercise(
|
| 201 |
+
schema=_fmt_schema(tables),
|
| 202 |
+
question="What is the total salary paid to active employees?",
|
| 203 |
+
correct_query=(
|
| 204 |
+
"SELECT SUM(salary) FROM employees WHERE status = 'active'"
|
| 205 |
+
),
|
| 206 |
+
tables=tuple(tables),
|
| 207 |
+
columns=("salary", "status"),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@_register
|
| 212 |
+
def exercise_product_price_filter(rng: random.Random) -> Exercise:
|
| 213 |
+
lo, hi = rng.randint(10, 50), rng.randint(100, 500)
|
| 214 |
+
tables = {"products": ["id", "name", "price", "category"]}
|
| 215 |
+
return Exercise(
|
| 216 |
+
schema=_fmt_schema(tables),
|
| 217 |
+
question=f"List products priced between {lo} and {hi}.",
|
| 218 |
+
correct_query=(
|
| 219 |
+
f"SELECT name, price FROM products "
|
| 220 |
+
f"WHERE price BETWEEN {lo} AND {hi}"
|
| 221 |
+
),
|
| 222 |
+
tables=tuple(tables),
|
| 223 |
+
columns=("name", "price", "category"),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def generate_exercise(rng: random.Random) -> Exercise:
|
| 228 |
+
return rng.choice(EXERCISE_BUILDERS)(rng)
|
src/generate_dataset.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate labeled SQL error dataset at scale."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import random
|
| 7 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from src.categories import load_categories
|
| 15 |
+
from src.exercises import generate_exercise
|
| 16 |
+
from src.sql_templates import ERROR_INJECTORS
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 19 |
+
DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def generate_dataset(
|
| 23 |
+
total_samples: int = 1_000_000,
|
| 24 |
+
output_path: Path = DEFAULT_OUTPUT,
|
| 25 |
+
batch_size: int = 10_000,
|
| 26 |
+
workers: int = 8,
|
| 27 |
+
seed: int = 42,
|
| 28 |
+
) -> Path:
|
| 29 |
+
categories = load_categories()
|
| 30 |
+
label_ids = [c.id for c in categories]
|
| 31 |
+
samples_per_class = total_samples // len(label_ids)
|
| 32 |
+
remainder = total_samples % len(label_ids)
|
| 33 |
+
|
| 34 |
+
# Balanced label schedule: each class gets equal share (+1 for first `remainder` classes)
|
| 35 |
+
schedule: List[int] = []
|
| 36 |
+
for cat in categories:
|
| 37 |
+
count = samples_per_class + (1 if cat.id < remainder else 0)
|
| 38 |
+
schedule.extend([cat.id] * count)
|
| 39 |
+
random.Random(seed).shuffle(schedule)
|
| 40 |
+
|
| 41 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
chunks: List[pd.DataFrame] = []
|
| 43 |
+
num_batches = (total_samples + batch_size - 1) // batch_size
|
| 44 |
+
|
| 45 |
+
with ProcessPoolExecutor(max_workers=workers) as executor:
|
| 46 |
+
futures = []
|
| 47 |
+
offset = 0
|
| 48 |
+
for batch_idx in range(num_batches):
|
| 49 |
+
current_batch = min(batch_size, total_samples - offset)
|
| 50 |
+
batch_labels = schedule[offset : offset + current_batch]
|
| 51 |
+
futures.append(
|
| 52 |
+
executor.submit(
|
| 53 |
+
_generate_batch_with_labels,
|
| 54 |
+
batch_labels,
|
| 55 |
+
seed + batch_idx,
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
offset += current_batch
|
| 59 |
+
|
| 60 |
+
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
|
| 61 |
+
rows = future.result()
|
| 62 |
+
chunks.append(pd.DataFrame(rows))
|
| 63 |
+
|
| 64 |
+
df = pd.concat(chunks, ignore_index=True)
|
| 65 |
+
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
| 66 |
+
df.to_parquet(output_path, index=False)
|
| 67 |
+
|
| 68 |
+
print(f"Saved {len(df):,} samples to {output_path}")
|
| 69 |
+
print("\nClass distribution:")
|
| 70 |
+
print(df["label_name"].value_counts().sort_index().to_string())
|
| 71 |
+
return output_path
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _generate_batch_with_labels(label_ids: List[int], seed: int) -> List[Dict]:
|
| 75 |
+
rng = random.Random(seed)
|
| 76 |
+
categories = load_categories()
|
| 77 |
+
rows = []
|
| 78 |
+
for label_id in label_ids:
|
| 79 |
+
exercise = generate_exercise(rng)
|
| 80 |
+
injector = ERROR_INJECTORS[label_id]
|
| 81 |
+
query, error_message = injector(rng, exercise)
|
| 82 |
+
rows.append(
|
| 83 |
+
{
|
| 84 |
+
"schema": exercise.schema,
|
| 85 |
+
"question": exercise.question,
|
| 86 |
+
"correct_query": exercise.correct_query,
|
| 87 |
+
"query": query.strip(),
|
| 88 |
+
"error_message": error_message,
|
| 89 |
+
"label_id": label_id,
|
| 90 |
+
"label_name": categories[label_id].name,
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
return rows
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def main() -> None:
|
| 97 |
+
parser = argparse.ArgumentParser(description="Generate labeled SQL error dataset")
|
| 98 |
+
parser.add_argument("--samples", type=int, default=1_000_000, help="Total samples")
|
| 99 |
+
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
|
| 100 |
+
parser.add_argument("--batch-size", type=int, default=10_000)
|
| 101 |
+
parser.add_argument("--workers", type=int, default=8)
|
| 102 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
|
| 105 |
+
generate_dataset(
|
| 106 |
+
total_samples=args.samples,
|
| 107 |
+
output_path=args.output,
|
| 108 |
+
batch_size=args.batch_size,
|
| 109 |
+
workers=args.workers,
|
| 110 |
+
seed=args.seed,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
main()
|
src/hf_eval_codebert.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate a trained CodeBERT cross-encoder."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
|
| 11 |
+
|
| 12 |
+
from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe
|
| 13 |
+
from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics
|
| 14 |
+
|
| 15 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
+
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
|
| 17 |
+
DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def evaluate(
|
| 21 |
+
model_dir: Path = DEFAULT_MODEL,
|
| 22 |
+
data_path: Path = DEFAULT_DATA,
|
| 23 |
+
sample_size: int = 10_000,
|
| 24 |
+
threshold: float = 0.5,
|
| 25 |
+
seed: int = 42,
|
| 26 |
+
) -> dict:
|
| 27 |
+
df = normalize_dataframe(pd.read_parquet(data_path))
|
| 28 |
+
if len(df) > sample_size:
|
| 29 |
+
df = df.sample(n=sample_size, random_state=seed)
|
| 30 |
+
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
| 32 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
| 33 |
+
dataset = SQLCodeBERTDataset(df, tokenizer)
|
| 34 |
+
|
| 35 |
+
trainer_kwargs = dict(
|
| 36 |
+
model=model,
|
| 37 |
+
data_collator=SQLCodeBERTDataCollator(tokenizer),
|
| 38 |
+
compute_metrics=build_compute_metrics(threshold=threshold),
|
| 39 |
+
)
|
| 40 |
+
try:
|
| 41 |
+
trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
|
| 42 |
+
except TypeError:
|
| 43 |
+
trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)
|
| 44 |
+
|
| 45 |
+
output = trainer.predict(dataset)
|
| 46 |
+
metrics = compute_multilabel_metrics(
|
| 47 |
+
output.predictions, output.label_ids, threshold=threshold
|
| 48 |
+
)
|
| 49 |
+
print(json.dumps(metrics, indent=2))
|
| 50 |
+
return metrics
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main() -> None:
|
| 54 |
+
parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier")
|
| 55 |
+
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL)
|
| 56 |
+
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
|
| 57 |
+
parser.add_argument("--sample-size", type=int, default=10_000)
|
| 58 |
+
parser.add_argument("--threshold", type=float, default=0.5)
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
evaluate(
|
| 61 |
+
model_dir=args.model_dir,
|
| 62 |
+
data_path=args.data,
|
| 63 |
+
sample_size=args.sample_size,
|
| 64 |
+
threshold=args.threshold,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
src/hf_metrics.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation metrics for multi-label SQL error classification."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.metrics import (
|
| 9 |
+
accuracy_score,
|
| 10 |
+
f1_score,
|
| 11 |
+
hamming_loss,
|
| 12 |
+
precision_score,
|
| 13 |
+
recall_score,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def sigmoid(x: np.ndarray) -> np.ndarray:
|
| 18 |
+
return 1.0 / (1.0 + np.exp(-x))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def compute_multilabel_metrics(
|
| 22 |
+
logits: np.ndarray,
|
| 23 |
+
labels: np.ndarray,
|
| 24 |
+
threshold: float = 0.5,
|
| 25 |
+
) -> Dict[str, float]:
|
| 26 |
+
probs = sigmoid(logits)
|
| 27 |
+
preds = (probs >= threshold).astype(int)
|
| 28 |
+
labels = labels.astype(int)
|
| 29 |
+
|
| 30 |
+
return {
|
| 31 |
+
"accuracy": float(accuracy_score(labels, preds)),
|
| 32 |
+
"f1_macro": float(f1_score(labels, preds, average="macro", zero_division=0)),
|
| 33 |
+
"f1_micro": float(f1_score(labels, preds, average="micro", zero_division=0)),
|
| 34 |
+
"precision_macro": float(
|
| 35 |
+
precision_score(labels, preds, average="macro", zero_division=0)
|
| 36 |
+
),
|
| 37 |
+
"recall_macro": float(
|
| 38 |
+
recall_score(labels, preds, average="macro", zero_division=0)
|
| 39 |
+
),
|
| 40 |
+
"hamming_loss": float(hamming_loss(labels, preds)),
|
| 41 |
+
"subset_accuracy": float((preds == labels).all(axis=1).mean()),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build_compute_metrics(threshold: float = 0.5):
|
| 46 |
+
"""Factory for Hugging Face Trainer compute_metrics callback."""
|
| 47 |
+
|
| 48 |
+
def compute_metrics(eval_pred) -> Dict[str, float]:
|
| 49 |
+
logits, labels = eval_pred
|
| 50 |
+
return compute_multilabel_metrics(logits, labels, threshold=threshold)
|
| 51 |
+
|
| 52 |
+
return compute_metrics
|
src/hf_predict_codebert.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference for CodeBERT SQL error cross-encoder."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from src.codebert_formatting import format_cross_encoder_input
|
| 15 |
+
from src.codebert_labels import load_codebert_labels, multihot_to_label_names
|
| 16 |
+
from src.hf_metrics import sigmoid
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 19 |
+
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CodeBERTSQLErrorClassifier:
|
| 23 |
+
"""CodeBERT cross-encoder inference wrapper."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
model_dir: Union[str, Path] = DEFAULT_MODEL_DIR,
|
| 28 |
+
threshold: float = 0.5,
|
| 29 |
+
device: Optional[str] = None,
|
| 30 |
+
):
|
| 31 |
+
self.model_dir = Path(model_dir)
|
| 32 |
+
self.threshold = threshold
|
| 33 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
|
| 35 |
+
config_path = self.model_dir / "label_config.json"
|
| 36 |
+
if config_path.exists():
|
| 37 |
+
with open(config_path) as f:
|
| 38 |
+
cfg = json.load(f)
|
| 39 |
+
self.label_list = cfg.get("labels", load_codebert_labels())
|
| 40 |
+
self.threshold = cfg.get("threshold", threshold)
|
| 41 |
+
self.max_length = cfg.get("max_length", 512)
|
| 42 |
+
else:
|
| 43 |
+
self.label_list = load_codebert_labels()
|
| 44 |
+
self.max_length = 512
|
| 45 |
+
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
| 47 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 48 |
+
self.model_dir
|
| 49 |
+
).to(self.device)
|
| 50 |
+
self.model.eval()
|
| 51 |
+
|
| 52 |
+
def predict(
|
| 53 |
+
self,
|
| 54 |
+
question: str,
|
| 55 |
+
schema: str,
|
| 56 |
+
student_sql: str,
|
| 57 |
+
correct_sql: str,
|
| 58 |
+
threshold: Optional[float] = None,
|
| 59 |
+
top_k: int = 5,
|
| 60 |
+
) -> dict:
|
| 61 |
+
text = format_cross_encoder_input(
|
| 62 |
+
question=question,
|
| 63 |
+
schema=schema,
|
| 64 |
+
student_sql=student_sql,
|
| 65 |
+
correct_sql=correct_sql,
|
| 66 |
+
)
|
| 67 |
+
encoded = self.tokenizer(
|
| 68 |
+
text,
|
| 69 |
+
truncation=True,
|
| 70 |
+
max_length=self.max_length,
|
| 71 |
+
padding=True,
|
| 72 |
+
return_tensors="pt",
|
| 73 |
+
).to(self.device)
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
logits = self.model(**encoded).logits.cpu().numpy()[0]
|
| 77 |
+
|
| 78 |
+
probs = sigmoid(logits)
|
| 79 |
+
thr = threshold if threshold is not None else self.threshold
|
| 80 |
+
predicted = multihot_to_label_names(probs, self.label_list, threshold=thr)
|
| 81 |
+
|
| 82 |
+
ranked = sorted(
|
| 83 |
+
zip(self.label_list, probs.tolist()),
|
| 84 |
+
key=lambda x: x[1],
|
| 85 |
+
reverse=True,
|
| 86 |
+
)[:top_k]
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
"error_labels": predicted,
|
| 90 |
+
"probabilities": {name: float(p) for name, p in ranked},
|
| 91 |
+
"top_k": [
|
| 92 |
+
{"label": name, "probability": float(p)} for name, p in ranked
|
| 93 |
+
],
|
| 94 |
+
"primary_label": ranked[0][0],
|
| 95 |
+
"primary_confidence": float(ranked[0][1]),
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def predict_batch(
|
| 99 |
+
self,
|
| 100 |
+
examples: List[dict],
|
| 101 |
+
batch_size: int = 16,
|
| 102 |
+
) -> List[dict]:
|
| 103 |
+
results = []
|
| 104 |
+
for i in range(0, len(examples), batch_size):
|
| 105 |
+
chunk = examples[i : i + batch_size]
|
| 106 |
+
texts = [
|
| 107 |
+
format_cross_encoder_input(
|
| 108 |
+
question=x["question"],
|
| 109 |
+
schema=x["schema"],
|
| 110 |
+
student_sql=x["student_sql"],
|
| 111 |
+
correct_sql=x["correct_sql"],
|
| 112 |
+
)
|
| 113 |
+
for x in chunk
|
| 114 |
+
]
|
| 115 |
+
encoded = self.tokenizer(
|
| 116 |
+
texts,
|
| 117 |
+
truncation=True,
|
| 118 |
+
max_length=self.max_length,
|
| 119 |
+
padding=True,
|
| 120 |
+
return_tensors="pt",
|
| 121 |
+
).to(self.device)
|
| 122 |
+
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
logits = self.model(**encoded).logits.cpu().numpy()
|
| 125 |
+
|
| 126 |
+
for j, row in enumerate(logits):
|
| 127 |
+
probs = sigmoid(row)
|
| 128 |
+
results.append(
|
| 129 |
+
{
|
| 130 |
+
"error_labels": multihot_to_label_names(
|
| 131 |
+
probs, self.label_list, self.threshold
|
| 132 |
+
),
|
| 133 |
+
"primary_label": self.label_list[int(np.argmax(probs))],
|
| 134 |
+
"primary_confidence": float(np.max(probs)),
|
| 135 |
+
}
|
| 136 |
+
)
|
| 137 |
+
return results
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def main() -> None:
|
| 141 |
+
parser = argparse.ArgumentParser(description="CodeBERT SQL error inference")
|
| 142 |
+
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR)
|
| 143 |
+
parser.add_argument("--question", type=str, required=True)
|
| 144 |
+
parser.add_argument("--schema", type=str, required=True)
|
| 145 |
+
parser.add_argument("--student-sql", type=str, required=True)
|
| 146 |
+
parser.add_argument("--correct-sql", type=str, required=True)
|
| 147 |
+
parser.add_argument("--threshold", type=float, default=0.5)
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
|
| 150 |
+
clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold)
|
| 151 |
+
result = clf.predict(
|
| 152 |
+
question=args.question,
|
| 153 |
+
schema=args.schema,
|
| 154 |
+
student_sql=args.student_sql,
|
| 155 |
+
correct_sql=args.correct_sql,
|
| 156 |
+
)
|
| 157 |
+
print(json.dumps(result, indent=2))
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
src/hf_train_codebert.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train CodeBERT cross-encoder for SQL error classification with HF Trainer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoModelForSequenceClassification,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
EarlyStoppingCallback,
|
| 16 |
+
Trainer,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from src.codebert_dataset import (
|
| 21 |
+
SQLCodeBERTDataCollator,
|
| 22 |
+
prepare_datasets,
|
| 23 |
+
)
|
| 24 |
+
from src.codebert_labels import load_codebert_labels
|
| 25 |
+
from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics
|
| 26 |
+
|
| 27 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
+
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
|
| 29 |
+
DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "codebert-cross-encoder"
|
| 30 |
+
DEFAULT_MODEL = "microsoft/codebert-base"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def train(
|
| 34 |
+
data_path: Path | None = DEFAULT_DATA,
|
| 35 |
+
dataframe: pd.DataFrame | None = None,
|
| 36 |
+
output_dir: Path = DEFAULT_OUTPUT,
|
| 37 |
+
model_name: str = DEFAULT_MODEL,
|
| 38 |
+
epochs: float = 3.0,
|
| 39 |
+
batch_size: int = 16,
|
| 40 |
+
eval_batch_size: int = 32,
|
| 41 |
+
learning_rate: float = 2e-5,
|
| 42 |
+
weight_decay: float = 0.01,
|
| 43 |
+
warmup_ratio: float = 0.06,
|
| 44 |
+
max_length: int = 512,
|
| 45 |
+
max_samples: int | None = None,
|
| 46 |
+
test_size: float = 0.1,
|
| 47 |
+
val_size: float = 0.1,
|
| 48 |
+
threshold: float = 0.5,
|
| 49 |
+
seed: int = 42,
|
| 50 |
+
push_to_hub: bool = False,
|
| 51 |
+
hub_model_id: str | None = None,
|
| 52 |
+
fp16: bool = False,
|
| 53 |
+
save_strategy: str = "no",
|
| 54 |
+
hub_token: str | None = None,
|
| 55 |
+
) -> dict:
|
| 56 |
+
if dataframe is not None:
|
| 57 |
+
df = dataframe.copy()
|
| 58 |
+
print(f"Loaded dataframe with {len(df):,} rows")
|
| 59 |
+
elif data_path is not None:
|
| 60 |
+
print(f"Loading dataset from {data_path}...")
|
| 61 |
+
df = pd.read_parquet(data_path)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Either data_path or dataframe must be provided")
|
| 64 |
+
if max_samples and len(df) > max_samples:
|
| 65 |
+
df = df.sample(n=max_samples, random_state=seed)
|
| 66 |
+
|
| 67 |
+
label_list = load_codebert_labels()
|
| 68 |
+
num_labels = len(label_list)
|
| 69 |
+
print(f"Labels ({num_labels}): {label_list}")
|
| 70 |
+
print(f"Samples: {len(df):,}")
|
| 71 |
+
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 73 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 74 |
+
model_name,
|
| 75 |
+
num_labels=num_labels,
|
| 76 |
+
problem_type="multi_label_classification",
|
| 77 |
+
id2label={i: name for i, name in enumerate(label_list)},
|
| 78 |
+
label2id={name: i for i, name in enumerate(label_list)},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
train_ds, val_ds, test_ds = prepare_datasets(
|
| 82 |
+
df,
|
| 83 |
+
tokenizer,
|
| 84 |
+
test_size=test_size,
|
| 85 |
+
val_size=val_size,
|
| 86 |
+
max_length=max_length,
|
| 87 |
+
seed=seed,
|
| 88 |
+
)
|
| 89 |
+
print(f"Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
|
| 90 |
+
|
| 91 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
label_info = {
|
| 93 |
+
"labels": label_list,
|
| 94 |
+
"model_name": model_name,
|
| 95 |
+
"architecture": "codebert-cross-encoder",
|
| 96 |
+
"input_format": "QUESTION + SCHEMA + STUDENT_SQL + CORRECT_SQL",
|
| 97 |
+
"max_length": max_length,
|
| 98 |
+
"threshold": threshold,
|
| 99 |
+
}
|
| 100 |
+
with open(output_dir / "label_config.json", "w") as f:
|
| 101 |
+
json.dump(label_info, f, indent=2)
|
| 102 |
+
|
| 103 |
+
training_args = TrainingArguments(
|
| 104 |
+
output_dir=str(output_dir),
|
| 105 |
+
num_train_epochs=epochs,
|
| 106 |
+
per_device_train_batch_size=batch_size,
|
| 107 |
+
per_device_eval_batch_size=eval_batch_size,
|
| 108 |
+
learning_rate=learning_rate,
|
| 109 |
+
weight_decay=weight_decay,
|
| 110 |
+
warmup_ratio=warmup_ratio,
|
| 111 |
+
eval_strategy="epoch",
|
| 112 |
+
save_strategy=save_strategy,
|
| 113 |
+
logging_strategy="steps",
|
| 114 |
+
logging_steps=50,
|
| 115 |
+
load_best_model_at_end=save_strategy == "epoch",
|
| 116 |
+
metric_for_best_model="f1_macro",
|
| 117 |
+
greater_is_better=True,
|
| 118 |
+
save_total_limit=1,
|
| 119 |
+
seed=seed,
|
| 120 |
+
report_to="none",
|
| 121 |
+
fp16=fp16 and torch.cuda.is_available(),
|
| 122 |
+
push_to_hub=push_to_hub,
|
| 123 |
+
hub_model_id=hub_model_id,
|
| 124 |
+
hub_token=hub_token,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
callbacks = []
|
| 128 |
+
if save_strategy == "epoch":
|
| 129 |
+
callbacks.append(EarlyStoppingCallback(early_stopping_patience=2))
|
| 130 |
+
|
| 131 |
+
trainer_kwargs = dict(
|
| 132 |
+
model=model,
|
| 133 |
+
args=training_args,
|
| 134 |
+
train_dataset=train_ds,
|
| 135 |
+
eval_dataset=val_ds,
|
| 136 |
+
data_collator=SQLCodeBERTDataCollator(tokenizer),
|
| 137 |
+
compute_metrics=build_compute_metrics(threshold=threshold),
|
| 138 |
+
callbacks=callbacks,
|
| 139 |
+
)
|
| 140 |
+
try:
|
| 141 |
+
trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
|
| 142 |
+
except TypeError:
|
| 143 |
+
trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)
|
| 144 |
+
|
| 145 |
+
print("Starting CodeBERT cross-encoder training...")
|
| 146 |
+
train_result = trainer.train()
|
| 147 |
+
|
| 148 |
+
print("Evaluating on validation set...")
|
| 149 |
+
val_metrics = trainer.evaluate()
|
| 150 |
+
|
| 151 |
+
print("Evaluating on held-out test set...")
|
| 152 |
+
test_output = trainer.predict(test_ds)
|
| 153 |
+
test_metrics = compute_multilabel_metrics(
|
| 154 |
+
test_output.predictions,
|
| 155 |
+
test_output.label_ids,
|
| 156 |
+
threshold=threshold,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
trainer.save_model(str(output_dir))
|
| 160 |
+
tokenizer.save_pretrained(str(output_dir))
|
| 161 |
+
|
| 162 |
+
metrics = {
|
| 163 |
+
"train_samples": len(train_ds),
|
| 164 |
+
"val_samples": len(val_ds),
|
| 165 |
+
"test_samples": len(test_ds),
|
| 166 |
+
"train_runtime": train_result.metrics.get("train_runtime"),
|
| 167 |
+
"validation": val_metrics,
|
| 168 |
+
"test": test_metrics,
|
| 169 |
+
}
|
| 170 |
+
with open(output_dir / "metrics.json", "w") as f:
|
| 171 |
+
json.dump(metrics, f, indent=2, default=float)
|
| 172 |
+
|
| 173 |
+
print(f"\nValidation F1 (macro): {val_metrics.get('eval_f1_macro', 0):.4f}")
|
| 174 |
+
print(f"Test F1 (macro): {test_metrics['f1_macro']:.4f}")
|
| 175 |
+
print(f"Test subset accuracy: {test_metrics['subset_accuracy']:.4f}")
|
| 176 |
+
print(f"Model saved to {output_dir}")
|
| 177 |
+
return metrics
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def main() -> None:
|
| 181 |
+
parser = argparse.ArgumentParser(
|
| 182 |
+
description="Train CodeBERT cross-encoder with Hugging Face Trainer"
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
|
| 185 |
+
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT)
|
| 186 |
+
parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL)
|
| 187 |
+
parser.add_argument("--epochs", type=float, default=3.0)
|
| 188 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 189 |
+
parser.add_argument("--eval-batch-size", type=int, default=32)
|
| 190 |
+
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
| 191 |
+
parser.add_argument("--max-length", type=int, default=512)
|
| 192 |
+
parser.add_argument("--max-samples", type=int, default=None)
|
| 193 |
+
parser.add_argument("--threshold", type=float, default=0.5)
|
| 194 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 195 |
+
parser.add_argument("--push-to-hub", action="store_true")
|
| 196 |
+
parser.add_argument("--hub-model-id", type=str, default=None)
|
| 197 |
+
parser.add_argument("--fp16", action="store_true")
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--save-strategy",
|
| 200 |
+
choices=["no", "epoch"],
|
| 201 |
+
default="no",
|
| 202 |
+
help="Use 'no' to save only final model (saves disk space)",
|
| 203 |
+
)
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
|
| 206 |
+
train(
|
| 207 |
+
data_path=args.data,
|
| 208 |
+
output_dir=args.output_dir,
|
| 209 |
+
model_name=args.model_name,
|
| 210 |
+
epochs=args.epochs,
|
| 211 |
+
batch_size=args.batch_size,
|
| 212 |
+
eval_batch_size=args.eval_batch_size,
|
| 213 |
+
learning_rate=args.learning_rate,
|
| 214 |
+
max_length=args.max_length,
|
| 215 |
+
max_samples=args.max_samples,
|
| 216 |
+
threshold=args.threshold,
|
| 217 |
+
seed=args.seed,
|
| 218 |
+
push_to_hub=args.push_to_hub,
|
| 219 |
+
hub_model_id=args.hub_model_id,
|
| 220 |
+
fp16=args.fp16,
|
| 221 |
+
save_strategy=args.save_strategy,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
main()
|
src/huggingface.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Hub integration for the SQL error classifier."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Union
|
| 8 |
+
|
| 9 |
+
import joblib
|
| 10 |
+
|
| 11 |
+
from src.categories import load_categories
|
| 12 |
+
from src.cross_encoder_model import CrossEncoderClassifier
|
| 13 |
+
from src.model import DEFAULT_ENCODER, load_model
|
| 14 |
+
from src.multi_tower_model import MultiTowerClassifier, QueryContext
|
| 15 |
+
|
| 16 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 17 |
+
CONFIG_NAME = "config.json"
|
| 18 |
+
CLASSIFIER_NAME = "classifier.joblib"
|
| 19 |
+
CATEGORIES_NAME = "categories.json"
|
| 20 |
+
|
| 21 |
+
SUPPORTED_CONTEXT_MODELS = (CrossEncoderClassifier, MultiTowerClassifier)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SQLLErrorClassifierHF:
|
| 25 |
+
"""
|
| 26 |
+
Hugging Face–compatible wrapper for SQL error classifiers.
|
| 27 |
+
|
| 28 |
+
Usage:
|
| 29 |
+
clf = SQLLErrorClassifierHF.from_pretrained("username/sql-error-classifier")
|
| 30 |
+
result = clf.predict(
|
| 31 |
+
question="...", schema="...", correct_query="...", student_query="..."
|
| 32 |
+
)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, model, label_map: Dict[int, str]):
|
| 36 |
+
self.model = model
|
| 37 |
+
self.label_map = label_map
|
| 38 |
+
|
| 39 |
+
def predict(
|
| 40 |
+
self,
|
| 41 |
+
question: str,
|
| 42 |
+
schema: str,
|
| 43 |
+
correct_query: str,
|
| 44 |
+
student_query: str,
|
| 45 |
+
error_message: Optional[str] = None,
|
| 46 |
+
top_k: int = 3,
|
| 47 |
+
) -> Dict[str, Any]:
|
| 48 |
+
ctx = QueryContext(
|
| 49 |
+
question=question,
|
| 50 |
+
schema=schema,
|
| 51 |
+
correct_query=correct_query,
|
| 52 |
+
student_query=student_query,
|
| 53 |
+
error_message=error_message,
|
| 54 |
+
)
|
| 55 |
+
proba = self.model.predict_proba([ctx])[0]
|
| 56 |
+
classes = self.model.classes_
|
| 57 |
+
ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
|
| 58 |
+
best_id = int(ranked[0][0])
|
| 59 |
+
|
| 60 |
+
diagnostics: Dict[str, Any] = {}
|
| 61 |
+
if isinstance(self.model, CrossEncoderClassifier):
|
| 62 |
+
diagnostics["pair_scores"] = self.model.explain_pair_scores(ctx)
|
| 63 |
+
elif isinstance(self.model, MultiTowerClassifier):
|
| 64 |
+
diagnostics["similarities"] = self.model.explain_similarities(ctx)
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"label_id": best_id,
|
| 68 |
+
"label_name": self.label_map[best_id],
|
| 69 |
+
"confidence": float(ranked[0][1]),
|
| 70 |
+
"top_k": [
|
| 71 |
+
{
|
| 72 |
+
"label_id": int(cls),
|
| 73 |
+
"label_name": self.label_map[int(cls)],
|
| 74 |
+
"confidence": float(p),
|
| 75 |
+
}
|
| 76 |
+
for cls, p in ranked[:top_k]
|
| 77 |
+
],
|
| 78 |
+
**diagnostics,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def save_pretrained(self, save_directory: Union[str, Path]) -> Path:
|
| 82 |
+
"""Save model artifacts in Hugging Face Hub layout."""
|
| 83 |
+
save_dir = Path(save_directory)
|
| 84 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
if isinstance(self.model, CrossEncoderClassifier):
|
| 87 |
+
payload = {
|
| 88 |
+
"model_type": "cross_encoder",
|
| 89 |
+
"cross_encoder_name": self.model.cross_encoder_name,
|
| 90 |
+
"batch_size": self.model.batch_size,
|
| 91 |
+
"max_length": self.model.max_length,
|
| 92 |
+
"scaler": self.model.scaler,
|
| 93 |
+
"classifier": self.model.clf,
|
| 94 |
+
"classes_": self.model.classes_,
|
| 95 |
+
}
|
| 96 |
+
config = {
|
| 97 |
+
"model_type": "cross_encoder",
|
| 98 |
+
"architecture": "cross-encoder-pairwise",
|
| 99 |
+
"cross_encoder_name": self.model.cross_encoder_name,
|
| 100 |
+
"batch_size": self.model.batch_size,
|
| 101 |
+
"num_labels": len(self.label_map),
|
| 102 |
+
"task": "sql-error-classification",
|
| 103 |
+
}
|
| 104 |
+
elif isinstance(self.model, MultiTowerClassifier):
|
| 105 |
+
payload = {
|
| 106 |
+
"model_type": "multi_tower",
|
| 107 |
+
"encoder_name": self.model.encoder_name,
|
| 108 |
+
"batch_size": self.model.batch_size,
|
| 109 |
+
"scaler": self.model.scaler,
|
| 110 |
+
"classifier": self.model.clf,
|
| 111 |
+
"classes_": self.model.classes_,
|
| 112 |
+
}
|
| 113 |
+
config = {
|
| 114 |
+
"model_type": "multi_tower",
|
| 115 |
+
"architecture": "multi-tower-semantic-comparison",
|
| 116 |
+
"encoder_name": self.model.encoder_name,
|
| 117 |
+
"batch_size": self.model.batch_size,
|
| 118 |
+
"num_labels": len(self.label_map),
|
| 119 |
+
"task": "sql-error-classification",
|
| 120 |
+
}
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError("Only cross_encoder and multi_tower models can be published")
|
| 123 |
+
|
| 124 |
+
joblib.dump(payload, save_dir / CLASSIFIER_NAME)
|
| 125 |
+
with open(save_dir / CONFIG_NAME, "w") as f:
|
| 126 |
+
json.dump(config, f, indent=2)
|
| 127 |
+
|
| 128 |
+
categories = load_categories()
|
| 129 |
+
cat_data = [
|
| 130 |
+
{"id": c.id, "name": c.name, "description": c.description}
|
| 131 |
+
for c in categories
|
| 132 |
+
]
|
| 133 |
+
with open(save_dir / CATEGORIES_NAME, "w") as f:
|
| 134 |
+
json.dump(cat_data, f, indent=2)
|
| 135 |
+
|
| 136 |
+
return save_dir
|
| 137 |
+
|
| 138 |
+
@classmethod
|
| 139 |
+
def from_pretrained(
|
| 140 |
+
cls,
|
| 141 |
+
pretrained_model_name_or_path: Union[str, Path],
|
| 142 |
+
*,
|
| 143 |
+
token: Optional[str] = None,
|
| 144 |
+
) -> "SQLLErrorClassifierHF":
|
| 145 |
+
"""Load from a local directory or Hugging Face Hub repo."""
|
| 146 |
+
path = _resolve_model_path(pretrained_model_name_or_path, token=token)
|
| 147 |
+
|
| 148 |
+
with open(path / CONFIG_NAME) as f:
|
| 149 |
+
config = json.load(f)
|
| 150 |
+
|
| 151 |
+
with open(path / CATEGORIES_NAME) as f:
|
| 152 |
+
categories = json.load(f)
|
| 153 |
+
label_map = {c["id"]: c["name"] for c in categories}
|
| 154 |
+
|
| 155 |
+
obj = joblib.load(path / CLASSIFIER_NAME)
|
| 156 |
+
model_type = config.get("model_type", obj.get("model_type"))
|
| 157 |
+
|
| 158 |
+
if model_type == "cross_encoder":
|
| 159 |
+
model = CrossEncoderClassifier(
|
| 160 |
+
cross_encoder_name=obj.get(
|
| 161 |
+
"cross_encoder_name",
|
| 162 |
+
config.get("cross_encoder_name", "cross-encoder/ms-marco-MiniLM-L6-v2"),
|
| 163 |
+
),
|
| 164 |
+
batch_size=obj.get("batch_size", 32),
|
| 165 |
+
max_length=obj.get("max_length", 512),
|
| 166 |
+
)
|
| 167 |
+
model.scaler = obj["scaler"]
|
| 168 |
+
model.clf = obj["classifier"]
|
| 169 |
+
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
|
| 170 |
+
else:
|
| 171 |
+
model = MultiTowerClassifier(
|
| 172 |
+
encoder_name=obj.get("encoder_name", config.get("encoder_name", DEFAULT_ENCODER)),
|
| 173 |
+
batch_size=obj.get("batch_size", 256),
|
| 174 |
+
)
|
| 175 |
+
model.scaler = obj["scaler"]
|
| 176 |
+
model.clf = obj["classifier"]
|
| 177 |
+
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
|
| 178 |
+
|
| 179 |
+
return cls(model=model, label_map=label_map)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _resolve_model_path(
|
| 183 |
+
pretrained_model_name_or_path: Union[str, Path],
|
| 184 |
+
token: Optional[str] = None,
|
| 185 |
+
) -> Path:
|
| 186 |
+
local = Path(pretrained_model_name_or_path)
|
| 187 |
+
if local.exists() and (local / CONFIG_NAME).exists():
|
| 188 |
+
return local
|
| 189 |
+
|
| 190 |
+
from huggingface_hub import snapshot_download
|
| 191 |
+
|
| 192 |
+
return Path(
|
| 193 |
+
snapshot_download(
|
| 194 |
+
repo_id=str(pretrained_model_name_or_path),
|
| 195 |
+
token=token,
|
| 196 |
+
allow_patterns=[CONFIG_NAME, CLASSIFIER_NAME, CATEGORIES_NAME],
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def package_for_hub(model_path: Path, output_dir: Path) -> Path:
|
| 202 |
+
"""Convert a local joblib model into HF Hub layout."""
|
| 203 |
+
sklearn_model = load_model(model_path)
|
| 204 |
+
if not isinstance(sklearn_model, SUPPORTED_CONTEXT_MODELS):
|
| 205 |
+
raise ValueError(
|
| 206 |
+
"Only cross_encoder and multi_tower models can be published to Hugging Face Hub"
|
| 207 |
+
)
|
| 208 |
+
label_map = {c.id: c.name for c in load_categories()}
|
| 209 |
+
wrapper = SQLLErrorClassifierHF(model=sklearn_model, label_map=label_map)
|
| 210 |
+
return wrapper.save_pretrained(output_dir)
|
src/model.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQL error classifiers: TF-IDF baseline and MiniLM embedding model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Literal, Optional, Protocol, Union
|
| 8 |
+
|
| 9 |
+
import joblib
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 12 |
+
from sklearn.linear_model import SGDClassifier
|
| 13 |
+
from sklearn.pipeline import FeatureUnion, Pipeline
|
| 14 |
+
|
| 15 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
+
DEFAULT_MODEL_PATH = PROJECT_ROOT / "models" / "sql_error_classifier.joblib"
|
| 17 |
+
DEFAULT_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
|
| 18 |
+
|
| 19 |
+
ModelType = Literal["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TextClassifier(Protocol):
|
| 23 |
+
classes_: np.ndarray
|
| 24 |
+
|
| 25 |
+
def fit(self, texts: List[str], y: np.ndarray) -> "TextClassifier": ...
|
| 26 |
+
def predict(self, texts: List[str]) -> np.ndarray: ...
|
| 27 |
+
def predict_proba(self, texts: List[str]) -> np.ndarray: ...
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def combine_features(
|
| 31 |
+
queries: List[str],
|
| 32 |
+
error_messages: Optional[List[str]] = None,
|
| 33 |
+
schemas: Optional[List[str]] = None,
|
| 34 |
+
questions: Optional[List[str]] = None,
|
| 35 |
+
) -> List[str]:
|
| 36 |
+
"""Fuse question, schema, query, and optional error message."""
|
| 37 |
+
texts: List[str] = []
|
| 38 |
+
for i, query in enumerate(queries):
|
| 39 |
+
parts: List[str] = []
|
| 40 |
+
if questions and questions[i]:
|
| 41 |
+
parts.append(f"QUESTION: {questions[i]}")
|
| 42 |
+
if schemas and schemas[i]:
|
| 43 |
+
parts.append(f"SCHEMA: {schemas[i]}")
|
| 44 |
+
parts.append(f"QUERY: {query}")
|
| 45 |
+
if error_messages and error_messages[i]:
|
| 46 |
+
parts.append(f"ERROR: {error_messages[i]}")
|
| 47 |
+
texts.append(" ".join(parts))
|
| 48 |
+
return texts
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _build_text_features() -> FeatureUnion:
|
| 52 |
+
return FeatureUnion(
|
| 53 |
+
[
|
| 54 |
+
(
|
| 55 |
+
"word",
|
| 56 |
+
TfidfVectorizer(
|
| 57 |
+
analyzer="word",
|
| 58 |
+
ngram_range=(1, 2),
|
| 59 |
+
max_features=30_000,
|
| 60 |
+
sublinear_tf=True,
|
| 61 |
+
strip_accents="unicode",
|
| 62 |
+
token_pattern=r"(?u)\b\w+\b|(?<=[=<>!])\S+",
|
| 63 |
+
),
|
| 64 |
+
),
|
| 65 |
+
(
|
| 66 |
+
"char",
|
| 67 |
+
TfidfVectorizer(
|
| 68 |
+
analyzer="char_wb",
|
| 69 |
+
ngram_range=(2, 5),
|
| 70 |
+
max_features=20_000,
|
| 71 |
+
sublinear_tf=True,
|
| 72 |
+
),
|
| 73 |
+
),
|
| 74 |
+
]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_tfidf_classifier() -> Pipeline:
|
| 79 |
+
"""Bag-of-words baseline. Fast but no deep semantic understanding."""
|
| 80 |
+
clf = SGDClassifier(
|
| 81 |
+
loss="log_loss",
|
| 82 |
+
penalty="l2",
|
| 83 |
+
alpha=1e-5,
|
| 84 |
+
max_iter=1000,
|
| 85 |
+
tol=1e-3,
|
| 86 |
+
class_weight="balanced",
|
| 87 |
+
random_state=42,
|
| 88 |
+
)
|
| 89 |
+
return Pipeline([("tfidf", _build_text_features()), ("clf", clf)])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class EmbeddingClassifier:
|
| 93 |
+
"""
|
| 94 |
+
MiniLM sentence embeddings + linear classifier.
|
| 95 |
+
|
| 96 |
+
Understands question intent (e.g. 'average' vs wrong aggregate) because
|
| 97 |
+
the encoder models full sentence context, not isolated word counts.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
encoder_name: str = DEFAULT_ENCODER,
|
| 103 |
+
batch_size: int = 256,
|
| 104 |
+
):
|
| 105 |
+
self.encoder_name = encoder_name
|
| 106 |
+
self.batch_size = batch_size
|
| 107 |
+
self.encoder = None
|
| 108 |
+
self.clf = SGDClassifier(
|
| 109 |
+
loss="log_loss",
|
| 110 |
+
penalty="l2",
|
| 111 |
+
alpha=1e-4,
|
| 112 |
+
max_iter=1000,
|
| 113 |
+
tol=1e-3,
|
| 114 |
+
class_weight="balanced",
|
| 115 |
+
random_state=42,
|
| 116 |
+
)
|
| 117 |
+
self.classes_: Optional[np.ndarray] = None
|
| 118 |
+
|
| 119 |
+
def _load_encoder(self):
|
| 120 |
+
if self.encoder is None:
|
| 121 |
+
from sentence_transformers import SentenceTransformer
|
| 122 |
+
|
| 123 |
+
self.encoder = SentenceTransformer(self.encoder_name)
|
| 124 |
+
|
| 125 |
+
def encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
|
| 126 |
+
self._load_encoder()
|
| 127 |
+
return self.encoder.encode(
|
| 128 |
+
texts,
|
| 129 |
+
batch_size=self.batch_size,
|
| 130 |
+
show_progress_bar=show_progress,
|
| 131 |
+
convert_to_numpy=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def fit(self, texts: List[str], y: np.ndarray) -> "EmbeddingClassifier":
|
| 135 |
+
X = self.encode(texts, show_progress=True)
|
| 136 |
+
self.clf.fit(X, y)
|
| 137 |
+
self.classes_ = self.clf.classes_
|
| 138 |
+
return self
|
| 139 |
+
|
| 140 |
+
def predict(self, texts: List[str]) -> np.ndarray:
|
| 141 |
+
return self.clf.predict(self.encode(texts))
|
| 142 |
+
|
| 143 |
+
def predict_proba(self, texts: List[str]) -> np.ndarray:
|
| 144 |
+
return self.clf.predict_proba(self.encode(texts))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def build_classifier(
|
| 148 |
+
model_type: ModelType = "cross_encoder",
|
| 149 |
+
) -> Union[
|
| 150 |
+
Pipeline,
|
| 151 |
+
EmbeddingClassifier,
|
| 152 |
+
"MultiTowerClassifier",
|
| 153 |
+
"CrossEncoderClassifier",
|
| 154 |
+
"FineTunedCrossEncoderClassifier",
|
| 155 |
+
]:
|
| 156 |
+
if model_type == "tfidf":
|
| 157 |
+
return build_tfidf_classifier()
|
| 158 |
+
if model_type == "minilm":
|
| 159 |
+
return EmbeddingClassifier()
|
| 160 |
+
if model_type == "multi_tower":
|
| 161 |
+
from src.multi_tower_model import MultiTowerClassifier
|
| 162 |
+
|
| 163 |
+
return MultiTowerClassifier()
|
| 164 |
+
if model_type == "cross_encoder":
|
| 165 |
+
from src.cross_encoder_model import CrossEncoderClassifier
|
| 166 |
+
|
| 167 |
+
return CrossEncoderClassifier()
|
| 168 |
+
if model_type == "cross_encoder_ft":
|
| 169 |
+
from src.cross_encoder_model import FineTunedCrossEncoderClassifier
|
| 170 |
+
|
| 171 |
+
return FineTunedCrossEncoderClassifier()
|
| 172 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def save_model(
|
| 176 |
+
model: Union[
|
| 177 |
+
Pipeline,
|
| 178 |
+
EmbeddingClassifier,
|
| 179 |
+
"MultiTowerClassifier",
|
| 180 |
+
"CrossEncoderClassifier",
|
| 181 |
+
"FineTunedCrossEncoderClassifier",
|
| 182 |
+
],
|
| 183 |
+
path: Path = DEFAULT_MODEL_PATH,
|
| 184 |
+
model_type: ModelType = "cross_encoder",
|
| 185 |
+
) -> Path:
|
| 186 |
+
from src.cross_encoder_model import (
|
| 187 |
+
CrossEncoderClassifier,
|
| 188 |
+
FineTunedCrossEncoderClassifier,
|
| 189 |
+
)
|
| 190 |
+
from src.multi_tower_model import MultiTowerClassifier
|
| 191 |
+
|
| 192 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 193 |
+
if isinstance(model, FineTunedCrossEncoderClassifier):
|
| 194 |
+
ft_path = path if path.is_dir() or str(path).endswith("/") else path.with_suffix(".ce")
|
| 195 |
+
if ft_path.suffix == ".joblib":
|
| 196 |
+
ft_path = ft_path.with_suffix(".ce")
|
| 197 |
+
model.save(ft_path)
|
| 198 |
+
meta_path = ft_path / "meta.json" if ft_path.is_dir() else path.with_suffix(".meta.json")
|
| 199 |
+
with open(meta_path, "w") as f:
|
| 200 |
+
json.dump({"model_type": "cross_encoder_ft", "path": str(ft_path)}, f, indent=2)
|
| 201 |
+
return ft_path
|
| 202 |
+
if isinstance(model, CrossEncoderClassifier):
|
| 203 |
+
payload = {
|
| 204 |
+
"model_type": "cross_encoder",
|
| 205 |
+
"cross_encoder_name": model.cross_encoder_name,
|
| 206 |
+
"batch_size": model.batch_size,
|
| 207 |
+
"max_length": model.max_length,
|
| 208 |
+
"scaler": model.scaler,
|
| 209 |
+
"classifier": model.clf,
|
| 210 |
+
"classes_": model.classes_,
|
| 211 |
+
}
|
| 212 |
+
joblib.dump(payload, path)
|
| 213 |
+
meta_path = path.with_suffix(".meta.json")
|
| 214 |
+
with open(meta_path, "w") as f:
|
| 215 |
+
json.dump(
|
| 216 |
+
{
|
| 217 |
+
"model_type": "cross_encoder",
|
| 218 |
+
"cross_encoder_name": model.cross_encoder_name,
|
| 219 |
+
},
|
| 220 |
+
f,
|
| 221 |
+
indent=2,
|
| 222 |
+
)
|
| 223 |
+
elif isinstance(model, MultiTowerClassifier):
|
| 224 |
+
payload = {
|
| 225 |
+
"model_type": "multi_tower",
|
| 226 |
+
"encoder_name": model.encoder_name,
|
| 227 |
+
"batch_size": model.batch_size,
|
| 228 |
+
"scaler": model.scaler,
|
| 229 |
+
"classifier": model.clf,
|
| 230 |
+
"classes_": model.classes_,
|
| 231 |
+
}
|
| 232 |
+
joblib.dump(payload, path)
|
| 233 |
+
meta_path = path.with_suffix(".meta.json")
|
| 234 |
+
with open(meta_path, "w") as f:
|
| 235 |
+
json.dump(
|
| 236 |
+
{"model_type": "multi_tower", "encoder_name": model.encoder_name},
|
| 237 |
+
f,
|
| 238 |
+
indent=2,
|
| 239 |
+
)
|
| 240 |
+
elif isinstance(model, EmbeddingClassifier):
|
| 241 |
+
payload = {
|
| 242 |
+
"model_type": model_type,
|
| 243 |
+
"encoder_name": model.encoder_name,
|
| 244 |
+
"batch_size": model.batch_size,
|
| 245 |
+
"classifier": model.clf,
|
| 246 |
+
"classes_": model.classes_,
|
| 247 |
+
}
|
| 248 |
+
joblib.dump(payload, path)
|
| 249 |
+
meta_path = path.with_suffix(".meta.json")
|
| 250 |
+
with open(meta_path, "w") as f:
|
| 251 |
+
json.dump(
|
| 252 |
+
{"model_type": model_type, "encoder_name": model.encoder_name},
|
| 253 |
+
f,
|
| 254 |
+
indent=2,
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
joblib.dump({"model_type": "tfidf", "pipeline": model}, path)
|
| 258 |
+
return path
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def load_model(
|
| 262 |
+
path: Path = DEFAULT_MODEL_PATH,
|
| 263 |
+
) -> Union[
|
| 264 |
+
Pipeline,
|
| 265 |
+
EmbeddingClassifier,
|
| 266 |
+
"MultiTowerClassifier",
|
| 267 |
+
"CrossEncoderClassifier",
|
| 268 |
+
"FineTunedCrossEncoderClassifier",
|
| 269 |
+
]:
|
| 270 |
+
from src.cross_encoder_model import (
|
| 271 |
+
CrossEncoderClassifier,
|
| 272 |
+
FineTunedCrossEncoderClassifier,
|
| 273 |
+
)
|
| 274 |
+
from src.multi_tower_model import MultiTowerClassifier
|
| 275 |
+
|
| 276 |
+
path = Path(path)
|
| 277 |
+
|
| 278 |
+
# Fine-tuned cross-encoder saved as directory
|
| 279 |
+
ce_path = path.with_suffix(".ce") if path.suffix == ".joblib" else path
|
| 280 |
+
if ce_path.exists() and (ce_path / "config.json").exists():
|
| 281 |
+
return FineTunedCrossEncoderClassifier.load(ce_path)
|
| 282 |
+
|
| 283 |
+
meta_path = path.with_suffix(".meta.json")
|
| 284 |
+
if meta_path.exists():
|
| 285 |
+
with open(meta_path) as f:
|
| 286 |
+
meta = json.load(f)
|
| 287 |
+
if meta.get("model_type") == "cross_encoder_ft":
|
| 288 |
+
ft_path = Path(meta.get("path", str(ce_path)))
|
| 289 |
+
return FineTunedCrossEncoderClassifier.load(ft_path)
|
| 290 |
+
|
| 291 |
+
obj = joblib.load(path)
|
| 292 |
+
if isinstance(obj, dict):
|
| 293 |
+
if obj.get("model_type") == "cross_encoder":
|
| 294 |
+
model = CrossEncoderClassifier(
|
| 295 |
+
cross_encoder_name=obj["cross_encoder_name"],
|
| 296 |
+
batch_size=obj.get("batch_size", 32),
|
| 297 |
+
max_length=obj.get("max_length", 512),
|
| 298 |
+
)
|
| 299 |
+
model.scaler = obj["scaler"]
|
| 300 |
+
model.clf = obj["classifier"]
|
| 301 |
+
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
|
| 302 |
+
return model
|
| 303 |
+
if obj.get("model_type") == "multi_tower":
|
| 304 |
+
model = MultiTowerClassifier(
|
| 305 |
+
encoder_name=obj["encoder_name"],
|
| 306 |
+
batch_size=obj.get("batch_size", 256),
|
| 307 |
+
)
|
| 308 |
+
model.scaler = obj["scaler"]
|
| 309 |
+
model.clf = obj["classifier"]
|
| 310 |
+
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
|
| 311 |
+
return model
|
| 312 |
+
if obj.get("model_type") == "minilm":
|
| 313 |
+
model = EmbeddingClassifier(
|
| 314 |
+
encoder_name=obj["encoder_name"],
|
| 315 |
+
batch_size=obj.get("batch_size", 256),
|
| 316 |
+
)
|
| 317 |
+
model.clf = obj["classifier"]
|
| 318 |
+
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
|
| 319 |
+
return model
|
| 320 |
+
return obj["pipeline"]
|
| 321 |
+
return obj
|
src/multi_tower_model.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-tower semantic comparison architecture for SQL error classification."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
from sklearn.preprocessing import StandardScaler
|
| 11 |
+
|
| 12 |
+
from src.model import DEFAULT_ENCODER
|
| 13 |
+
from src.sql_features import extract_sql_features
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class QueryContext:
|
| 18 |
+
"""Inputs available in the SQL playground at inference time."""
|
| 19 |
+
|
| 20 |
+
question: str
|
| 21 |
+
schema: str
|
| 22 |
+
correct_query: str
|
| 23 |
+
student_query: str
|
| 24 |
+
error_message: Optional[str] = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 28 |
+
denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
|
| 29 |
+
denom = np.maximum(denom, 1e-8)
|
| 30 |
+
return np.sum(a * b, axis=1) / denom
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MultiTowerClassifier:
|
| 34 |
+
"""
|
| 35 |
+
Recommended architecture for SQL error classification.
|
| 36 |
+
|
| 37 |
+
Three semantic towers (shared MiniLM encoder):
|
| 38 |
+
1. Intent tower — question + schema → what should be answered
|
| 39 |
+
2. Reference tower — correct_query → ground-truth solution
|
| 40 |
+
3. Student tower — student_query → what the student wrote
|
| 41 |
+
|
| 42 |
+
Comparison layer fuses:
|
| 43 |
+
- tower embeddings
|
| 44 |
+
- |student − reference| (what changed)
|
| 45 |
+
- student ⊙ reference (interaction)
|
| 46 |
+
- cosine similarities (semantic alignment)
|
| 47 |
+
- SQL structural features (join/null/agg rules)
|
| 48 |
+
|
| 49 |
+
A light linear head maps the fused vector → 15 error categories.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
encoder_name: str = DEFAULT_ENCODER,
|
| 55 |
+
batch_size: int = 256,
|
| 56 |
+
):
|
| 57 |
+
self.encoder_name = encoder_name
|
| 58 |
+
self.batch_size = batch_size
|
| 59 |
+
self.encoder = None
|
| 60 |
+
self.scaler = StandardScaler()
|
| 61 |
+
self.clf = LogisticRegression(
|
| 62 |
+
max_iter=1000,
|
| 63 |
+
solver="lbfgs",
|
| 64 |
+
class_weight="balanced",
|
| 65 |
+
random_state=42,
|
| 66 |
+
)
|
| 67 |
+
self.classes_: Optional[np.ndarray] = None
|
| 68 |
+
|
| 69 |
+
def _load_encoder(self):
|
| 70 |
+
if self.encoder is None:
|
| 71 |
+
from sentence_transformers import SentenceTransformer
|
| 72 |
+
|
| 73 |
+
self.encoder = SentenceTransformer(self.encoder_name)
|
| 74 |
+
|
| 75 |
+
def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
|
| 76 |
+
self._load_encoder()
|
| 77 |
+
return self.encoder.encode(
|
| 78 |
+
texts,
|
| 79 |
+
batch_size=self.batch_size,
|
| 80 |
+
show_progress_bar=show_progress,
|
| 81 |
+
convert_to_numpy=True,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def _intent_text(ctx: QueryContext) -> str:
|
| 86 |
+
return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def _reference_text(ctx: QueryContext) -> str:
|
| 90 |
+
return f"REFERENCE: {ctx.correct_query}"
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def _student_text(ctx: QueryContext) -> str:
|
| 94 |
+
parts = [f"STUDENT: {ctx.student_query}"]
|
| 95 |
+
if ctx.error_message:
|
| 96 |
+
parts.append(f"ERROR: {ctx.error_message}")
|
| 97 |
+
return " ".join(parts)
|
| 98 |
+
|
| 99 |
+
def _build_feature_matrix(
|
| 100 |
+
self,
|
| 101 |
+
contexts: List[QueryContext],
|
| 102 |
+
show_progress: bool = False,
|
| 103 |
+
) -> np.ndarray:
|
| 104 |
+
intent_texts = [self._intent_text(c) for c in contexts]
|
| 105 |
+
ref_texts = [self._reference_text(c) for c in contexts]
|
| 106 |
+
student_texts = [self._student_text(c) for c in contexts]
|
| 107 |
+
|
| 108 |
+
intent_emb = self._encode(intent_texts, show_progress)
|
| 109 |
+
ref_emb = self._encode(ref_texts, show_progress=False)
|
| 110 |
+
student_emb = self._encode(student_texts, show_progress=False)
|
| 111 |
+
|
| 112 |
+
diff = np.abs(student_emb - ref_emb)
|
| 113 |
+
prod = student_emb * ref_emb
|
| 114 |
+
cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1)
|
| 115 |
+
cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1)
|
| 116 |
+
cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1)
|
| 117 |
+
|
| 118 |
+
sql_feats = np.array(
|
| 119 |
+
[
|
| 120 |
+
extract_sql_features(c.student_query, c.correct_query)
|
| 121 |
+
for c in contexts
|
| 122 |
+
],
|
| 123 |
+
dtype=np.float64,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return np.hstack(
|
| 127 |
+
[intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier":
|
| 131 |
+
X = self._build_feature_matrix(contexts, show_progress=True)
|
| 132 |
+
X = self.scaler.fit_transform(X)
|
| 133 |
+
self.clf.fit(X, y)
|
| 134 |
+
self.classes_ = self.clf.classes_
|
| 135 |
+
return self
|
| 136 |
+
|
| 137 |
+
def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 138 |
+
X = self.scaler.transform(self._build_feature_matrix(contexts))
|
| 139 |
+
return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
|
| 140 |
+
|
| 141 |
+
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 142 |
+
return self.clf.predict(self._prepare_features(contexts))
|
| 143 |
+
|
| 144 |
+
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
|
| 145 |
+
return self.clf.predict_proba(self._prepare_features(contexts))
|
| 146 |
+
|
| 147 |
+
def explain_similarities(self, ctx: QueryContext) -> dict:
|
| 148 |
+
"""Diagnostic scores for the playground UI."""
|
| 149 |
+
emb = self._build_feature_matrix([ctx])
|
| 150 |
+
intent_texts = [self._intent_text(ctx)]
|
| 151 |
+
ref_texts = [self._reference_text(ctx)]
|
| 152 |
+
student_texts = [self._student_text(ctx)]
|
| 153 |
+
intent_emb = self._encode(intent_texts)
|
| 154 |
+
ref_emb = self._encode(ref_texts)
|
| 155 |
+
student_emb = self._encode(student_texts)
|
| 156 |
+
return {
|
| 157 |
+
"student_vs_reference": float(_cosine(student_emb, ref_emb)[0]),
|
| 158 |
+
"student_vs_intent": float(_cosine(student_emb, intent_emb)[0]),
|
| 159 |
+
"reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def contexts_from_dataframe(df) -> List[QueryContext]:
|
| 164 |
+
"""Build QueryContext list from a training dataframe."""
|
| 165 |
+
has_error = "error_message" in df.columns
|
| 166 |
+
return [
|
| 167 |
+
QueryContext(
|
| 168 |
+
question=row["question"],
|
| 169 |
+
schema=row["schema"],
|
| 170 |
+
correct_query=row["correct_query"],
|
| 171 |
+
student_query=row["query"],
|
| 172 |
+
error_message=row["error_message"] if has_error else None,
|
| 173 |
+
)
|
| 174 |
+
for row in df.to_dict("records")
|
| 175 |
+
]
|
src/predict.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference API for SQL error classification."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import asdict, dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
from src.categories import id_to_name, load_categories
|
| 12 |
+
from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
|
| 13 |
+
from src.cross_encoder_model import (
|
| 14 |
+
CrossEncoderClassifier,
|
| 15 |
+
FineTunedCrossEncoderClassifier,
|
| 16 |
+
)
|
| 17 |
+
from src.multi_tower_model import MultiTowerClassifier, QueryContext
|
| 18 |
+
|
| 19 |
+
CONTEXT_MODELS = (
|
| 20 |
+
CrossEncoderClassifier,
|
| 21 |
+
FineTunedCrossEncoderClassifier,
|
| 22 |
+
MultiTowerClassifier,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class Prediction:
|
| 28 |
+
label_id: int
|
| 29 |
+
label_name: str
|
| 30 |
+
confidence: float
|
| 31 |
+
top_k: List[dict]
|
| 32 |
+
similarities: Optional[dict] = None
|
| 33 |
+
pair_scores: Optional[dict] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SQLErrorClassifier:
|
| 37 |
+
"""Classifier wrapper for playground integration."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, model_path: Path = DEFAULT_MODEL_PATH):
|
| 40 |
+
self.model = load_model(model_path)
|
| 41 |
+
self.label_map = id_to_name(load_categories())
|
| 42 |
+
|
| 43 |
+
def predict(
|
| 44 |
+
self,
|
| 45 |
+
query: str,
|
| 46 |
+
error_message: Optional[str] = None,
|
| 47 |
+
schema: Optional[str] = None,
|
| 48 |
+
question: Optional[str] = None,
|
| 49 |
+
correct_query: Optional[str] = None,
|
| 50 |
+
top_k: int = 3,
|
| 51 |
+
) -> Prediction:
|
| 52 |
+
if isinstance(self.model, CONTEXT_MODELS):
|
| 53 |
+
if not all([schema, question, correct_query]):
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"context models require schema, question, and correct_query"
|
| 56 |
+
)
|
| 57 |
+
ctx = QueryContext(
|
| 58 |
+
question=question,
|
| 59 |
+
schema=schema,
|
| 60 |
+
correct_query=correct_query,
|
| 61 |
+
student_query=query,
|
| 62 |
+
error_message=error_message,
|
| 63 |
+
)
|
| 64 |
+
proba = self.model.predict_proba([ctx])[0]
|
| 65 |
+
similarities = (
|
| 66 |
+
self.model.explain_similarities(ctx)
|
| 67 |
+
if isinstance(self.model, MultiTowerClassifier)
|
| 68 |
+
else None
|
| 69 |
+
)
|
| 70 |
+
pair_scores = (
|
| 71 |
+
self.model.explain_pair_scores(ctx)
|
| 72 |
+
if isinstance(self.model, CrossEncoderClassifier)
|
| 73 |
+
else None
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
pair_scores = None
|
| 77 |
+
similarities = None
|
| 78 |
+
text = combine_features(
|
| 79 |
+
queries=[query],
|
| 80 |
+
error_messages=[error_message] if error_message else None,
|
| 81 |
+
schemas=[schema] if schema else None,
|
| 82 |
+
questions=[question] if question else None,
|
| 83 |
+
)[0]
|
| 84 |
+
proba = self.model.predict_proba([text])[0]
|
| 85 |
+
similarities = None
|
| 86 |
+
|
| 87 |
+
classes = self.model.classes_
|
| 88 |
+
ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
|
| 89 |
+
best_id = int(ranked[0][0])
|
| 90 |
+
|
| 91 |
+
return Prediction(
|
| 92 |
+
label_id=best_id,
|
| 93 |
+
label_name=self.label_map[best_id],
|
| 94 |
+
confidence=float(ranked[0][1]),
|
| 95 |
+
top_k=[
|
| 96 |
+
{
|
| 97 |
+
"label_id": int(cls),
|
| 98 |
+
"label_name": self.label_map[int(cls)],
|
| 99 |
+
"confidence": float(p),
|
| 100 |
+
}
|
| 101 |
+
for cls, p in ranked[:top_k]
|
| 102 |
+
],
|
| 103 |
+
similarities=similarities,
|
| 104 |
+
pair_scores=pair_scores,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main() -> None:
|
| 109 |
+
parser = argparse.ArgumentParser(description="Classify SQL error type")
|
| 110 |
+
parser.add_argument("--query", type=str, required=True)
|
| 111 |
+
parser.add_argument("--correct-query", type=str, default=None)
|
| 112 |
+
parser.add_argument("--error-message", type=str, default=None)
|
| 113 |
+
parser.add_argument("--schema", type=str, default=None)
|
| 114 |
+
parser.add_argument("--question", type=str, default=None)
|
| 115 |
+
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
|
| 116 |
+
parser.add_argument("--top-k", type=int, default=3)
|
| 117 |
+
args = parser.parse_args()
|
| 118 |
+
|
| 119 |
+
clf = SQLErrorClassifier(args.model)
|
| 120 |
+
result = clf.predict(
|
| 121 |
+
args.query,
|
| 122 |
+
args.error_message,
|
| 123 |
+
args.schema,
|
| 124 |
+
args.question,
|
| 125 |
+
args.correct_query,
|
| 126 |
+
top_k=args.top_k,
|
| 127 |
+
)
|
| 128 |
+
print(json.dumps(asdict(result), indent=2))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
src/sql_features.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Lightweight structural SQL features for the classification head."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
AGG_FUNCS = ("COUNT", "SUM", "AVG", "MAX", "MIN")
|
| 9 |
+
WINDOW_FUNCS = ("ROW_NUMBER", "RANK", "DENSE_RANK", "OVER")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _upper(sql: str) -> str:
|
| 13 |
+
return sql.upper()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_sql_features(student_query: str, correct_query: str = "") -> List[float]:
|
| 17 |
+
"""
|
| 18 |
+
Rule-based signals that complement semantic embeddings.
|
| 19 |
+
Returns a fixed-length float vector.
|
| 20 |
+
"""
|
| 21 |
+
s = _upper(student_query)
|
| 22 |
+
c = _upper(correct_query) if correct_query else ""
|
| 23 |
+
|
| 24 |
+
has_agg = any(f" {f}(" in s or f"{f}(" in s for f in AGG_FUNCS)
|
| 25 |
+
has_group = "GROUP BY" in s
|
| 26 |
+
has_join = "JOIN" in s
|
| 27 |
+
has_on = " ON " in s
|
| 28 |
+
has_where = " WHERE " in s
|
| 29 |
+
has_having = " HAVING " in s
|
| 30 |
+
has_distinct = "DISTINCT" in s
|
| 31 |
+
has_subquery = "(" in s and "SELECT" in s[s.find("(") :]
|
| 32 |
+
has_window = "OVER" in s
|
| 33 |
+
has_null_eq = "= NULL" in s or "=NULL" in s
|
| 34 |
+
has_is_null = "IS NULL" in s or "IS NOT NULL" in s
|
| 35 |
+
has_select_star = bool(re.search(r"SELECT\s+\*", s))
|
| 36 |
+
has_or = " OR " in s
|
| 37 |
+
has_and = " AND " in s
|
| 38 |
+
|
| 39 |
+
correct_has_distinct = "DISTINCT" in c
|
| 40 |
+
correct_has_group = "GROUP BY" in c
|
| 41 |
+
correct_has_inner = "INNER JOIN" in c
|
| 42 |
+
student_has_left = "LEFT JOIN" in s
|
| 43 |
+
|
| 44 |
+
return [
|
| 45 |
+
float(has_agg),
|
| 46 |
+
float(has_agg and not has_group),
|
| 47 |
+
float(has_join and not has_on),
|
| 48 |
+
float(has_join),
|
| 49 |
+
float(has_where and has_having),
|
| 50 |
+
float(has_agg and has_where and not has_having),
|
| 51 |
+
float(has_distinct),
|
| 52 |
+
float(correct_has_distinct and not has_distinct),
|
| 53 |
+
float(has_subquery),
|
| 54 |
+
float(has_window),
|
| 55 |
+
float(has_null_eq),
|
| 56 |
+
float(has_is_null),
|
| 57 |
+
float(has_select_star),
|
| 58 |
+
float(has_or and has_and),
|
| 59 |
+
float(correct_has_inner and student_has_left),
|
| 60 |
+
float(len(s) / max(len(c), 1)), # length ratio vs reference
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
FEATURE_NAMES = [
|
| 65 |
+
"has_aggregate",
|
| 66 |
+
"agg_without_group_by",
|
| 67 |
+
"join_without_on",
|
| 68 |
+
"has_join",
|
| 69 |
+
"where_and_having",
|
| 70 |
+
"agg_in_where",
|
| 71 |
+
"has_distinct",
|
| 72 |
+
"missing_distinct_vs_correct",
|
| 73 |
+
"has_subquery",
|
| 74 |
+
"has_window",
|
| 75 |
+
"null_equals",
|
| 76 |
+
"is_null_check",
|
| 77 |
+
"select_star",
|
| 78 |
+
"and_or_mix",
|
| 79 |
+
"left_vs_inner_join",
|
| 80 |
+
"length_ratio",
|
| 81 |
+
]
|
src/sql_templates.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Error injectors that transform exercise context into labeled mistakes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Callable, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
from src.exercises import Exercise
|
| 9 |
+
|
| 10 |
+
FAKE_COLUMNS = ["fullname", "studentname", "coursename", "dept_name", "totals"]
|
| 11 |
+
FAKE_TABLES = ["student", "course", "enrolment", "employe", "orderz"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _pick(rng: random.Random, items: List[str], k: int = 1) -> str | List[str]:
|
| 15 |
+
if k == 1:
|
| 16 |
+
return rng.choice(items)
|
| 17 |
+
return rng.sample(items, k)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _first_table(exercise: Exercise) -> str:
|
| 21 |
+
return exercise.tables[0]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _second_table(exercise: Exercise) -> str:
|
| 25 |
+
return exercise.tables[1] if len(exercise.tables) > 1 else exercise.tables[0]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# --- Error injectors: (exercise) -> (erroneous_sql, error_message) ---
|
| 29 |
+
|
| 30 |
+
def inject_syntax_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 31 |
+
sql = exercise.correct_query
|
| 32 |
+
mutations = [
|
| 33 |
+
lambda s: s.replace("SELECT", "SELEC", 1),
|
| 34 |
+
lambda s: s.replace("FROM", "FRO", 1),
|
| 35 |
+
lambda s: s[:-1],
|
| 36 |
+
lambda s: s.replace(")", "", 1),
|
| 37 |
+
lambda s: s + " WHERE",
|
| 38 |
+
lambda s: s.replace(",", "", 1),
|
| 39 |
+
lambda s: s.replace("'", '"', 1) if "'" in s else s + " 'unclosed",
|
| 40 |
+
]
|
| 41 |
+
return rng.choice(mutations)(sql), "syntax error at or near unexpected token"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def inject_join_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 45 |
+
t1, t2 = _first_table(exercise), _second_table(exercise)
|
| 46 |
+
col = _pick(rng, list(exercise.columns))
|
| 47 |
+
variants = [
|
| 48 |
+
f"SELECT {col} FROM {t1} JOIN {t2}",
|
| 49 |
+
f"SELECT {col} FROM {t1} INNER JOIN {t2} ON {t1}.id = {t2}.id",
|
| 50 |
+
(
|
| 51 |
+
f"SELECT {t1}.{col} FROM {t1} "
|
| 52 |
+
f"LEFT JOIN {t2} ON {t1}.{col} = {t2}.{col}"
|
| 53 |
+
),
|
| 54 |
+
f"SELECT * FROM {t1}, {t2} WHERE {t1}.wrong_id = {t2}.wrong_id",
|
| 55 |
+
]
|
| 56 |
+
return rng.choice(variants), "missing ON clause or invalid join condition"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def inject_aggregation_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 60 |
+
t = _first_table(exercise)
|
| 61 |
+
cols = list(exercise.columns)
|
| 62 |
+
group_col = cols[0]
|
| 63 |
+
agg_col = cols[-1]
|
| 64 |
+
bad = f"SELECT {group_col}, AVG({agg_col}) FROM {t}"
|
| 65 |
+
return bad, "column must appear in GROUP BY clause or be used in aggregate function"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def inject_having_where_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 69 |
+
t = _first_table(exercise)
|
| 70 |
+
cols = list(exercise.columns)
|
| 71 |
+
group_col, agg_col = cols[0], cols[-1]
|
| 72 |
+
bad = (
|
| 73 |
+
f"SELECT {group_col}, COUNT({agg_col}) FROM {t} "
|
| 74 |
+
f"WHERE COUNT({agg_col}) > {rng.randint(1, 5)}"
|
| 75 |
+
)
|
| 76 |
+
return bad, "aggregate functions are not allowed in WHERE"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def inject_subquery_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 80 |
+
t1, t2 = _first_table(exercise), _second_table(exercise)
|
| 81 |
+
col = _pick(rng, list(exercise.columns))
|
| 82 |
+
bad = f"SELECT {col} FROM {t1} WHERE {col} = (SELECT {col} FROM {t2})"
|
| 83 |
+
return bad, "subquery returned more than one row"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def inject_window_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 87 |
+
t = _first_table(exercise)
|
| 88 |
+
col = _pick(rng, list(exercise.columns))
|
| 89 |
+
variants = [
|
| 90 |
+
f"SELECT {col}, ROW_NUMBER() OVER () FROM {t}",
|
| 91 |
+
f"SELECT {col}, SUM({col}) OVER (ORDER BY {col}) FROM {t} GROUP BY {col}",
|
| 92 |
+
f"SELECT {col}, RANK() OVER (PARTITION {col}) FROM {t}",
|
| 93 |
+
]
|
| 94 |
+
return rng.choice(variants), "window function requires PARTITION BY or ORDER BY"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def inject_null_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 98 |
+
t = _first_table(exercise)
|
| 99 |
+
col = _pick(rng, list(exercise.columns))
|
| 100 |
+
bad = f"SELECT * FROM {t} WHERE {col} = NULL"
|
| 101 |
+
return bad, "use IS NULL or IS NOT NULL to test for null values"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def inject_date_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 105 |
+
t = _first_table(exercise)
|
| 106 |
+
variants = [
|
| 107 |
+
f"SELECT * FROM {t} WHERE order_date = '31/02/2023'",
|
| 108 |
+
f"SELECT * FROM {t} WHERE order_date = DATE '2023-13-40'",
|
| 109 |
+
f"SELECT * FROM {t} WHERE STR_TO_DATE('bad-date', '%Y-%m-%d')",
|
| 110 |
+
f"SELECT * FROM {t} WHERE hire_date > 'yesterday'",
|
| 111 |
+
]
|
| 112 |
+
return rng.choice(variants), "invalid date format or unknown date function"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def inject_column_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 116 |
+
t = _first_table(exercise)
|
| 117 |
+
col = _pick(rng, FAKE_COLUMNS)
|
| 118 |
+
bad = f"SELECT {col} FROM {t}"
|
| 119 |
+
return bad, f"column '{col}' does not exist"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def inject_table_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 123 |
+
tbl = _pick(rng, FAKE_TABLES)
|
| 124 |
+
col = _pick(rng, list(exercise.columns))
|
| 125 |
+
bad = f"SELECT {col} FROM {tbl}"
|
| 126 |
+
return bad, f"relation '{tbl}' does not exist"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def inject_datatype_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 130 |
+
t = _first_table(exercise)
|
| 131 |
+
col = _pick(rng, list(exercise.columns))
|
| 132 |
+
bad = f"SELECT {col} FROM {t} WHERE {col} = '{rng.choice(['abc', 'ten', 'N/A'])}'"
|
| 133 |
+
return bad, "operator does not exist: integer = character varying"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def inject_duplicate_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 137 |
+
"""Drop DISTINCT when the question asks for unique values."""
|
| 138 |
+
sql = exercise.correct_query
|
| 139 |
+
if "DISTINCT" in sql.upper():
|
| 140 |
+
bad = sql.upper().replace("DISTINCT ", "").replace("distinct ", "")
|
| 141 |
+
# restore original casing loosely
|
| 142 |
+
bad = sql.replace("DISTINCT ", "").replace("distinct ", "")
|
| 143 |
+
else:
|
| 144 |
+
col = _pick(rng, list(exercise.columns))
|
| 145 |
+
bad = f"SELECT {col} FROM {_first_table(exercise)}"
|
| 146 |
+
return bad, "query returns duplicate rows; DISTINCT may be required"
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def inject_logical_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 150 |
+
"""
|
| 151 |
+
Produce a query that runs against the schema but answers the question incorrectly.
|
| 152 |
+
Variants are tied to the exercise question and correct answer.
|
| 153 |
+
"""
|
| 154 |
+
sql = exercise.correct_query
|
| 155 |
+
q = exercise.question.lower()
|
| 156 |
+
variants: List[str] = []
|
| 157 |
+
|
| 158 |
+
if "average" in q or "avg" in sql.lower():
|
| 159 |
+
variants.append(sql.replace("AVG(", "SUM(", 1))
|
| 160 |
+
variants.append(sql.replace("AVG(", "MAX(", 1))
|
| 161 |
+
if " and " in q and " AND " in sql:
|
| 162 |
+
variants.append(sql.replace(" AND ", " OR ", 1))
|
| 163 |
+
if "join" in sql.lower():
|
| 164 |
+
t1, t2 = _first_table(exercise), _second_table(exercise)
|
| 165 |
+
variants.append(
|
| 166 |
+
f"SELECT {t1}.name, {t2}.name FROM {t1} "
|
| 167 |
+
f"JOIN {t2} ON {t1}.id = {t2}.id"
|
| 168 |
+
)
|
| 169 |
+
variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1))
|
| 170 |
+
if "between" in q and "BETWEEN" in sql.upper():
|
| 171 |
+
upper = sql.upper()
|
| 172 |
+
between_part = upper.split("BETWEEN", 1)[1]
|
| 173 |
+
bounds = between_part.split("AND", 1)
|
| 174 |
+
if len(bounds) == 2:
|
| 175 |
+
lo = bounds[0].strip().split()[-1]
|
| 176 |
+
hi = bounds[1].strip().split()[0]
|
| 177 |
+
variants.append(
|
| 178 |
+
sql.split("WHERE", 1)[0]
|
| 179 |
+
+ f" WHERE price BETWEEN {hi} AND {lo}"
|
| 180 |
+
)
|
| 181 |
+
if "rank" in q or "over" in sql.lower():
|
| 182 |
+
col = _pick(rng, list(exercise.columns))
|
| 183 |
+
variants.append(
|
| 184 |
+
f"SELECT name, {col} FROM {_first_table(exercise)} ORDER BY {col} DESC"
|
| 185 |
+
)
|
| 186 |
+
if "total" in q and "WHERE" in sql.upper():
|
| 187 |
+
variants.append(sql.replace("active", "inactive"))
|
| 188 |
+
if "highest" in q or "max" in sql.lower():
|
| 189 |
+
col = _pick(rng, list(exercise.columns))
|
| 190 |
+
variants.append(
|
| 191 |
+
f"SELECT name FROM {_first_table(exercise)} "
|
| 192 |
+
f"WHERE {col} >= (SELECT AVG({col}) FROM {_second_table(exercise)})"
|
| 193 |
+
)
|
| 194 |
+
if "enrolled" in q and "INNER JOIN" in sql.upper():
|
| 195 |
+
variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1))
|
| 196 |
+
if "not provided" in q or "is null" in sql.lower():
|
| 197 |
+
variants.append(sql.replace("IS NULL", "= ''"))
|
| 198 |
+
|
| 199 |
+
if not variants:
|
| 200 |
+
col = _pick(rng, list(exercise.columns))
|
| 201 |
+
t = _first_table(exercise)
|
| 202 |
+
variants = [
|
| 203 |
+
f"SELECT {col} FROM {t} ORDER BY {col} DESC LIMIT 10",
|
| 204 |
+
f"SELECT COUNT(*) FROM {t}",
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
bad = rng.choice(variants)
|
| 208 |
+
return bad, "query executes but produces incorrect result set"
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def inject_performance_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 212 |
+
t1, t2 = _first_table(exercise), _second_table(exercise)
|
| 213 |
+
variants = [
|
| 214 |
+
f"SELECT * FROM {t1}",
|
| 215 |
+
f"SELECT * FROM {t1} JOIN {t2} ON {t1}.id = {t2}.id",
|
| 216 |
+
(
|
| 217 |
+
f"SELECT * FROM {t1} "
|
| 218 |
+
f"WHERE {_pick(rng, list(exercise.columns))} "
|
| 219 |
+
f"LIKE '%{rng.choice(['a', 'e', 'i'])}%'"
|
| 220 |
+
),
|
| 221 |
+
f"SELECT * FROM {t1} CROSS JOIN {t2}",
|
| 222 |
+
]
|
| 223 |
+
return rng.choice(variants), "inefficient query: SELECT * or cartesian join detected"
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def inject_filtering_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
|
| 227 |
+
sql = exercise.correct_query
|
| 228 |
+
col = _pick(rng, list(exercise.columns))
|
| 229 |
+
t = _first_table(exercise)
|
| 230 |
+
threshold = rng.randint(50, 90)
|
| 231 |
+
variants = [
|
| 232 |
+
sql.replace(">", "<", 1) if ">" in sql else sql.replace("=", "!=", 1),
|
| 233 |
+
f"SELECT {col} FROM {t} WHERE {col} > {threshold} AND {col} < {threshold - 20}",
|
| 234 |
+
f"SELECT {col} FROM {t} WHERE NOT {col} > {threshold}",
|
| 235 |
+
sql.replace(" AND ", " OR ", 1) if " AND " in sql else (
|
| 236 |
+
f"SELECT {col} FROM {t} WHERE {col} BETWEEN {threshold} AND {threshold - 10}"
|
| 237 |
+
),
|
| 238 |
+
]
|
| 239 |
+
return rng.choice(variants), "WHERE clause filters incorrect rows"
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
ERROR_INJECTORS: Dict[int, Callable[[random.Random, Exercise], Tuple[str, str]]] = {
|
| 243 |
+
0: inject_syntax_error,
|
| 244 |
+
1: inject_join_error,
|
| 245 |
+
2: inject_aggregation_error,
|
| 246 |
+
3: inject_having_where_error,
|
| 247 |
+
4: inject_subquery_error,
|
| 248 |
+
5: inject_window_error,
|
| 249 |
+
6: inject_null_error,
|
| 250 |
+
7: inject_date_error,
|
| 251 |
+
8: inject_column_error,
|
| 252 |
+
9: inject_table_error,
|
| 253 |
+
10: inject_datatype_error,
|
| 254 |
+
11: inject_duplicate_error,
|
| 255 |
+
12: inject_logical_error,
|
| 256 |
+
13: inject_performance_error,
|
| 257 |
+
14: inject_filtering_error,
|
| 258 |
+
}
|
src/train.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train the SQL error classifier."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from sklearn.metrics import classification_report
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
|
| 13 |
+
from src.categories import id_to_name, load_categories
|
| 14 |
+
from src.cross_encoder_model import (
|
| 15 |
+
CrossEncoderClassifier,
|
| 16 |
+
FineTunedCrossEncoderClassifier,
|
| 17 |
+
)
|
| 18 |
+
from src.model import (
|
| 19 |
+
DEFAULT_MODEL_PATH,
|
| 20 |
+
ModelType,
|
| 21 |
+
build_classifier,
|
| 22 |
+
combine_features,
|
| 23 |
+
save_model,
|
| 24 |
+
)
|
| 25 |
+
from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe
|
| 26 |
+
|
| 27 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
+
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
|
| 29 |
+
DEFAULT_METRICS = PROJECT_ROOT / "models" / "metrics.json"
|
| 30 |
+
|
| 31 |
+
CONTEXT_MODELS = (
|
| 32 |
+
CrossEncoderClassifier,
|
| 33 |
+
FineTunedCrossEncoderClassifier,
|
| 34 |
+
MultiTowerClassifier,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _split_dataframe(
|
| 39 |
+
df: pd.DataFrame, test_size: float, val_size: float, seed: int
|
| 40 |
+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 41 |
+
trainval, test = train_test_split(
|
| 42 |
+
df, test_size=test_size, random_state=seed, stratify=df["label_id"]
|
| 43 |
+
)
|
| 44 |
+
relative_val = val_size / (1 - test_size)
|
| 45 |
+
train, val = train_test_split(
|
| 46 |
+
trainval,
|
| 47 |
+
test_size=relative_val,
|
| 48 |
+
random_state=seed,
|
| 49 |
+
stratify=trainval["label_id"],
|
| 50 |
+
)
|
| 51 |
+
return train, val, test
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def train(
|
| 55 |
+
data_path: Path = DEFAULT_DATA,
|
| 56 |
+
model_path: Path = DEFAULT_MODEL_PATH,
|
| 57 |
+
metrics_path: Path = DEFAULT_METRICS,
|
| 58 |
+
test_size: float = 0.1,
|
| 59 |
+
val_size: float = 0.1,
|
| 60 |
+
use_error_message: bool = True,
|
| 61 |
+
max_train_samples: int | None = None,
|
| 62 |
+
model_type: ModelType = "cross_encoder",
|
| 63 |
+
epochs: int = 1,
|
| 64 |
+
seed: int = 42,
|
| 65 |
+
) -> dict:
|
| 66 |
+
print(f"Loading data from {data_path}...")
|
| 67 |
+
df = pd.read_parquet(data_path)
|
| 68 |
+
|
| 69 |
+
if max_train_samples and len(df) > max_train_samples:
|
| 70 |
+
df = df.sample(n=max_train_samples, random_state=seed)
|
| 71 |
+
|
| 72 |
+
if not use_error_message and "error_message" in df.columns:
|
| 73 |
+
df = df.drop(columns=["error_message"])
|
| 74 |
+
|
| 75 |
+
train_df, val_df, test_df = _split_dataframe(df, test_size, val_size, seed)
|
| 76 |
+
print(
|
| 77 |
+
f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
model = build_classifier(model_type=model_type)
|
| 81 |
+
print(f"Training {model_type} classifier...")
|
| 82 |
+
|
| 83 |
+
if isinstance(model, CONTEXT_MODELS):
|
| 84 |
+
train_ctx = contexts_from_dataframe(train_df)
|
| 85 |
+
val_ctx = contexts_from_dataframe(val_df)
|
| 86 |
+
test_ctx = contexts_from_dataframe(test_df)
|
| 87 |
+
|
| 88 |
+
if isinstance(model, FineTunedCrossEncoderClassifier):
|
| 89 |
+
model.fit(
|
| 90 |
+
train_ctx,
|
| 91 |
+
train_df["label_id"].values,
|
| 92 |
+
epochs=epochs,
|
| 93 |
+
output_path=model_path.with_suffix(".ce")
|
| 94 |
+
if model_path.suffix == ".joblib"
|
| 95 |
+
else model_path,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
model.fit(train_ctx, train_df["label_id"].values)
|
| 99 |
+
|
| 100 |
+
val_preds = model.predict(val_ctx)
|
| 101 |
+
test_preds = model.predict(test_ctx)
|
| 102 |
+
y_val = val_df["label_id"].values
|
| 103 |
+
y_test = test_df["label_id"].values
|
| 104 |
+
else:
|
| 105 |
+
|
| 106 |
+
def to_texts(frame: pd.DataFrame) -> list[str]:
|
| 107 |
+
return combine_features(
|
| 108 |
+
queries=frame["query"].tolist(),
|
| 109 |
+
error_messages=frame["error_message"].tolist()
|
| 110 |
+
if "error_message" in frame.columns
|
| 111 |
+
else None,
|
| 112 |
+
schemas=frame["schema"].tolist() if "schema" in frame.columns else None,
|
| 113 |
+
questions=frame["question"].tolist()
|
| 114 |
+
if "question" in frame.columns
|
| 115 |
+
else None,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
model.fit(to_texts(train_df), train_df["label_id"].values)
|
| 119 |
+
val_preds = model.predict(to_texts(val_df))
|
| 120 |
+
test_preds = model.predict(to_texts(test_df))
|
| 121 |
+
y_val = val_df["label_id"].values
|
| 122 |
+
y_test = test_df["label_id"].values
|
| 123 |
+
|
| 124 |
+
val_report = classification_report(
|
| 125 |
+
y_val, val_preds, output_dict=True, zero_division=0
|
| 126 |
+
)
|
| 127 |
+
print(f"Validation accuracy: {val_report['accuracy']:.4f}")
|
| 128 |
+
|
| 129 |
+
test_report = classification_report(
|
| 130 |
+
y_test, test_preds, output_dict=True, zero_division=0
|
| 131 |
+
)
|
| 132 |
+
print(f"Test accuracy: {test_report['accuracy']:.4f}")
|
| 133 |
+
|
| 134 |
+
save_model(model, model_path, model_type=model_type)
|
| 135 |
+
print(f"Model saved to {model_path}")
|
| 136 |
+
|
| 137 |
+
categories = load_categories()
|
| 138 |
+
label_map = id_to_name(categories)
|
| 139 |
+
metrics = {
|
| 140 |
+
"train_size": len(train_df),
|
| 141 |
+
"val_size": len(val_df),
|
| 142 |
+
"test_size": len(test_df),
|
| 143 |
+
"model_type": model_type,
|
| 144 |
+
"epochs": epochs if model_type == "cross_encoder_ft" else None,
|
| 145 |
+
"use_error_message": use_error_message,
|
| 146 |
+
"validation": val_report,
|
| 147 |
+
"test": test_report,
|
| 148 |
+
"label_map": {str(k): v for k, v in label_map.items()},
|
| 149 |
+
}
|
| 150 |
+
metrics_path.parent.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
with open(metrics_path, "w") as f:
|
| 152 |
+
json.dump(metrics, f, indent=2)
|
| 153 |
+
print(f"Metrics saved to {metrics_path}")
|
| 154 |
+
|
| 155 |
+
return metrics
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main() -> None:
|
| 159 |
+
parser = argparse.ArgumentParser(description="Train SQL error classifier")
|
| 160 |
+
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
|
| 161 |
+
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
|
| 162 |
+
parser.add_argument("--metrics", type=Path, default=DEFAULT_METRICS)
|
| 163 |
+
parser.add_argument("--test-size", type=float, default=0.1)
|
| 164 |
+
parser.add_argument("--val-size", type=float, default=0.1)
|
| 165 |
+
parser.add_argument("--no-error-message", action="store_true")
|
| 166 |
+
parser.add_argument("--max-samples", type=int, default=None)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--model-type",
|
| 169 |
+
choices=["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"],
|
| 170 |
+
default="cross_encoder",
|
| 171 |
+
help="cross_encoder (recommended): joint attention pairs; "
|
| 172 |
+
"cross_encoder_ft: fine-tuned end-to-end (best accuracy)",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--epochs",
|
| 176 |
+
type=int,
|
| 177 |
+
default=1,
|
| 178 |
+
help="Epochs for cross_encoder_ft fine-tuning",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 181 |
+
args = parser.parse_args()
|
| 182 |
+
|
| 183 |
+
train(
|
| 184 |
+
data_path=args.data,
|
| 185 |
+
model_path=args.model,
|
| 186 |
+
metrics_path=args.metrics,
|
| 187 |
+
test_size=args.test_size,
|
| 188 |
+
val_size=args.val_size,
|
| 189 |
+
use_error_message=not args.no_error_message,
|
| 190 |
+
max_train_samples=args.max_samples,
|
| 191 |
+
model_type=args.model_type,
|
| 192 |
+
epochs=args.epochs,
|
| 193 |
+
seed=args.seed,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
train_space_app.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Space — CodeBERT SQL Error Classifier Training UI.
|
| 3 |
+
|
| 4 |
+
Deploy as a Gradio Space with app_file: train_space_app.py
|
| 5 |
+
Set hardware to GPU (t4-small recommended).
|
| 6 |
+
Add HF_TOKEN secret to push trained models to your Hub account.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import shutil
|
| 14 |
+
import tempfile
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import pandas as pd
|
| 19 |
+
|
| 20 |
+
from src.hf_train_codebert import train
|
| 21 |
+
|
| 22 |
+
PROJECT_ROOT = Path(__file__).parent
|
| 23 |
+
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
|
| 24 |
+
OUTPUT_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
|
| 25 |
+
BUNDLED_DATASETS = {
|
| 26 |
+
"Dev (15K samples)": str(PROJECT_ROOT / "data" / "sql_errors_dev.parquet"),
|
| 27 |
+
"Full (1M samples)": str(PROJECT_ROOT / "data" / "sql_errors_1m.parquet"),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _format_metrics(metrics: dict) -> str:
|
| 32 |
+
val = metrics.get("validation", {})
|
| 33 |
+
test = metrics.get("test", {})
|
| 34 |
+
lines = [
|
| 35 |
+
"## Training complete",
|
| 36 |
+
"",
|
| 37 |
+
f"- Train samples: **{metrics.get('train_samples', 0):,}**",
|
| 38 |
+
f"- Val samples: **{metrics.get('val_samples', 0):,}**",
|
| 39 |
+
f"- Test samples: **{metrics.get('test_samples', 0):,}**",
|
| 40 |
+
"",
|
| 41 |
+
"### Validation",
|
| 42 |
+
f"- F1 macro: **{val.get('eval_f1_macro', 0):.4f}**",
|
| 43 |
+
f"- F1 micro: **{val.get('eval_f1_micro', 0):.4f}**",
|
| 44 |
+
"",
|
| 45 |
+
"### Test",
|
| 46 |
+
f"- F1 macro: **{test.get('f1_macro', 0):.4f}**",
|
| 47 |
+
f"- F1 micro: **{test.get('f1_micro', 0):.4f}**",
|
| 48 |
+
f"- Subset accuracy: **{test.get('subset_accuracy', 0):.4f}**",
|
| 49 |
+
"",
|
| 50 |
+
f"Model saved to `{OUTPUT_DIR}`",
|
| 51 |
+
]
|
| 52 |
+
if metrics.get("hub_url"):
|
| 53 |
+
lines.append(f"\n**Hub model:** {metrics['hub_url']}")
|
| 54 |
+
return "\n".join(lines)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_training(
|
| 58 |
+
dataset_choice: str,
|
| 59 |
+
uploaded_file,
|
| 60 |
+
max_samples: int,
|
| 61 |
+
epochs: float,
|
| 62 |
+
batch_size: int,
|
| 63 |
+
learning_rate: float,
|
| 64 |
+
max_length: int,
|
| 65 |
+
fp16: bool,
|
| 66 |
+
push_to_hub: bool,
|
| 67 |
+
hub_model_id: str,
|
| 68 |
+
progress=gr.Progress(),
|
| 69 |
+
):
|
| 70 |
+
progress(0, desc="Preparing dataset...")
|
| 71 |
+
|
| 72 |
+
if uploaded_file is not None:
|
| 73 |
+
data_path = Path(uploaded_file.name)
|
| 74 |
+
else:
|
| 75 |
+
data_path = Path(BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA))
|
| 76 |
+
if not data_path.exists():
|
| 77 |
+
return (
|
| 78 |
+
f"Dataset not found: `{data_path}`. "
|
| 79 |
+
"Upload a parquet file or include data/ in the Space repo.",
|
| 80 |
+
None,
|
| 81 |
+
None,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
hub_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
| 85 |
+
if push_to_hub and not hub_token:
|
| 86 |
+
return (
|
| 87 |
+
"Add `HF_TOKEN` to Space secrets to push models to the Hub.",
|
| 88 |
+
None,
|
| 89 |
+
None,
|
| 90 |
+
)
|
| 91 |
+
if push_to_hub and not hub_model_id.strip():
|
| 92 |
+
return "Enter a Hub model id (e.g. `your-username/sql-codebert-classifier`).", None, None
|
| 93 |
+
|
| 94 |
+
if OUTPUT_DIR.exists():
|
| 95 |
+
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
|
| 96 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
|
| 98 |
+
samples = int(max_samples) if max_samples and max_samples > 0 else None
|
| 99 |
+
progress(0.1, desc="Starting CodeBERT training...")
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
metrics = train(
|
| 103 |
+
data_path=data_path,
|
| 104 |
+
output_dir=OUTPUT_DIR,
|
| 105 |
+
epochs=epochs,
|
| 106 |
+
batch_size=batch_size,
|
| 107 |
+
learning_rate=learning_rate,
|
| 108 |
+
max_length=max_length,
|
| 109 |
+
max_samples=samples,
|
| 110 |
+
fp16=fp16,
|
| 111 |
+
save_strategy="no",
|
| 112 |
+
push_to_hub=push_to_hub,
|
| 113 |
+
hub_model_id=hub_model_id.strip() or None,
|
| 114 |
+
hub_token=hub_token,
|
| 115 |
+
)
|
| 116 |
+
except Exception as exc:
|
| 117 |
+
return f"Training failed:\n\n```\n{exc}\n```", None, None
|
| 118 |
+
|
| 119 |
+
progress(1.0, desc="Done")
|
| 120 |
+
if push_to_hub and hub_model_id.strip():
|
| 121 |
+
metrics["hub_url"] = f"https://huggingface.co/{hub_model_id.strip()}"
|
| 122 |
+
|
| 123 |
+
metrics_path = OUTPUT_DIR / "metrics.json"
|
| 124 |
+
summary = _format_metrics(metrics)
|
| 125 |
+
return summary, str(metrics_path) if metrics_path.exists() else None, str(OUTPUT_DIR)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_preview(dataset_choice: str, uploaded_file) -> str:
|
| 129 |
+
try:
|
| 130 |
+
if uploaded_file is not None:
|
| 131 |
+
df = pd.read_parquet(uploaded_file.name)
|
| 132 |
+
else:
|
| 133 |
+
path = BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA)
|
| 134 |
+
if not Path(path).exists():
|
| 135 |
+
return f"Dataset not found: {path}"
|
| 136 |
+
df = pd.read_parquet(path)
|
| 137 |
+
cols = list(df.columns)
|
| 138 |
+
sample = df.head(2).to_dict(orient="records")
|
| 139 |
+
return f"**Rows:** {len(df):,}\n\n**Columns:** `{cols}`\n\n**Sample:**\n```json\n{json.dumps(sample, indent=2)[:2000]}\n```"
|
| 140 |
+
except Exception as exc:
|
| 141 |
+
return f"Could not load preview: {exc}"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
with gr.Blocks(title="SQL Error Classifier — Train") as demo:
|
| 145 |
+
gr.Markdown(
|
| 146 |
+
"""
|
| 147 |
+
# SQL Error Classifier — CodeBERT Training
|
| 148 |
+
Train **microsoft/codebert-base** as a cross-encoder on this Space.
|
| 149 |
+
|
| 150 |
+
**Input format:** `QUESTION` + `SCHEMA` + `STUDENT_SQL` + `CORRECT_SQL` (single sequence)
|
| 151 |
+
|
| 152 |
+
**GPU recommended** — upgrade Space hardware to `t4-small` or better.
|
| 153 |
+
"""
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
with gr.Row():
|
| 157 |
+
with gr.Column(scale=1):
|
| 158 |
+
dataset_choice = gr.Dropdown(
|
| 159 |
+
choices=list(BUNDLED_DATASETS.keys()),
|
| 160 |
+
value="Dev (15K samples)",
|
| 161 |
+
label="Bundled dataset",
|
| 162 |
+
)
|
| 163 |
+
uploaded = gr.File(
|
| 164 |
+
label="Or upload parquet",
|
| 165 |
+
file_types=[".parquet"],
|
| 166 |
+
)
|
| 167 |
+
preview_btn = gr.Button("Preview dataset")
|
| 168 |
+
preview_out = gr.Markdown()
|
| 169 |
+
|
| 170 |
+
max_samples = gr.Number(
|
| 171 |
+
label="Max samples (0 = all)",
|
| 172 |
+
value=5000,
|
| 173 |
+
precision=0,
|
| 174 |
+
)
|
| 175 |
+
epochs = gr.Slider(1, 10, value=2, step=1, label="Epochs")
|
| 176 |
+
batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch size")
|
| 177 |
+
learning_rate = gr.Number(label="Learning rate", value=2e-5)
|
| 178 |
+
max_length = gr.Slider(128, 512, value=512, step=64, label="Max length")
|
| 179 |
+
fp16 = gr.Checkbox(label="FP16 (GPU only)", value=True)
|
| 180 |
+
|
| 181 |
+
push_to_hub = gr.Checkbox(label="Push to Hugging Face Hub", value=False)
|
| 182 |
+
hub_model_id = gr.Textbox(
|
| 183 |
+
label="Hub model id",
|
| 184 |
+
placeholder="your-username/sql-codebert-classifier",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
train_btn = gr.Button("Start Training", variant="primary")
|
| 188 |
+
|
| 189 |
+
with gr.Column(scale=1):
|
| 190 |
+
result = gr.Markdown(label="Results")
|
| 191 |
+
metrics_file = gr.File(label="metrics.json")
|
| 192 |
+
model_dir = gr.Textbox(label="Model output path", interactive=False)
|
| 193 |
+
|
| 194 |
+
preview_btn.click(load_preview, [dataset_choice, uploaded], preview_out)
|
| 195 |
+
train_btn.click(
|
| 196 |
+
run_training,
|
| 197 |
+
[
|
| 198 |
+
dataset_choice,
|
| 199 |
+
uploaded,
|
| 200 |
+
max_samples,
|
| 201 |
+
epochs,
|
| 202 |
+
batch_size,
|
| 203 |
+
learning_rate,
|
| 204 |
+
max_length,
|
| 205 |
+
fp16,
|
| 206 |
+
push_to_hub,
|
| 207 |
+
hub_model_id,
|
| 208 |
+
],
|
| 209 |
+
[result, metrics_file, model_dir],
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
gr.Markdown(
|
| 213 |
+
"""
|
| 214 |
+
### Space setup
|
| 215 |
+
1. Create a Gradio Space and push this repo
|
| 216 |
+
2. Set **Hardware → GPU (t4-small)**
|
| 217 |
+
3. Add secret `HF_TOKEN` (write token) to push models
|
| 218 |
+
4. Include `data/sql_errors_dev.parquet` in the repo (or upload at runtime)
|
| 219 |
+
|
| 220 |
+
### After training
|
| 221 |
+
Use the saved model with:
|
| 222 |
+
```python
|
| 223 |
+
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
|
| 224 |
+
clf = CodeBERTSQLErrorClassifier("models/codebert-cross-encoder")
|
| 225 |
+
```
|
| 226 |
+
"""
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
demo.launch()
|