nishu08 commited on
Commit
9b2cded
·
verified ·
1 Parent(s): 5c948aa

Deploy CodeBERT training Space

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .venv/
4
+ venv/
5
+ .env
6
+ data/*.parquet
7
+ data/*.csv
8
+ models/*.joblib
9
+ models/evaluation/
10
+ .DS_Store
11
+ *.egg-info/
12
+ dist/
13
+ build/
README.md CHANGED
@@ -1,13 +1,46 @@
1
  ---
2
- title: Sql Error Classifier Train
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.17.3
8
- python_version: '3.13'
9
- app_file: app.py
10
  pinned: false
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SQL Error Classifier Training
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: train_space_app.py
 
9
  pinned: false
10
+ license: mit
11
+ hardware: t4-small
12
  ---
13
 
14
+ # SQL Error Classifier CodeBERT Training Space
15
+
16
+ Train `microsoft/codebert-base` as a **cross-encoder** for multi-label SQL error classification.
17
+
18
+ ## Setup
19
+
20
+ 1. **Hardware:** Settings → Hardware → **GPU t4-small** (recommended)
21
+ 2. **Secrets:** Settings → Secrets → add `HF_TOKEN` (Hugging Face write token) to push models to your account
22
+ 3. **Data:** Include `data/sql_errors_dev.parquet` in this Space repo, or upload parquet at runtime
23
+
24
+ ## Usage
25
+
26
+ 1. Choose bundled dataset or upload your own parquet
27
+ 2. Set epochs, batch size, max samples
28
+ 3. Click **Start Training**
29
+ 4. Optionally enable **Push to Hub** with model id `your-username/sql-codebert-classifier`
30
+
31
+ ## Dataset columns
32
+
33
+ Required (aliases supported):
34
+
35
+ | Column | Aliases |
36
+ |--------|---------|
37
+ | `question` | — |
38
+ | `schema` | — |
39
+ | `student_sql` | `query` |
40
+ | `correct_sql` | `correct_query` |
41
+ | `error_labels` | `label_name` |
42
+
43
+ ## Labels (9-class multi-label)
44
+
45
+ `JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
46
+ `SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`
README_HF_SPACE.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SQL Error Classifier
3
+ emoji: 🗄️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # SQL Error Classifier
14
+
15
+ Demo Space for the multi-tower SQL error classification model.
16
+
17
+ Set `SPACE_MODEL_ID` in Space secrets to your model repo (e.g. `username/sql-error-classifier`).
README_TRAIN_SPACE.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SQL Error Classifier Training
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: train_space_app.py
9
+ pinned: false
10
+ license: mit
11
+ hardware: t4-small
12
+ ---
13
+
14
+ # SQL Error Classifier — CodeBERT Training Space
15
+
16
+ Train `microsoft/codebert-base` as a **cross-encoder** for multi-label SQL error classification.
17
+
18
+ ## Setup
19
+
20
+ 1. **Hardware:** Settings → Hardware → **GPU t4-small** (recommended)
21
+ 2. **Secrets:** Settings → Secrets → add `HF_TOKEN` (Hugging Face write token) to push models to your account
22
+ 3. **Data:** Include `data/sql_errors_dev.parquet` in this Space repo, or upload parquet at runtime
23
+
24
+ ## Usage
25
+
26
+ 1. Choose bundled dataset or upload your own parquet
27
+ 2. Set epochs, batch size, max samples
28
+ 3. Click **Start Training**
29
+ 4. Optionally enable **Push to Hub** with model id `your-username/sql-codebert-classifier`
30
+
31
+ ## Dataset columns
32
+
33
+ Required (aliases supported):
34
+
35
+ | Column | Aliases |
36
+ |--------|---------|
37
+ | `question` | — |
38
+ | `schema` | — |
39
+ | `student_sql` | `query` |
40
+ | `correct_sql` | `correct_query` |
41
+ | `error_labels` | `label_name` |
42
+
43
+ ## Labels (9-class multi-label)
44
+
45
+ `JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
46
+ `SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for Hugging Face Spaces.
3
+
4
+ Deploy: create a Space with sdk=gradio and point app_file to this file.
5
+ Set SPACE_MODEL_ID env var to your HF model repo (e.g. username/sql-error-classifier).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+
14
+ import gradio as gr
15
+
16
+ from src.huggingface import SQLLErrorClassifierHF
17
+
18
+ MODEL_ID = os.getenv("SPACE_MODEL_ID", "models/hf_package")
19
+ LOCAL_PACKAGE = Path(__file__).parent / "models" / "hf_package"
20
+
21
+ EXAMPLE = {
22
+ "question": "What is the average score of students in each department?",
23
+ "schema": "students(id, name, score, department_id) | departments(id, name)",
24
+ "correct_query": (
25
+ "SELECT department_id, AVG(score) FROM students GROUP BY department_id"
26
+ ),
27
+ "student_query": (
28
+ "SELECT department_id, SUM(score) FROM students GROUP BY department_id"
29
+ ),
30
+ "error_message": "query executes but produces incorrect result set",
31
+ }
32
+
33
+
34
+ def _load_classifier() -> SQLLErrorClassifierHF:
35
+ if LOCAL_PACKAGE.exists() and (LOCAL_PACKAGE / "config.json").exists():
36
+ return SQLLErrorClassifierHF.from_pretrained(LOCAL_PACKAGE)
37
+ return SQLLErrorClassifierHF.from_pretrained(MODEL_ID)
38
+
39
+
40
+ clf = _load_classifier()
41
+
42
+
43
+ def classify(
44
+ question: str,
45
+ schema: str,
46
+ correct_query: str,
47
+ student_query: str,
48
+ error_message: str,
49
+ ) -> tuple[str, str, str]:
50
+ result = clf.predict(
51
+ question=question.strip(),
52
+ schema=schema.strip(),
53
+ correct_query=correct_query.strip(),
54
+ student_query=student_query.strip(),
55
+ error_message=error_message.strip() or None,
56
+ )
57
+
58
+ summary = (
59
+ f"**{result['label_name']}** \n"
60
+ f"Confidence: **{result['confidence']:.1%}**"
61
+ )
62
+ top_k = "\n".join(
63
+ f"- {item['label_name']}: {item['confidence']:.1%}"
64
+ for item in result["top_k"]
65
+ )
66
+ sims = result.get("similarities") or result.get("pair_scores") or {}
67
+ diagnostics = "\n".join(
68
+ f"- **{k.replace('_', ' ').title()}**: {v:.3f}"
69
+ for k, v in sims.items()
70
+ )
71
+ if not diagnostics:
72
+ diagnostics = "_No diagnostic scores for this model type._"
73
+ return summary, top_k, diagnostics
74
+
75
+
76
+ with gr.Blocks(title="SQL Error Classifier") as demo:
77
+ gr.Markdown(
78
+ """
79
+ # SQL Error Classifier
80
+ Classify **which mistake area** a student is struggling with, using:
81
+ **question**, **schema**, **correct query**, and the **student's query**.
82
+
83
+ Powered by a multi-tower MiniLM architecture on Hugging Face.
84
+ """
85
+ )
86
+
87
+ with gr.Row():
88
+ with gr.Column():
89
+ question = gr.Textbox(
90
+ label="Question",
91
+ lines=2,
92
+ value=EXAMPLE["question"],
93
+ )
94
+ schema = gr.Textbox(
95
+ label="Schema",
96
+ lines=2,
97
+ value=EXAMPLE["schema"],
98
+ )
99
+ correct_query = gr.Textbox(
100
+ label="Correct Query",
101
+ lines=3,
102
+ value=EXAMPLE["correct_query"],
103
+ )
104
+ student_query = gr.Textbox(
105
+ label="Student Query",
106
+ lines=3,
107
+ value=EXAMPLE["student_query"],
108
+ )
109
+ error_message = gr.Textbox(
110
+ label="DB Error Message (optional)",
111
+ lines=2,
112
+ value=EXAMPLE["error_message"],
113
+ )
114
+ run_btn = gr.Button("Classify", variant="primary")
115
+
116
+ with gr.Column():
117
+ prediction = gr.Markdown(label="Prediction")
118
+ top_k = gr.Markdown(label="Top 3")
119
+ diagnostics = gr.Markdown(label="Semantic Diagnostics")
120
+
121
+ run_btn.click(
122
+ classify,
123
+ inputs=[question, schema, correct_query, student_query, error_message],
124
+ outputs=[prediction, top_k, diagnostics],
125
+ )
126
+
127
+ gr.Examples(
128
+ examples=[
129
+ [
130
+ EXAMPLE["question"],
131
+ EXAMPLE["schema"],
132
+ EXAMPLE["correct_query"],
133
+ EXAMPLE["student_query"],
134
+ EXAMPLE["error_message"],
135
+ ],
136
+ [
137
+ "Find students who have not provided an email address.",
138
+ "students(id, name, email, phone)",
139
+ "SELECT name FROM students WHERE email IS NULL",
140
+ "SELECT name FROM students WHERE email = NULL",
141
+ "use IS NULL or IS NOT NULL to test for null values",
142
+ ],
143
+ [
144
+ "List each student's name along with their department name.",
145
+ "students(id, name, department_id) | departments(id, name)",
146
+ "SELECT students.name, departments.name FROM students "
147
+ "INNER JOIN departments ON students.department_id = departments.id",
148
+ "SELECT students.name, departments.name FROM students JOIN departments",
149
+ "missing ON clause or invalid join condition",
150
+ ],
151
+ ],
152
+ inputs=[question, schema, correct_query, student_query, error_message],
153
+ )
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch()
config/codebert_labels.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Primary labels for CodeBERT cross-encoder training
2
+ labels:
3
+ - JOIN_ERROR
4
+ - AGGREGATION_ERROR
5
+ - FILTER_ERROR
6
+ - WINDOW_FUNCTION_ERROR
7
+ - SUBQUERY_ERROR
8
+ - NULL_HANDLING_ERROR
9
+ - PERFORMANCE_ERROR
10
+ - LOGICAL_ERROR
11
+ - SYNTAX_ERROR
12
+
13
+ # Map dataset label_name values → one or more CodeBERT labels (multi-label)
14
+ alias_map:
15
+ JOIN_ERROR: [JOIN_ERROR]
16
+ AGGREGATION_ERROR: [AGGREGATION_ERROR]
17
+ HAVING_WHERE_ERROR: [AGGREGATION_ERROR]
18
+ FILTERING_ERROR: [FILTER_ERROR]
19
+ WINDOW_FUNCTION_ERROR: [WINDOW_FUNCTION_ERROR]
20
+ SUBQUERY_ERROR: [SUBQUERY_ERROR]
21
+ NULL_HANDLING_ERROR: [NULL_HANDLING_ERROR]
22
+ PERFORMANCE_ERROR: [PERFORMANCE_ERROR]
23
+ LOGICAL_QUERY_ERROR: [LOGICAL_ERROR]
24
+ SYNTAX_ERROR: [SYNTAX_ERROR]
25
+ DATE_FUNCTION_ERROR: [SYNTAX_ERROR]
26
+ COLUMN_REFERENCE_ERROR: [SYNTAX_ERROR]
27
+ TABLE_REFERENCE_ERROR: [SYNTAX_ERROR]
28
+ DATA_TYPE_ERROR: [SYNTAX_ERROR]
29
+ DUPLICATE_RECORD_ERROR: [FILTER_ERROR]
config/error_categories.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ categories:
2
+ - id: 0
3
+ name: SYNTAX_ERROR
4
+ description: Missing comma, bracket, quote
5
+ - id: 1
6
+ name: JOIN_ERROR
7
+ description: Missing ON, wrong join condition
8
+ - id: 2
9
+ name: AGGREGATION_ERROR
10
+ description: Missing GROUP BY
11
+ - id: 3
12
+ name: HAVING_WHERE_ERROR
13
+ description: Using aggregate in WHERE
14
+ - id: 4
15
+ name: SUBQUERY_ERROR
16
+ description: Multiple rows returned
17
+ - id: 5
18
+ name: WINDOW_FUNCTION_ERROR
19
+ description: Incorrect OVER/PARTITION BY
20
+ - id: 6
21
+ name: NULL_HANDLING_ERROR
22
+ description: "= NULL instead of IS NULL"
23
+ - id: 7
24
+ name: DATE_FUNCTION_ERROR
25
+ description: Incorrect date format/function
26
+ - id: 8
27
+ name: COLUMN_REFERENCE_ERROR
28
+ description: Column doesn't exist
29
+ - id: 9
30
+ name: TABLE_REFERENCE_ERROR
31
+ description: Table doesn't exist
32
+ - id: 10
33
+ name: DATA_TYPE_ERROR
34
+ description: Comparing integer with string
35
+ - id: 11
36
+ name: DUPLICATE_RECORD_ERROR
37
+ description: Missing DISTINCT
38
+ - id: 12
39
+ name: LOGICAL_QUERY_ERROR
40
+ description: Query runs but answer is wrong
41
+ - id: 13
42
+ name: PERFORMANCE_ERROR
43
+ description: "SELECT *, inefficient joins"
44
+ - id: 14
45
+ name: FILTERING_ERROR
46
+ description: Incorrect WHERE clause
hub/CODEBERT_MODEL_CARD.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - codebert
6
+ - sql
7
+ - education
8
+ - text-classification
9
+ - cross-encoder
10
+ base_model: microsoft/codebert-base
11
+ pipeline_tag: text-classification
12
+ ---
13
+
14
+ # SQL CodeBERT Cross-Encoder
15
+
16
+ Multi-label SQL error classifier using **microsoft/codebert-base** as a cross-encoder.
17
+
18
+ ## Input Format
19
+
20
+ All fields are concatenated into one sequence:
21
+
22
+ ```
23
+ QUESTION:
24
+ {question}
25
+
26
+ SCHEMA:
27
+ {schema}
28
+
29
+ STUDENT_SQL:
30
+ {student_sql}
31
+
32
+ CORRECT_SQL:
33
+ {correct_sql}
34
+ ```
35
+
36
+ ## Labels
37
+
38
+ `JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
39
+ `SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`
40
+
41
+ ## Training
42
+
43
+ ```bash
44
+ python -m src.hf_train_codebert \
45
+ --data data/sql_errors_1m.parquet \
46
+ --output-dir models/codebert-cross-encoder \
47
+ --epochs 3 \
48
+ --push-to-hub \
49
+ --hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder
50
+ ```
51
+
52
+ ## Inference
53
+
54
+ ```python
55
+ from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
56
+
57
+ clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder")
58
+ result = clf.predict(
59
+ question="What is the average score per department?",
60
+ schema="students(id, score, department_id)",
61
+ student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
62
+ correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
63
+ )
64
+ print(result["error_labels"])
65
+ ```
hub/MODEL_CARD.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - sql
6
+ - education
7
+ - text-classification
8
+ - sentence-transformers
9
+ - multi-tower
10
+ pipeline_tag: text-classification
11
+ ---
12
+
13
+ # SQL Error Classifier (Multi-Tower)
14
+
15
+ Lightweight classifier that identifies **which SQL mistake area** a student is struggling with, given:
16
+
17
+ - **Question** — natural-language task
18
+ - **Schema** — available tables and columns
19
+ - **Correct query** — reference solution
20
+ - **Student query** — what the student submitted
21
+ - **Error message** *(optional)* — database error text
22
+
23
+ ## Architecture
24
+
25
+ Multi-tower semantic comparison using `sentence-transformers/all-MiniLM-L6-v2`:
26
+
27
+ 1. **Intent tower** — question + schema
28
+ 2. **Reference tower** — correct query
29
+ 3. **Student tower** — student query (+ error)
30
+ 4. **Comparison layer** — embedding diff, interaction, cosine similarities, SQL rule features
31
+ 5. **Linear head** — 15 error categories
32
+
33
+ ## Error Categories (15)
34
+
35
+ | ID | Category |
36
+ |----|----------|
37
+ | 0 | SYNTAX_ERROR |
38
+ | 1 | JOIN_ERROR |
39
+ | 2 | AGGREGATION_ERROR |
40
+ | 3 | HAVING_WHERE_ERROR |
41
+ | 4 | SUBQUERY_ERROR |
42
+ | 5 | WINDOW_FUNCTION_ERROR |
43
+ | 6 | NULL_HANDLING_ERROR |
44
+ | 7 | DATE_FUNCTION_ERROR |
45
+ | 8 | COLUMN_REFERENCE_ERROR |
46
+ | 9 | TABLE_REFERENCE_ERROR |
47
+ | 10 | DATA_TYPE_ERROR |
48
+ | 11 | DUPLICATE_RECORD_ERROR |
49
+ | 12 | LOGICAL_QUERY_ERROR |
50
+ | 13 | PERFORMANCE_ERROR |
51
+ | 14 | FILTERING_ERROR |
52
+
53
+ ## Usage
54
+
55
+ ```python
56
+ from src.huggingface import SQLLErrorClassifierHF
57
+
58
+ clf = SQLLErrorClassifierHF.from_pretrained("YOUR_USERNAME/sql-error-classifier")
59
+
60
+ result = clf.predict(
61
+ question="What is the average score of students in each department?",
62
+ schema="students(id, name, score, department_id) | departments(id, name)",
63
+ correct_query="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
64
+ student_query="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
65
+ )
66
+
67
+ print(result["label_name"]) # LOGICAL_QUERY_ERROR
68
+ print(result["confidence"]) # 0.94
69
+ print(result["similarities"]) # semantic alignment scores
70
+ ```
71
+
72
+ ## Gradio Demo
73
+
74
+ Deploy as a [Hugging Face Space](https://huggingface.co/docs/hub/spaces) with `app.py` from this repository.
75
+
76
+ ## Model Details
77
+
78
+ - **Encoder**: `sentence-transformers/all-MiniLM-L6-v2` (loaded from Hub, not bundled)
79
+ - **Head**: scikit-learn SGDClassifier + StandardScaler
80
+ - **Size**: ~5 MB classifier head (encoder ~80 MB, cached separately)
81
+ - **Inference**: ~100–200 ms on CPU
82
+
83
+ ## Training Data
84
+
85
+ Synthetically generated from exercise templates with per-category error injectors.
86
+ 1M balanced samples across 15 classes.
87
+
88
+ ## Citation
89
+
90
+ ```bibtex
91
+ @misc{sql-error-classifier,
92
+ title = {SQL Error Classifier - Multi-Tower},
93
+ author = {SQLErrorClassification},
94
+ year = {2025},
95
+ publisher = {Hugging Face},
96
+ }
97
+ ```
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.24.0
2
+ pandas>=2.0.0
3
+ scikit-learn>=1.3.0
4
+ joblib>=1.3.0
5
+ pyarrow>=14.0.0
6
+ tqdm>=4.66.0
7
+ pyyaml>=6.0
8
+ matplotlib>=3.7.0
9
+ sentence-transformers>=2.2.0
10
+ torch>=2.0.0
11
+ transformers>=4.36.0
12
+ accelerate>=0.25.0
13
+ datasets>=2.16.0
14
+ huggingface_hub>=0.20.0
15
+ gradio>=4.44.0
scripts/create_hf_package.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Create a local Hugging Face Hub package from a trained model."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
11
+ sys.path.insert(0, str(PROJECT_ROOT))
12
+
13
+ from src.huggingface import package_for_hub
14
+
15
+ def main() -> None:
16
+ parser = argparse.ArgumentParser(description="Create HF Hub package locally")
17
+ parser.add_argument(
18
+ "--model",
19
+ type=Path,
20
+ default=PROJECT_ROOT / "models" / "multi_tower_dev.joblib",
21
+ )
22
+ parser.add_argument(
23
+ "--output",
24
+ type=Path,
25
+ default=PROJECT_ROOT / "models" / "hf_package",
26
+ )
27
+ args = parser.parse_args()
28
+
29
+ out = package_for_hub(args.model, args.output)
30
+ print(f"Package ready at {out}")
31
+ print("Files:", [p.name for p in out.iterdir()])
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
scripts/deploy_train_space.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Deploy training Space to Hugging Face.
3
+ # Usage: ./scripts/deploy_train_space.sh YOUR_HF_USERNAME/sql-error-classifier-train
4
+ set -euo pipefail
5
+
6
+ SPACE_ID="${1:-}"
7
+ if [[ -z "${SPACE_ID}" ]]; then
8
+ echo "Usage: $0 YOUR_USERNAME/sql-error-classifier-train"
9
+ exit 1
10
+ fi
11
+
12
+ ROOT="$(cd "$(dirname "$0")/.." && pwd)"
13
+ TOKEN="${HF_TOKEN:-${HUGGING_FACE_HUB_TOKEN:-}}"
14
+
15
+ if [[ -z "${TOKEN}" ]]; then
16
+ echo "Set HF_TOKEN before deploying."
17
+ exit 1
18
+ fi
19
+
20
+ WORKDIR=$(mktemp -d)
21
+ trap 'rm -rf "${WORKDIR}"' EXIT
22
+
23
+ echo "==> Preparing Space files in ${WORKDIR}..."
24
+ rsync -a \
25
+ --exclude '.venv' \
26
+ --exclude 'models' \
27
+ --exclude '__pycache__' \
28
+ --exclude '.git' \
29
+ "${ROOT}/" "${WORKDIR}/"
30
+
31
+ cp "${ROOT}/README_TRAIN_SPACE.md" "${WORKDIR}/README.md"
32
+
33
+ echo "==> Creating / updating Space ${SPACE_ID}..."
34
+ python - <<PY
35
+ from huggingface_hub import HfApi
36
+ api = HfApi(token="${TOKEN}")
37
+ api.create_repo("${SPACE_ID}", repo_type="space", space_sdk="gradio", exist_ok=True)
38
+ PY
39
+
40
+ echo "==> Uploading to Hugging Face Space..."
41
+ python - <<PY
42
+ from huggingface_hub import HfApi
43
+ api = HfApi(token="${TOKEN}")
44
+ api.upload_folder(
45
+ folder_path="${WORKDIR}",
46
+ repo_id="${SPACE_ID}",
47
+ repo_type="space",
48
+ commit_message="Deploy CodeBERT training Space",
49
+ )
50
+ PY
51
+
52
+ echo "==> Done: https://huggingface.co/spaces/${SPACE_ID}"
53
+ echo "Next: Space Settings → Hardware → GPU t4-small"
54
+ echo " Space Settings → Secrets → HF_TOKEN"
scripts/push_to_hub.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Package and push the SQL error classifier to Hugging Face Hub."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
12
+ sys.path.insert(0, str(PROJECT_ROOT))
13
+
14
+ from huggingface_hub import HfApi, create_repo
15
+
16
+ from src.huggingface import package_for_hub
17
+ DEFAULT_MODEL = PROJECT_ROOT / "models" / "multi_tower_dev.joblib"
18
+ DEFAULT_PACKAGE = PROJECT_ROOT / "models" / "hf_package"
19
+ MODEL_CARD = PROJECT_ROOT / "hub" / "MODEL_CARD.md"
20
+
21
+
22
+ def push(
23
+ model_path: Path = DEFAULT_MODEL,
24
+ package_dir: Path = DEFAULT_PACKAGE,
25
+ repo_id: str = "",
26
+ private: bool = False,
27
+ token: str | None = None,
28
+ ) -> str:
29
+ if not repo_id:
30
+ raise ValueError("--repo-id is required (e.g. your-username/sql-error-classifier)")
31
+
32
+ token = token or os.getenv("HF_TOKEN")
33
+ api = HfApi(token=token)
34
+
35
+ print(f"Packaging model from {model_path}...")
36
+ package_for_hub(model_path, package_dir)
37
+
38
+ print(f"Creating repo {repo_id}...")
39
+ create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token)
40
+
41
+ print("Uploading model files...")
42
+ api.upload_folder(
43
+ folder_path=str(package_dir),
44
+ repo_id=repo_id,
45
+ repo_type="model",
46
+ token=token,
47
+ )
48
+
49
+ if MODEL_CARD.exists():
50
+ api.upload_file(
51
+ path_or_fileobj=str(MODEL_CARD),
52
+ path_in_repo="README.md",
53
+ repo_id=repo_id,
54
+ repo_type="model",
55
+ token=token,
56
+ )
57
+
58
+ url = f"https://huggingface.co/{repo_id}"
59
+ print(f"Done: {url}")
60
+ return url
61
+
62
+
63
+ def main() -> None:
64
+ parser = argparse.ArgumentParser(description="Push SQL error classifier to HF Hub")
65
+ parser.add_argument("--model", type=Path, default=DEFAULT_MODEL)
66
+ parser.add_argument("--package-dir", type=Path, default=DEFAULT_PACKAGE)
67
+ parser.add_argument(
68
+ "--repo-id",
69
+ type=str,
70
+ required=True,
71
+ help="Hugging Face repo id, e.g. nishantgupta/sql-error-classifier",
72
+ )
73
+ parser.add_argument("--private", action="store_true")
74
+ parser.add_argument("--token", type=str, default=None)
75
+ args = parser.parse_args()
76
+
77
+ push(
78
+ model_path=args.model,
79
+ package_dir=args.package_dir,
80
+ repo_id=args.repo_id,
81
+ private=args.private,
82
+ token=args.token,
83
+ )
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
scripts/run_codebert_training.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ DATA="${1:-data/sql_errors_dev.parquet}"
5
+ OUTPUT="${2:-models/codebert-cross-encoder}"
6
+ SAMPLES="${3:-}"
7
+
8
+ CMD=(python -m src.hf_train_codebert --data "${DATA}" --output-dir "${OUTPUT}")
9
+
10
+ if [[ -n "${SAMPLES}" ]]; then
11
+ CMD+=(--max-samples "${SAMPLES}")
12
+ fi
13
+
14
+ echo "==> Training CodeBERT cross-encoder..."
15
+ "${CMD[@]}"
16
+
17
+ echo "==> Done. Model at ${OUTPUT}"
scripts/run_pipeline.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SAMPLES="${1:-1000000}"
5
+ WORKERS="${2:-8}"
6
+
7
+ echo "==> Generating ${SAMPLES} labeled SQL samples..."
8
+ python -m src.generate_dataset --samples "${SAMPLES}" --workers "${WORKERS}"
9
+
10
+ echo "==> Training classifier..."
11
+ python -m src.train
12
+
13
+ echo "==> Evaluating..."
14
+ python -m src.evaluate
15
+
16
+ echo "==> Done. Model at models/sql_error_classifier.joblib"
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """SQL error classification package."""
src/categories.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Dict, List
6
+
7
+ import yaml
8
+
9
+ CONFIG_PATH = Path(__file__).resolve().parent.parent / "config" / "error_categories.yaml"
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class ErrorCategory:
14
+ id: int
15
+ name: str
16
+ description: str
17
+
18
+
19
+ def load_categories(config_path: Path = CONFIG_PATH) -> List[ErrorCategory]:
20
+ with open(config_path) as f:
21
+ data = yaml.safe_load(f)
22
+ return [
23
+ ErrorCategory(id=c["id"], name=c["name"], description=c["description"])
24
+ for c in data["categories"]
25
+ ]
26
+
27
+
28
+ def id_to_name(categories: List[ErrorCategory] | None = None) -> Dict[int, str]:
29
+ cats = categories or load_categories()
30
+ return {c.id: c.name for c in cats}
31
+
32
+
33
+ def name_to_id(categories: List[ErrorCategory] | None = None) -> Dict[str, int]:
34
+ cats = categories or load_categories()
35
+ return {c.name: c.id for c in cats}
src/codebert_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dataset and preprocessing for CodeBERT Trainer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from src.codebert_formatting import format_cross_encoder_input
14
+ from src.codebert_labels import label_to_multihot, load_codebert_labels
15
+
16
+
17
+ def normalize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
18
+ """Map project column names to the canonical training schema."""
19
+ col_map = {
20
+ "query": "student_sql",
21
+ "correct_query": "correct_sql",
22
+ "label_name": "error_labels",
23
+ }
24
+ out = df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}).copy()
25
+
26
+ required = ["question", "schema", "student_sql", "correct_sql", "error_labels"]
27
+ missing = [c for c in required if c not in out.columns]
28
+ if missing:
29
+ raise ValueError(
30
+ f"Dataset missing required columns: {missing}. "
31
+ f"Expected {required} (or aliases query/correct_query/label_name)."
32
+ )
33
+ return out
34
+
35
+
36
+ class SQLCodeBERTDataset(Dataset):
37
+ """Tokenized SQL error dataset for Hugging Face Trainer."""
38
+
39
+ def __init__(
40
+ self,
41
+ df: pd.DataFrame,
42
+ tokenizer: PreTrainedTokenizerBase,
43
+ label_list: Optional[List[str]] = None,
44
+ max_length: int = 512,
45
+ ):
46
+ self.df = normalize_dataframe(df).reset_index(drop=True)
47
+ self.tokenizer = tokenizer
48
+ self.label_list = label_list or load_codebert_labels()
49
+ self.max_length = max_length
50
+ self.num_labels = len(self.label_list)
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.df)
54
+
55
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
56
+ row = self.df.iloc[idx]
57
+ text = format_cross_encoder_input(
58
+ question=str(row["question"]),
59
+ schema=str(row["schema"]),
60
+ student_sql=str(row["student_sql"]),
61
+ correct_sql=str(row["correct_sql"]),
62
+ )
63
+ encoded = self.tokenizer(
64
+ text,
65
+ truncation=True,
66
+ max_length=self.max_length,
67
+ padding=False,
68
+ return_tensors=None,
69
+ )
70
+ labels = label_to_multihot(str(row["error_labels"]), self.label_list)
71
+ encoded["labels"] = labels.tolist()
72
+ return encoded
73
+
74
+
75
+ class SQLCodeBERTDataCollator:
76
+ """Pad batches dynamically for Trainer."""
77
+
78
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
79
+ self.tokenizer = tokenizer
80
+
81
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
82
+ labels = [f.pop("labels") for f in features]
83
+ batch = self.tokenizer.pad(features, padding=True, return_tensors="pt")
84
+ batch["labels"] = torch.tensor(labels, dtype=torch.float)
85
+ return batch
86
+
87
+
88
+ def prepare_datasets(
89
+ df: pd.DataFrame,
90
+ tokenizer: PreTrainedTokenizerBase,
91
+ test_size: float = 0.1,
92
+ val_size: float = 0.1,
93
+ max_length: int = 512,
94
+ seed: int = 42,
95
+ ) -> tuple[SQLCodeBERTDataset, SQLCodeBERTDataset, SQLCodeBERTDataset]:
96
+ from sklearn.model_selection import train_test_split
97
+
98
+ df = normalize_dataframe(df)
99
+ trainval, test_df = train_test_split(
100
+ df,
101
+ test_size=test_size,
102
+ random_state=seed,
103
+ stratify=df["error_labels"],
104
+ )
105
+ relative_val = val_size / (1 - test_size)
106
+ train_df, val_df = train_test_split(
107
+ trainval,
108
+ test_size=relative_val,
109
+ random_state=seed,
110
+ stratify=trainval["error_labels"],
111
+ )
112
+
113
+ return (
114
+ SQLCodeBERTDataset(train_df, tokenizer, max_length=max_length),
115
+ SQLCodeBERTDataset(val_df, tokenizer, max_length=max_length),
116
+ SQLCodeBERTDataset(test_df, tokenizer, max_length=max_length),
117
+ )
src/codebert_formatting.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cross-encoder input formatting for CodeBERT."""
2
+
3
+ from __future__ import annotations
4
+
5
+ QUESTION_TAG = "QUESTION:"
6
+ SCHEMA_TAG = "SCHEMA:"
7
+ STUDENT_TAG = "STUDENT_SQL:"
8
+ CORRECT_TAG = "CORRECT_SQL:"
9
+
10
+
11
+ def format_cross_encoder_input(
12
+ question: str,
13
+ schema: str,
14
+ student_sql: str,
15
+ correct_sql: str,
16
+ ) -> str:
17
+ """
18
+ Concatenate all fields into a single CodeBERT input sequence.
19
+
20
+ The model attends jointly across question intent, schema, student SQL,
21
+ and the reference solution — cross-encoder style in one forward pass.
22
+ """
23
+ return (
24
+ f"{QUESTION_TAG}\n{question.strip()}\n\n"
25
+ f"{SCHEMA_TAG}\n{schema.strip()}\n\n"
26
+ f"{STUDENT_TAG}\n{student_sql.strip()}\n\n"
27
+ f"{CORRECT_TAG}\n{correct_sql.strip()}"
28
+ )
src/codebert_labels.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Label utilities for CodeBERT multi-label classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Dict, List, Sequence, Union
7
+
8
+ import numpy as np
9
+ import yaml
10
+
11
+ CONFIG_PATH = (
12
+ Path(__file__).resolve().parent.parent / "config" / "codebert_labels.yaml"
13
+ )
14
+
15
+
16
+ def load_codebert_labels(config_path: Path = CONFIG_PATH) -> List[str]:
17
+ with open(config_path) as f:
18
+ data = yaml.safe_load(f)
19
+ return list(data["labels"])
20
+
21
+
22
+ def load_alias_map(config_path: Path = CONFIG_PATH) -> Dict[str, List[str]]:
23
+ with open(config_path) as f:
24
+ data = yaml.safe_load(f)
25
+ return {k: list(v) for k, v in data["alias_map"].items()}
26
+
27
+
28
+ def label_to_multihot(
29
+ error_labels: Union[str, Sequence[str]],
30
+ label_list: List[str] | None = None,
31
+ alias_map: Dict[str, List[str]] | None = None,
32
+ ) -> np.ndarray:
33
+ """
34
+ Convert error label(s) to multi-hot vector.
35
+
36
+ Accepts:
37
+ - comma-separated string: "JOIN_ERROR,AGGREGATION_ERROR"
38
+ - list of label strings
39
+ - single dataset label_name (resolved via alias_map)
40
+ """
41
+ labels = label_list or load_codebert_labels()
42
+ aliases = alias_map or load_alias_map()
43
+ index = {name: i for i, name in enumerate(labels)}
44
+ vec = np.zeros(len(labels), dtype=np.float32)
45
+
46
+ if isinstance(error_labels, str):
47
+ raw = [s.strip() for s in error_labels.split(",") if s.strip()]
48
+ if len(raw) == 1 and raw[0] in aliases:
49
+ raw = aliases[raw[0]]
50
+ elif len(raw) == 1 and raw[0] in index:
51
+ raw = [raw[0]]
52
+ elif len(raw) == 1 and raw[0] not in index:
53
+ mapped = aliases.get(raw[0], [])
54
+ raw = mapped
55
+ else:
56
+ raw = list(error_labels)
57
+ expanded: List[str] = []
58
+ for item in raw:
59
+ if item in aliases:
60
+ expanded.extend(aliases[item])
61
+ elif item in index:
62
+ expanded.append(item)
63
+ raw = expanded
64
+
65
+ for name in raw:
66
+ if name not in index:
67
+ raise ValueError(f"Unknown label '{name}'. Expected one of {labels}")
68
+ vec[index[name]] = 1.0
69
+
70
+ if vec.sum() == 0:
71
+ raise ValueError(f"No valid labels found in {error_labels}")
72
+ return vec
73
+
74
+
75
+ def multihot_to_label_names(
76
+ vec: np.ndarray,
77
+ label_list: List[str] | None = None,
78
+ threshold: float = 0.5,
79
+ ) -> List[str]:
80
+ labels = label_list or load_codebert_labels()
81
+ indices = np.where(vec >= threshold)[0]
82
+ return [labels[i] for i in indices]
src/cross_encoder_model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cross-encoder architecture for SQL error classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ from sklearn.linear_model import LogisticRegression
11
+ from sklearn.preprocessing import StandardScaler
12
+
13
+ from src.multi_tower_model import QueryContext, contexts_from_dataframe
14
+ from src.sql_features import extract_sql_features
15
+
16
+ DEFAULT_CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L6-v2"
17
+ DEFAULT_FINETUNED_CE = "cross-encoder/ms-marco-MiniLM-L6-v2"
18
+
19
+ PAIR_NAMES = (
20
+ "intent_vs_student",
21
+ "reference_vs_student",
22
+ "intent_vs_reference",
23
+ )
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class CrossEncoderPair:
28
+ name: str
29
+ text_a: str
30
+ text_b: str
31
+
32
+
33
+ def _intent_text(ctx: QueryContext) -> str:
34
+ return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
35
+
36
+
37
+ def _reference_text(ctx: QueryContext) -> str:
38
+ return f"REFERENCE: {ctx.correct_query}"
39
+
40
+
41
+ def _student_text(ctx: QueryContext) -> str:
42
+ parts = [f"STUDENT: {ctx.student_query}"]
43
+ if ctx.error_message:
44
+ parts.append(f"ERROR: {ctx.error_message}")
45
+ return " ".join(parts)
46
+
47
+
48
+ def _context_text(ctx: QueryContext) -> str:
49
+ """Full task context for fine-tuned cross-encoder."""
50
+ return (
51
+ f"QUESTION: {ctx.question} "
52
+ f"SCHEMA: {ctx.schema} "
53
+ f"REFERENCE: {ctx.correct_query}"
54
+ )
55
+
56
+
57
+ def build_pairs(ctx: QueryContext) -> List[CrossEncoderPair]:
58
+ intent, reference, student = (
59
+ _intent_text(ctx),
60
+ _reference_text(ctx),
61
+ _student_text(ctx),
62
+ )
63
+ return [
64
+ CrossEncoderPair("intent_vs_student", intent, student),
65
+ CrossEncoderPair("reference_vs_student", reference, student),
66
+ CrossEncoderPair("intent_vs_reference", intent, reference),
67
+ ]
68
+
69
+
70
+ class CrossEncoderClassifier:
71
+ """
72
+ Hybrid cross-encoder: frozen pairwise relevance + linear head.
73
+
74
+ Unlike bi-encoders (multi-tower), the cross-encoder attends jointly over
75
+ each (context, student) pair — better for logical and filtering errors.
76
+
77
+ Three pairs are scored:
78
+ 1. intent vs student — does the query address the question?
79
+ 2. reference vs student — how far is the student from the answer?
80
+ 3. intent vs reference — task-answer alignment baseline
81
+
82
+ Pair scores + SQL rule features → LogisticRegression → 15 classes.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ cross_encoder_name: str = DEFAULT_CROSS_ENCODER,
88
+ batch_size: int = 32,
89
+ max_length: int = 512,
90
+ ):
91
+ self.cross_encoder_name = cross_encoder_name
92
+ self.batch_size = batch_size
93
+ self.max_length = max_length
94
+ self.cross_encoder = None
95
+ self.scaler = StandardScaler()
96
+ self.clf = LogisticRegression(
97
+ max_iter=1000,
98
+ solver="lbfgs",
99
+ class_weight="balanced",
100
+ random_state=42,
101
+ )
102
+ self.classes_: Optional[np.ndarray] = None
103
+
104
+ def _load_cross_encoder(self):
105
+ if self.cross_encoder is None:
106
+ from sentence_transformers import CrossEncoder
107
+
108
+ self.cross_encoder = CrossEncoder(
109
+ self.cross_encoder_name,
110
+ max_length=self.max_length,
111
+ )
112
+
113
+ def _pair_batches(self, contexts: List[QueryContext]) -> List[List[Tuple[str, str]]]:
114
+ """One batch list per pair type across all contexts."""
115
+ pair_lists: List[List[Tuple[str, str]]] = [[], [], []]
116
+ for ctx in contexts:
117
+ pairs = build_pairs(ctx)
118
+ for i, pair in enumerate(pairs):
119
+ pair_lists[i].append((pair.text_a, pair.text_b))
120
+ return pair_lists
121
+
122
+ def _score_pairs(
123
+ self,
124
+ contexts: List[QueryContext],
125
+ show_progress: bool = False,
126
+ ) -> np.ndarray:
127
+ self._load_cross_encoder()
128
+ pair_batches = self._pair_batches(contexts)
129
+ scores = []
130
+ for batch in pair_batches:
131
+ raw = self.cross_encoder.predict(
132
+ batch,
133
+ batch_size=self.batch_size,
134
+ show_progress_bar=show_progress,
135
+ )
136
+ scores.append(np.asarray(raw, dtype=np.float64).reshape(-1, 1))
137
+ return np.hstack(scores) # (n, 3)
138
+
139
+ def _build_features(
140
+ self,
141
+ contexts: List[QueryContext],
142
+ show_progress: bool = False,
143
+ ) -> np.ndarray:
144
+ pair_scores = self._score_pairs(contexts, show_progress=show_progress)
145
+ s_is, s_rs, s_ir = pair_scores[:, 0], pair_scores[:, 1], pair_scores[:, 2]
146
+
147
+ derived = np.column_stack(
148
+ [
149
+ s_rs - s_is, # reference closer than intent?
150
+ s_is - s_ir, # student-intent gap vs baseline
151
+ s_rs - s_ir, # student-reference gap vs baseline
152
+ s_is * s_rs, # interaction
153
+ np.abs(s_rs - s_is), # intent-reference disagreement
154
+ ]
155
+ )
156
+
157
+ sql_feats = np.array(
158
+ [extract_sql_features(c.student_query, c.correct_query) for c in contexts],
159
+ dtype=np.float64,
160
+ )
161
+
162
+ return np.hstack([pair_scores, derived, sql_feats])
163
+
164
+ def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
165
+ X = self.scaler.transform(self._build_features(contexts))
166
+ return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
167
+
168
+ def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "CrossEncoderClassifier":
169
+ X = self._build_features(contexts, show_progress=True)
170
+ X = self.scaler.fit_transform(X)
171
+ X = np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
172
+ self.clf.fit(X, y)
173
+ self.classes_ = self.clf.classes_
174
+ return self
175
+
176
+ def predict(self, contexts: List[QueryContext]) -> np.ndarray:
177
+ return self.clf.predict(self._prepare_features(contexts))
178
+
179
+ def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
180
+ return self.clf.predict_proba(self._prepare_features(contexts))
181
+
182
+ def explain_pair_scores(self, ctx: QueryContext) -> dict:
183
+ scores = self._score_pairs([ctx])[0]
184
+ return {
185
+ PAIR_NAMES[0]: float(scores[0]),
186
+ PAIR_NAMES[1]: float(scores[1]),
187
+ PAIR_NAMES[2]: float(scores[2]),
188
+ }
189
+
190
+
191
+ class FineTunedCrossEncoderClassifier:
192
+ """
193
+ End-to-end fine-tuned cross-encoder (highest accuracy).
194
+
195
+ Single cross-attention pass over [task_context | student_query] with
196
+ num_labels=15. Slower to train; best on smaller high-quality datasets.
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ cross_encoder_name: str = DEFAULT_FINETUNED_CE,
202
+ batch_size: int = 16,
203
+ max_length: int = 512,
204
+ num_labels: int = 15,
205
+ ):
206
+ self.cross_encoder_name = cross_encoder_name
207
+ self.batch_size = batch_size
208
+ self.max_length = max_length
209
+ self.num_labels = num_labels
210
+ self.model = None
211
+ self.classes_: Optional[np.ndarray] = None
212
+
213
+ def _load_model(self, num_labels: Optional[int] = None):
214
+ if self.model is None:
215
+ from sentence_transformers import CrossEncoder
216
+
217
+ self.model = CrossEncoder(
218
+ self.cross_encoder_name,
219
+ num_labels=num_labels or self.num_labels,
220
+ max_length=self.max_length,
221
+ )
222
+
223
+ def _to_examples(self, contexts: List[QueryContext], labels: Optional[np.ndarray] = None):
224
+ from sentence_transformers import InputExample
225
+
226
+ examples = []
227
+ for i, ctx in enumerate(contexts):
228
+ label = float(labels[i]) if labels is not None else 0.0
229
+ examples.append(
230
+ InputExample(
231
+ texts=[_context_text(ctx), _student_text(ctx)],
232
+ label=label,
233
+ )
234
+ )
235
+ return examples
236
+
237
+ def fit(
238
+ self,
239
+ contexts: List[QueryContext],
240
+ y: np.ndarray,
241
+ epochs: int = 1,
242
+ warmup_steps: int = 100,
243
+ output_path: Optional[Path] = None,
244
+ ) -> "FineTunedCrossEncoderClassifier":
245
+ from torch.utils.data import DataLoader
246
+
247
+ self._load_model(num_labels=len(np.unique(y)))
248
+ train_examples = self._to_examples(contexts, y)
249
+ loader = DataLoader(
250
+ train_examples,
251
+ shuffle=True,
252
+ batch_size=self.batch_size,
253
+ )
254
+ self.model.fit(
255
+ train_dataloader=loader,
256
+ epochs=epochs,
257
+ warmup_steps=min(warmup_steps, max(10, len(train_examples) // 10)),
258
+ show_progress_bar=True,
259
+ output_path=str(output_path) if output_path else None,
260
+ )
261
+ self.classes_ = np.sort(np.unique(y))
262
+ return self
263
+
264
+ def predict(self, contexts: List[QueryContext]) -> np.ndarray:
265
+ self._load_model()
266
+ pairs = [[_context_text(c), _student_text(c)] for c in contexts]
267
+ logits = self.model.predict(
268
+ pairs,
269
+ batch_size=self.batch_size,
270
+ show_progress_bar=False,
271
+ convert_to_numpy=True,
272
+ )
273
+ logits = np.asarray(logits)
274
+ if logits.ndim == 1:
275
+ return logits.astype(int)
276
+ return logits.argmax(axis=1)
277
+
278
+ def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
279
+ self._load_model()
280
+ pairs = [[_context_text(c), _student_text(c)] for c in contexts]
281
+ logits = self.model.predict(
282
+ pairs,
283
+ batch_size=self.batch_size,
284
+ show_progress_bar=False,
285
+ convert_to_numpy=True,
286
+ )
287
+ logits = np.asarray(logits, dtype=np.float64)
288
+ if logits.ndim == 1:
289
+ # binary fallback
290
+ probs = np.zeros((len(contexts), len(self.classes_)))
291
+ for i, pred in enumerate(logits.astype(int)):
292
+ idx = np.where(self.classes_ == pred)[0][0]
293
+ probs[i, idx] = 1.0
294
+ return probs
295
+ # softmax
296
+ exp = np.exp(logits - logits.max(axis=1, keepdims=True))
297
+ return exp / exp.sum(axis=1, keepdims=True)
298
+
299
+ def save(self, path: Path) -> Path:
300
+ path.mkdir(parents=True, exist_ok=True)
301
+ self._load_model()
302
+ self.model.save(str(path))
303
+ return path
304
+
305
+ @classmethod
306
+ def load(cls, path: Path) -> "FineTunedCrossEncoderClassifier":
307
+ from sentence_transformers import CrossEncoder
308
+
309
+ instance = cls()
310
+ instance.model = CrossEncoder(str(path))
311
+ instance.classes_ = np.arange(instance.model.num_labels)
312
+ return instance
src/evaluate.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate trained model with confusion matrix and per-class metrics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ from sklearn.metrics import ConfusionMatrixDisplay, classification_report
15
+
16
+ from src.categories import id_to_name, load_categories
17
+ from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
18
+ from src.cross_encoder_model import (
19
+ CrossEncoderClassifier,
20
+ FineTunedCrossEncoderClassifier,
21
+ )
22
+ from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe
23
+
24
+ CONTEXT_MODELS = (
25
+ CrossEncoderClassifier,
26
+ FineTunedCrossEncoderClassifier,
27
+ MultiTowerClassifier,
28
+ )
29
+
30
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
31
+ DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
32
+ DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "evaluation"
33
+
34
+
35
+ def evaluate(
36
+ data_path: Path = DEFAULT_DATA,
37
+ model_path: Path = DEFAULT_MODEL_PATH,
38
+ output_dir: Path = DEFAULT_OUTPUT,
39
+ sample_size: int = 100_000,
40
+ use_error_message: bool = True,
41
+ seed: int = 42,
42
+ ) -> dict:
43
+ output_dir.mkdir(parents=True, exist_ok=True)
44
+
45
+ df = pd.read_parquet(data_path)
46
+ if len(df) > sample_size:
47
+ df = df.sample(n=sample_size, random_state=seed)
48
+
49
+ labels = df["label_id"].values
50
+ model = load_model(model_path)
51
+
52
+ if isinstance(model, CONTEXT_MODELS):
53
+ if not use_error_message and "error_message" in df.columns:
54
+ df = df.drop(columns=["error_message"])
55
+ preds = model.predict(contexts_from_dataframe(df))
56
+ else:
57
+ texts = combine_features(
58
+ queries=df["query"].tolist(),
59
+ error_messages=df["error_message"].tolist() if use_error_message else None,
60
+ schemas=df["schema"].tolist() if "schema" in df.columns else None,
61
+ questions=df["question"].tolist() if "question" in df.columns else None,
62
+ )
63
+ preds = model.predict(texts)
64
+
65
+ categories = load_categories()
66
+ target_names = [c.name for c in categories]
67
+ report = classification_report(
68
+ labels, preds, target_names=target_names, output_dict=True, zero_division=0
69
+ )
70
+
71
+ with open(output_dir / "classification_report.json", "w") as f:
72
+ json.dump(report, f, indent=2)
73
+
74
+ cm = ConfusionMatrixDisplay.from_predictions(
75
+ labels,
76
+ preds,
77
+ display_labels=target_names,
78
+ xticks_rotation=90,
79
+ cmap="Blues",
80
+ colorbar=False,
81
+ )
82
+ fig = cm.figure_
83
+ fig.set_size_inches(14, 12)
84
+ fig.tight_layout()
85
+ fig.savefig(output_dir / "confusion_matrix.png", dpi=150)
86
+ plt.close(fig)
87
+
88
+ print(f"Accuracy: {report['accuracy']:.4f}")
89
+ print(f"Reports saved to {output_dir}")
90
+ return report
91
+
92
+
93
+ def main() -> None:
94
+ parser = argparse.ArgumentParser(description="Evaluate SQL error classifier")
95
+ parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
96
+ parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
97
+ parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
98
+ parser.add_argument("--sample-size", type=int, default=100_000)
99
+ parser.add_argument("--no-error-message", action="store_true")
100
+ parser.add_argument("--seed", type=int, default=42)
101
+ args = parser.parse_args()
102
+
103
+ evaluate(
104
+ data_path=args.data,
105
+ model_path=args.model,
106
+ output_dir=args.output,
107
+ sample_size=args.sample_size,
108
+ use_error_message=not args.no_error_message,
109
+ seed=args.seed,
110
+ )
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
src/exercises.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate playground exercises: schema + question + correct SQL."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from dataclasses import dataclass
7
+ from typing import Callable, Dict, List
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class Exercise:
12
+ schema: str
13
+ question: str
14
+ correct_query: str
15
+ tables: tuple[str, ...]
16
+ columns: tuple[str, ...]
17
+
18
+
19
+ def _fmt_schema(tables: Dict[str, List[str]]) -> str:
20
+ parts = [f"{name}({', '.join(cols)})" for name, cols in tables.items()]
21
+ return " | ".join(parts)
22
+
23
+
24
+ ExerciseBuilder = Callable[[random.Random], Exercise]
25
+
26
+ EXERCISE_BUILDERS: List[ExerciseBuilder] = []
27
+
28
+
29
+ def _register(builder: ExerciseBuilder) -> ExerciseBuilder:
30
+ EXERCISE_BUILDERS.append(builder)
31
+ return builder
32
+
33
+
34
+ @_register
35
+ def exercise_avg_by_department(rng: random.Random) -> Exercise:
36
+ tables = {
37
+ "students": ["id", "name", "email", "score", "department_id"],
38
+ "departments": ["id", "name", "city"],
39
+ }
40
+ return Exercise(
41
+ schema=_fmt_schema(tables),
42
+ question="What is the average score of students in each department?",
43
+ correct_query=(
44
+ "SELECT department_id, AVG(score) "
45
+ "FROM students GROUP BY department_id"
46
+ ),
47
+ tables=tuple(tables),
48
+ columns=("department_id", "score"),
49
+ )
50
+
51
+
52
+ @_register
53
+ def exercise_student_department_names(rng: random.Random) -> Exercise:
54
+ tables = {
55
+ "students": ["id", "name", "department_id"],
56
+ "departments": ["id", "name"],
57
+ }
58
+ return Exercise(
59
+ schema=_fmt_schema(tables),
60
+ question="List each student's name along with their department name.",
61
+ correct_query=(
62
+ "SELECT students.name, departments.name "
63
+ "FROM students "
64
+ "INNER JOIN departments ON students.department_id = departments.id"
65
+ ),
66
+ tables=tuple(tables),
67
+ columns=("name", "department_id"),
68
+ )
69
+
70
+
71
+ @_register
72
+ def exercise_high_scoring_students(rng: random.Random) -> Exercise:
73
+ threshold = rng.randint(70, 90)
74
+ tables = {"students": ["id", "name", "age", "score", "status"]}
75
+ return Exercise(
76
+ schema=_fmt_schema(tables),
77
+ question=(
78
+ f"Find names of students older than 18 with a score above {threshold}."
79
+ ),
80
+ correct_query=(
81
+ f"SELECT name FROM students "
82
+ f"WHERE age > 18 AND score > {threshold}"
83
+ ),
84
+ tables=tuple(tables),
85
+ columns=("name", "age", "score", "status"),
86
+ )
87
+
88
+
89
+ @_register
90
+ def exercise_unique_cities(rng: random.Random) -> Exercise:
91
+ tables = {"students": ["id", "name", "city", "country"]}
92
+ return Exercise(
93
+ schema=_fmt_schema(tables),
94
+ question="List the unique cities where students live.",
95
+ correct_query="SELECT DISTINCT city FROM students",
96
+ tables=tuple(tables),
97
+ columns=("city",),
98
+ )
99
+
100
+
101
+ @_register
102
+ def exercise_top_scorer(rng: random.Random) -> Exercise:
103
+ tables = {"students": ["id", "name", "score"], "grades": ["id", "score"]}
104
+ return Exercise(
105
+ schema=_fmt_schema(tables),
106
+ question="Find students whose score equals the highest score in the class.",
107
+ correct_query=(
108
+ "SELECT name FROM students "
109
+ "WHERE score = (SELECT MAX(score) FROM grades)"
110
+ ),
111
+ tables=tuple(tables),
112
+ columns=("name", "score"),
113
+ )
114
+
115
+
116
+ @_register
117
+ def exercise_departments_over_budget(rng: random.Random) -> Exercise:
118
+ budget = rng.randint(3, 8)
119
+ tables = {"employees": ["id", "name", "department_id", "salary"]}
120
+ return Exercise(
121
+ schema=_fmt_schema(tables),
122
+ question=f"Which departments have more than {budget} employees?",
123
+ correct_query=(
124
+ f"SELECT department_id, COUNT(*) AS cnt "
125
+ f"FROM employees GROUP BY department_id "
126
+ f"HAVING COUNT(*) > {budget}"
127
+ ),
128
+ tables=tuple(tables),
129
+ columns=("department_id", "salary"),
130
+ )
131
+
132
+
133
+ @_register
134
+ def exercise_recent_orders(rng: random.Random) -> Exercise:
135
+ year = rng.randint(2020, 2024)
136
+ tables = {"orders": ["id", "customer_id", "amount", "order_date", "status"]}
137
+ return Exercise(
138
+ schema=_fmt_schema(tables),
139
+ question=f"Show orders placed on or after January 1, {year}.",
140
+ correct_query=(
141
+ f"SELECT id, amount FROM orders "
142
+ f"WHERE order_date >= DATE '{year}-01-01'"
143
+ ),
144
+ tables=tuple(tables),
145
+ columns=("order_date", "amount", "status"),
146
+ )
147
+
148
+
149
+ @_register
150
+ def exercise_missing_email(rng: random.Random) -> Exercise:
151
+ tables = {"students": ["id", "name", "email", "phone"]}
152
+ return Exercise(
153
+ schema=_fmt_schema(tables),
154
+ question="Find students who have not provided an email address.",
155
+ correct_query="SELECT name FROM students WHERE email IS NULL",
156
+ tables=tuple(tables),
157
+ columns=("email", "name"),
158
+ )
159
+
160
+
161
+ @_register
162
+ def exercise_rank_by_score(rng: random.Random) -> Exercise:
163
+ tables = {"students": ["id", "name", "score", "department_id"]}
164
+ return Exercise(
165
+ schema=_fmt_schema(tables),
166
+ question="Rank students by score within each department.",
167
+ correct_query=(
168
+ "SELECT name, score, "
169
+ "RANK() OVER (PARTITION BY department_id ORDER BY score DESC) AS rnk "
170
+ "FROM students"
171
+ ),
172
+ tables=tuple(tables),
173
+ columns=("name", "score", "department_id"),
174
+ )
175
+
176
+
177
+ @_register
178
+ def exercise_course_enrollment_count(rng: random.Random) -> Exercise:
179
+ tables = {
180
+ "courses": ["id", "title"],
181
+ "enrollments": ["id", "course_id", "student_id"],
182
+ }
183
+ return Exercise(
184
+ schema=_fmt_schema(tables),
185
+ question="How many students are enrolled in each course?",
186
+ correct_query=(
187
+ "SELECT courses.title, COUNT(enrollments.student_id) AS enrolled "
188
+ "FROM courses "
189
+ "INNER JOIN enrollments ON courses.id = enrollments.course_id "
190
+ "GROUP BY courses.title"
191
+ ),
192
+ tables=tuple(tables),
193
+ columns=("title", "student_id", "course_id"),
194
+ )
195
+
196
+
197
+ @_register
198
+ def exercise_active_employees(rng: random.Random) -> Exercise:
199
+ tables = {"employees": ["id", "name", "salary", "status", "hire_date"]}
200
+ return Exercise(
201
+ schema=_fmt_schema(tables),
202
+ question="What is the total salary paid to active employees?",
203
+ correct_query=(
204
+ "SELECT SUM(salary) FROM employees WHERE status = 'active'"
205
+ ),
206
+ tables=tuple(tables),
207
+ columns=("salary", "status"),
208
+ )
209
+
210
+
211
+ @_register
212
+ def exercise_product_price_filter(rng: random.Random) -> Exercise:
213
+ lo, hi = rng.randint(10, 50), rng.randint(100, 500)
214
+ tables = {"products": ["id", "name", "price", "category"]}
215
+ return Exercise(
216
+ schema=_fmt_schema(tables),
217
+ question=f"List products priced between {lo} and {hi}.",
218
+ correct_query=(
219
+ f"SELECT name, price FROM products "
220
+ f"WHERE price BETWEEN {lo} AND {hi}"
221
+ ),
222
+ tables=tuple(tables),
223
+ columns=("name", "price", "category"),
224
+ )
225
+
226
+
227
+ def generate_exercise(rng: random.Random) -> Exercise:
228
+ return rng.choice(EXERCISE_BUILDERS)(rng)
src/generate_dataset.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate labeled SQL error dataset at scale."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import random
7
+ from concurrent.futures import ProcessPoolExecutor, as_completed
8
+ from pathlib import Path
9
+ from typing import Dict, List, Tuple
10
+
11
+ import pandas as pd
12
+ from tqdm import tqdm
13
+
14
+ from src.categories import load_categories
15
+ from src.exercises import generate_exercise
16
+ from src.sql_templates import ERROR_INJECTORS
17
+
18
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
19
+ DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
20
+
21
+
22
+ def generate_dataset(
23
+ total_samples: int = 1_000_000,
24
+ output_path: Path = DEFAULT_OUTPUT,
25
+ batch_size: int = 10_000,
26
+ workers: int = 8,
27
+ seed: int = 42,
28
+ ) -> Path:
29
+ categories = load_categories()
30
+ label_ids = [c.id for c in categories]
31
+ samples_per_class = total_samples // len(label_ids)
32
+ remainder = total_samples % len(label_ids)
33
+
34
+ # Balanced label schedule: each class gets equal share (+1 for first `remainder` classes)
35
+ schedule: List[int] = []
36
+ for cat in categories:
37
+ count = samples_per_class + (1 if cat.id < remainder else 0)
38
+ schedule.extend([cat.id] * count)
39
+ random.Random(seed).shuffle(schedule)
40
+
41
+ output_path.parent.mkdir(parents=True, exist_ok=True)
42
+ chunks: List[pd.DataFrame] = []
43
+ num_batches = (total_samples + batch_size - 1) // batch_size
44
+
45
+ with ProcessPoolExecutor(max_workers=workers) as executor:
46
+ futures = []
47
+ offset = 0
48
+ for batch_idx in range(num_batches):
49
+ current_batch = min(batch_size, total_samples - offset)
50
+ batch_labels = schedule[offset : offset + current_batch]
51
+ futures.append(
52
+ executor.submit(
53
+ _generate_batch_with_labels,
54
+ batch_labels,
55
+ seed + batch_idx,
56
+ )
57
+ )
58
+ offset += current_batch
59
+
60
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
61
+ rows = future.result()
62
+ chunks.append(pd.DataFrame(rows))
63
+
64
+ df = pd.concat(chunks, ignore_index=True)
65
+ df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
66
+ df.to_parquet(output_path, index=False)
67
+
68
+ print(f"Saved {len(df):,} samples to {output_path}")
69
+ print("\nClass distribution:")
70
+ print(df["label_name"].value_counts().sort_index().to_string())
71
+ return output_path
72
+
73
+
74
+ def _generate_batch_with_labels(label_ids: List[int], seed: int) -> List[Dict]:
75
+ rng = random.Random(seed)
76
+ categories = load_categories()
77
+ rows = []
78
+ for label_id in label_ids:
79
+ exercise = generate_exercise(rng)
80
+ injector = ERROR_INJECTORS[label_id]
81
+ query, error_message = injector(rng, exercise)
82
+ rows.append(
83
+ {
84
+ "schema": exercise.schema,
85
+ "question": exercise.question,
86
+ "correct_query": exercise.correct_query,
87
+ "query": query.strip(),
88
+ "error_message": error_message,
89
+ "label_id": label_id,
90
+ "label_name": categories[label_id].name,
91
+ }
92
+ )
93
+ return rows
94
+
95
+
96
+ def main() -> None:
97
+ parser = argparse.ArgumentParser(description="Generate labeled SQL error dataset")
98
+ parser.add_argument("--samples", type=int, default=1_000_000, help="Total samples")
99
+ parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
100
+ parser.add_argument("--batch-size", type=int, default=10_000)
101
+ parser.add_argument("--workers", type=int, default=8)
102
+ parser.add_argument("--seed", type=int, default=42)
103
+ args = parser.parse_args()
104
+
105
+ generate_dataset(
106
+ total_samples=args.samples,
107
+ output_path=args.output,
108
+ batch_size=args.batch_size,
109
+ workers=args.workers,
110
+ seed=args.seed,
111
+ )
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
src/hf_eval_codebert.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate a trained CodeBERT cross-encoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import pandas as pd
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
11
+
12
+ from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe
13
+ from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics
14
+
15
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
16
+ DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
17
+ DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder"
18
+
19
+
20
+ def evaluate(
21
+ model_dir: Path = DEFAULT_MODEL,
22
+ data_path: Path = DEFAULT_DATA,
23
+ sample_size: int = 10_000,
24
+ threshold: float = 0.5,
25
+ seed: int = 42,
26
+ ) -> dict:
27
+ df = normalize_dataframe(pd.read_parquet(data_path))
28
+ if len(df) > sample_size:
29
+ df = df.sample(n=sample_size, random_state=seed)
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
32
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
33
+ dataset = SQLCodeBERTDataset(df, tokenizer)
34
+
35
+ trainer_kwargs = dict(
36
+ model=model,
37
+ data_collator=SQLCodeBERTDataCollator(tokenizer),
38
+ compute_metrics=build_compute_metrics(threshold=threshold),
39
+ )
40
+ try:
41
+ trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
42
+ except TypeError:
43
+ trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)
44
+
45
+ output = trainer.predict(dataset)
46
+ metrics = compute_multilabel_metrics(
47
+ output.predictions, output.label_ids, threshold=threshold
48
+ )
49
+ print(json.dumps(metrics, indent=2))
50
+ return metrics
51
+
52
+
53
+ def main() -> None:
54
+ parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier")
55
+ parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL)
56
+ parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
57
+ parser.add_argument("--sample-size", type=int, default=10_000)
58
+ parser.add_argument("--threshold", type=float, default=0.5)
59
+ args = parser.parse_args()
60
+ evaluate(
61
+ model_dir=args.model_dir,
62
+ data_path=args.data,
63
+ sample_size=args.sample_size,
64
+ threshold=args.threshold,
65
+ )
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
src/hf_metrics.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation metrics for multi-label SQL error classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict
6
+
7
+ import numpy as np
8
+ from sklearn.metrics import (
9
+ accuracy_score,
10
+ f1_score,
11
+ hamming_loss,
12
+ precision_score,
13
+ recall_score,
14
+ )
15
+
16
+
17
+ def sigmoid(x: np.ndarray) -> np.ndarray:
18
+ return 1.0 / (1.0 + np.exp(-x))
19
+
20
+
21
+ def compute_multilabel_metrics(
22
+ logits: np.ndarray,
23
+ labels: np.ndarray,
24
+ threshold: float = 0.5,
25
+ ) -> Dict[str, float]:
26
+ probs = sigmoid(logits)
27
+ preds = (probs >= threshold).astype(int)
28
+ labels = labels.astype(int)
29
+
30
+ return {
31
+ "accuracy": float(accuracy_score(labels, preds)),
32
+ "f1_macro": float(f1_score(labels, preds, average="macro", zero_division=0)),
33
+ "f1_micro": float(f1_score(labels, preds, average="micro", zero_division=0)),
34
+ "precision_macro": float(
35
+ precision_score(labels, preds, average="macro", zero_division=0)
36
+ ),
37
+ "recall_macro": float(
38
+ recall_score(labels, preds, average="macro", zero_division=0)
39
+ ),
40
+ "hamming_loss": float(hamming_loss(labels, preds)),
41
+ "subset_accuracy": float((preds == labels).all(axis=1).mean()),
42
+ }
43
+
44
+
45
+ def build_compute_metrics(threshold: float = 0.5):
46
+ """Factory for Hugging Face Trainer compute_metrics callback."""
47
+
48
+ def compute_metrics(eval_pred) -> Dict[str, float]:
49
+ logits, labels = eval_pred
50
+ return compute_multilabel_metrics(logits, labels, threshold=threshold)
51
+
52
+ return compute_metrics
src/hf_predict_codebert.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference for CodeBERT SQL error cross-encoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+ from typing import List, Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
13
+
14
+ from src.codebert_formatting import format_cross_encoder_input
15
+ from src.codebert_labels import load_codebert_labels, multihot_to_label_names
16
+ from src.hf_metrics import sigmoid
17
+
18
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
19
+ DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
20
+
21
+
22
+ class CodeBERTSQLErrorClassifier:
23
+ """CodeBERT cross-encoder inference wrapper."""
24
+
25
+ def __init__(
26
+ self,
27
+ model_dir: Union[str, Path] = DEFAULT_MODEL_DIR,
28
+ threshold: float = 0.5,
29
+ device: Optional[str] = None,
30
+ ):
31
+ self.model_dir = Path(model_dir)
32
+ self.threshold = threshold
33
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ config_path = self.model_dir / "label_config.json"
36
+ if config_path.exists():
37
+ with open(config_path) as f:
38
+ cfg = json.load(f)
39
+ self.label_list = cfg.get("labels", load_codebert_labels())
40
+ self.threshold = cfg.get("threshold", threshold)
41
+ self.max_length = cfg.get("max_length", 512)
42
+ else:
43
+ self.label_list = load_codebert_labels()
44
+ self.max_length = 512
45
+
46
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
47
+ self.model = AutoModelForSequenceClassification.from_pretrained(
48
+ self.model_dir
49
+ ).to(self.device)
50
+ self.model.eval()
51
+
52
+ def predict(
53
+ self,
54
+ question: str,
55
+ schema: str,
56
+ student_sql: str,
57
+ correct_sql: str,
58
+ threshold: Optional[float] = None,
59
+ top_k: int = 5,
60
+ ) -> dict:
61
+ text = format_cross_encoder_input(
62
+ question=question,
63
+ schema=schema,
64
+ student_sql=student_sql,
65
+ correct_sql=correct_sql,
66
+ )
67
+ encoded = self.tokenizer(
68
+ text,
69
+ truncation=True,
70
+ max_length=self.max_length,
71
+ padding=True,
72
+ return_tensors="pt",
73
+ ).to(self.device)
74
+
75
+ with torch.no_grad():
76
+ logits = self.model(**encoded).logits.cpu().numpy()[0]
77
+
78
+ probs = sigmoid(logits)
79
+ thr = threshold if threshold is not None else self.threshold
80
+ predicted = multihot_to_label_names(probs, self.label_list, threshold=thr)
81
+
82
+ ranked = sorted(
83
+ zip(self.label_list, probs.tolist()),
84
+ key=lambda x: x[1],
85
+ reverse=True,
86
+ )[:top_k]
87
+
88
+ return {
89
+ "error_labels": predicted,
90
+ "probabilities": {name: float(p) for name, p in ranked},
91
+ "top_k": [
92
+ {"label": name, "probability": float(p)} for name, p in ranked
93
+ ],
94
+ "primary_label": ranked[0][0],
95
+ "primary_confidence": float(ranked[0][1]),
96
+ }
97
+
98
+ def predict_batch(
99
+ self,
100
+ examples: List[dict],
101
+ batch_size: int = 16,
102
+ ) -> List[dict]:
103
+ results = []
104
+ for i in range(0, len(examples), batch_size):
105
+ chunk = examples[i : i + batch_size]
106
+ texts = [
107
+ format_cross_encoder_input(
108
+ question=x["question"],
109
+ schema=x["schema"],
110
+ student_sql=x["student_sql"],
111
+ correct_sql=x["correct_sql"],
112
+ )
113
+ for x in chunk
114
+ ]
115
+ encoded = self.tokenizer(
116
+ texts,
117
+ truncation=True,
118
+ max_length=self.max_length,
119
+ padding=True,
120
+ return_tensors="pt",
121
+ ).to(self.device)
122
+
123
+ with torch.no_grad():
124
+ logits = self.model(**encoded).logits.cpu().numpy()
125
+
126
+ for j, row in enumerate(logits):
127
+ probs = sigmoid(row)
128
+ results.append(
129
+ {
130
+ "error_labels": multihot_to_label_names(
131
+ probs, self.label_list, self.threshold
132
+ ),
133
+ "primary_label": self.label_list[int(np.argmax(probs))],
134
+ "primary_confidence": float(np.max(probs)),
135
+ }
136
+ )
137
+ return results
138
+
139
+
140
+ def main() -> None:
141
+ parser = argparse.ArgumentParser(description="CodeBERT SQL error inference")
142
+ parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR)
143
+ parser.add_argument("--question", type=str, required=True)
144
+ parser.add_argument("--schema", type=str, required=True)
145
+ parser.add_argument("--student-sql", type=str, required=True)
146
+ parser.add_argument("--correct-sql", type=str, required=True)
147
+ parser.add_argument("--threshold", type=float, default=0.5)
148
+ args = parser.parse_args()
149
+
150
+ clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold)
151
+ result = clf.predict(
152
+ question=args.question,
153
+ schema=args.schema,
154
+ student_sql=args.student_sql,
155
+ correct_sql=args.correct_sql,
156
+ )
157
+ print(json.dumps(result, indent=2))
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
src/hf_train_codebert.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train CodeBERT cross-encoder for SQL error classification with HF Trainer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from transformers import (
13
+ AutoModelForSequenceClassification,
14
+ AutoTokenizer,
15
+ EarlyStoppingCallback,
16
+ Trainer,
17
+ TrainingArguments,
18
+ )
19
+
20
+ from src.codebert_dataset import (
21
+ SQLCodeBERTDataCollator,
22
+ prepare_datasets,
23
+ )
24
+ from src.codebert_labels import load_codebert_labels
25
+ from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics
26
+
27
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
28
+ DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
29
+ DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "codebert-cross-encoder"
30
+ DEFAULT_MODEL = "microsoft/codebert-base"
31
+
32
+
33
+ def train(
34
+ data_path: Path | None = DEFAULT_DATA,
35
+ dataframe: pd.DataFrame | None = None,
36
+ output_dir: Path = DEFAULT_OUTPUT,
37
+ model_name: str = DEFAULT_MODEL,
38
+ epochs: float = 3.0,
39
+ batch_size: int = 16,
40
+ eval_batch_size: int = 32,
41
+ learning_rate: float = 2e-5,
42
+ weight_decay: float = 0.01,
43
+ warmup_ratio: float = 0.06,
44
+ max_length: int = 512,
45
+ max_samples: int | None = None,
46
+ test_size: float = 0.1,
47
+ val_size: float = 0.1,
48
+ threshold: float = 0.5,
49
+ seed: int = 42,
50
+ push_to_hub: bool = False,
51
+ hub_model_id: str | None = None,
52
+ fp16: bool = False,
53
+ save_strategy: str = "no",
54
+ hub_token: str | None = None,
55
+ ) -> dict:
56
+ if dataframe is not None:
57
+ df = dataframe.copy()
58
+ print(f"Loaded dataframe with {len(df):,} rows")
59
+ elif data_path is not None:
60
+ print(f"Loading dataset from {data_path}...")
61
+ df = pd.read_parquet(data_path)
62
+ else:
63
+ raise ValueError("Either data_path or dataframe must be provided")
64
+ if max_samples and len(df) > max_samples:
65
+ df = df.sample(n=max_samples, random_state=seed)
66
+
67
+ label_list = load_codebert_labels()
68
+ num_labels = len(label_list)
69
+ print(f"Labels ({num_labels}): {label_list}")
70
+ print(f"Samples: {len(df):,}")
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
73
+ model = AutoModelForSequenceClassification.from_pretrained(
74
+ model_name,
75
+ num_labels=num_labels,
76
+ problem_type="multi_label_classification",
77
+ id2label={i: name for i, name in enumerate(label_list)},
78
+ label2id={name: i for i, name in enumerate(label_list)},
79
+ )
80
+
81
+ train_ds, val_ds, test_ds = prepare_datasets(
82
+ df,
83
+ tokenizer,
84
+ test_size=test_size,
85
+ val_size=val_size,
86
+ max_length=max_length,
87
+ seed=seed,
88
+ )
89
+ print(f"Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
90
+
91
+ output_dir.mkdir(parents=True, exist_ok=True)
92
+ label_info = {
93
+ "labels": label_list,
94
+ "model_name": model_name,
95
+ "architecture": "codebert-cross-encoder",
96
+ "input_format": "QUESTION + SCHEMA + STUDENT_SQL + CORRECT_SQL",
97
+ "max_length": max_length,
98
+ "threshold": threshold,
99
+ }
100
+ with open(output_dir / "label_config.json", "w") as f:
101
+ json.dump(label_info, f, indent=2)
102
+
103
+ training_args = TrainingArguments(
104
+ output_dir=str(output_dir),
105
+ num_train_epochs=epochs,
106
+ per_device_train_batch_size=batch_size,
107
+ per_device_eval_batch_size=eval_batch_size,
108
+ learning_rate=learning_rate,
109
+ weight_decay=weight_decay,
110
+ warmup_ratio=warmup_ratio,
111
+ eval_strategy="epoch",
112
+ save_strategy=save_strategy,
113
+ logging_strategy="steps",
114
+ logging_steps=50,
115
+ load_best_model_at_end=save_strategy == "epoch",
116
+ metric_for_best_model="f1_macro",
117
+ greater_is_better=True,
118
+ save_total_limit=1,
119
+ seed=seed,
120
+ report_to="none",
121
+ fp16=fp16 and torch.cuda.is_available(),
122
+ push_to_hub=push_to_hub,
123
+ hub_model_id=hub_model_id,
124
+ hub_token=hub_token,
125
+ )
126
+
127
+ callbacks = []
128
+ if save_strategy == "epoch":
129
+ callbacks.append(EarlyStoppingCallback(early_stopping_patience=2))
130
+
131
+ trainer_kwargs = dict(
132
+ model=model,
133
+ args=training_args,
134
+ train_dataset=train_ds,
135
+ eval_dataset=val_ds,
136
+ data_collator=SQLCodeBERTDataCollator(tokenizer),
137
+ compute_metrics=build_compute_metrics(threshold=threshold),
138
+ callbacks=callbacks,
139
+ )
140
+ try:
141
+ trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
142
+ except TypeError:
143
+ trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)
144
+
145
+ print("Starting CodeBERT cross-encoder training...")
146
+ train_result = trainer.train()
147
+
148
+ print("Evaluating on validation set...")
149
+ val_metrics = trainer.evaluate()
150
+
151
+ print("Evaluating on held-out test set...")
152
+ test_output = trainer.predict(test_ds)
153
+ test_metrics = compute_multilabel_metrics(
154
+ test_output.predictions,
155
+ test_output.label_ids,
156
+ threshold=threshold,
157
+ )
158
+
159
+ trainer.save_model(str(output_dir))
160
+ tokenizer.save_pretrained(str(output_dir))
161
+
162
+ metrics = {
163
+ "train_samples": len(train_ds),
164
+ "val_samples": len(val_ds),
165
+ "test_samples": len(test_ds),
166
+ "train_runtime": train_result.metrics.get("train_runtime"),
167
+ "validation": val_metrics,
168
+ "test": test_metrics,
169
+ }
170
+ with open(output_dir / "metrics.json", "w") as f:
171
+ json.dump(metrics, f, indent=2, default=float)
172
+
173
+ print(f"\nValidation F1 (macro): {val_metrics.get('eval_f1_macro', 0):.4f}")
174
+ print(f"Test F1 (macro): {test_metrics['f1_macro']:.4f}")
175
+ print(f"Test subset accuracy: {test_metrics['subset_accuracy']:.4f}")
176
+ print(f"Model saved to {output_dir}")
177
+ return metrics
178
+
179
+
180
+ def main() -> None:
181
+ parser = argparse.ArgumentParser(
182
+ description="Train CodeBERT cross-encoder with Hugging Face Trainer"
183
+ )
184
+ parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
185
+ parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT)
186
+ parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL)
187
+ parser.add_argument("--epochs", type=float, default=3.0)
188
+ parser.add_argument("--batch-size", type=int, default=16)
189
+ parser.add_argument("--eval-batch-size", type=int, default=32)
190
+ parser.add_argument("--learning-rate", type=float, default=2e-5)
191
+ parser.add_argument("--max-length", type=int, default=512)
192
+ parser.add_argument("--max-samples", type=int, default=None)
193
+ parser.add_argument("--threshold", type=float, default=0.5)
194
+ parser.add_argument("--seed", type=int, default=42)
195
+ parser.add_argument("--push-to-hub", action="store_true")
196
+ parser.add_argument("--hub-model-id", type=str, default=None)
197
+ parser.add_argument("--fp16", action="store_true")
198
+ parser.add_argument(
199
+ "--save-strategy",
200
+ choices=["no", "epoch"],
201
+ default="no",
202
+ help="Use 'no' to save only final model (saves disk space)",
203
+ )
204
+ args = parser.parse_args()
205
+
206
+ train(
207
+ data_path=args.data,
208
+ output_dir=args.output_dir,
209
+ model_name=args.model_name,
210
+ epochs=args.epochs,
211
+ batch_size=args.batch_size,
212
+ eval_batch_size=args.eval_batch_size,
213
+ learning_rate=args.learning_rate,
214
+ max_length=args.max_length,
215
+ max_samples=args.max_samples,
216
+ threshold=args.threshold,
217
+ seed=args.seed,
218
+ push_to_hub=args.push_to_hub,
219
+ hub_model_id=args.hub_model_id,
220
+ fp16=args.fp16,
221
+ save_strategy=args.save_strategy,
222
+ )
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()
src/huggingface.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Hub integration for the SQL error classifier."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Union
8
+
9
+ import joblib
10
+
11
+ from src.categories import load_categories
12
+ from src.cross_encoder_model import CrossEncoderClassifier
13
+ from src.model import DEFAULT_ENCODER, load_model
14
+ from src.multi_tower_model import MultiTowerClassifier, QueryContext
15
+
16
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
17
+ CONFIG_NAME = "config.json"
18
+ CLASSIFIER_NAME = "classifier.joblib"
19
+ CATEGORIES_NAME = "categories.json"
20
+
21
+ SUPPORTED_CONTEXT_MODELS = (CrossEncoderClassifier, MultiTowerClassifier)
22
+
23
+
24
+ class SQLLErrorClassifierHF:
25
+ """
26
+ Hugging Face–compatible wrapper for SQL error classifiers.
27
+
28
+ Usage:
29
+ clf = SQLLErrorClassifierHF.from_pretrained("username/sql-error-classifier")
30
+ result = clf.predict(
31
+ question="...", schema="...", correct_query="...", student_query="..."
32
+ )
33
+ """
34
+
35
+ def __init__(self, model, label_map: Dict[int, str]):
36
+ self.model = model
37
+ self.label_map = label_map
38
+
39
+ def predict(
40
+ self,
41
+ question: str,
42
+ schema: str,
43
+ correct_query: str,
44
+ student_query: str,
45
+ error_message: Optional[str] = None,
46
+ top_k: int = 3,
47
+ ) -> Dict[str, Any]:
48
+ ctx = QueryContext(
49
+ question=question,
50
+ schema=schema,
51
+ correct_query=correct_query,
52
+ student_query=student_query,
53
+ error_message=error_message,
54
+ )
55
+ proba = self.model.predict_proba([ctx])[0]
56
+ classes = self.model.classes_
57
+ ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
58
+ best_id = int(ranked[0][0])
59
+
60
+ diagnostics: Dict[str, Any] = {}
61
+ if isinstance(self.model, CrossEncoderClassifier):
62
+ diagnostics["pair_scores"] = self.model.explain_pair_scores(ctx)
63
+ elif isinstance(self.model, MultiTowerClassifier):
64
+ diagnostics["similarities"] = self.model.explain_similarities(ctx)
65
+
66
+ return {
67
+ "label_id": best_id,
68
+ "label_name": self.label_map[best_id],
69
+ "confidence": float(ranked[0][1]),
70
+ "top_k": [
71
+ {
72
+ "label_id": int(cls),
73
+ "label_name": self.label_map[int(cls)],
74
+ "confidence": float(p),
75
+ }
76
+ for cls, p in ranked[:top_k]
77
+ ],
78
+ **diagnostics,
79
+ }
80
+
81
+ def save_pretrained(self, save_directory: Union[str, Path]) -> Path:
82
+ """Save model artifacts in Hugging Face Hub layout."""
83
+ save_dir = Path(save_directory)
84
+ save_dir.mkdir(parents=True, exist_ok=True)
85
+
86
+ if isinstance(self.model, CrossEncoderClassifier):
87
+ payload = {
88
+ "model_type": "cross_encoder",
89
+ "cross_encoder_name": self.model.cross_encoder_name,
90
+ "batch_size": self.model.batch_size,
91
+ "max_length": self.model.max_length,
92
+ "scaler": self.model.scaler,
93
+ "classifier": self.model.clf,
94
+ "classes_": self.model.classes_,
95
+ }
96
+ config = {
97
+ "model_type": "cross_encoder",
98
+ "architecture": "cross-encoder-pairwise",
99
+ "cross_encoder_name": self.model.cross_encoder_name,
100
+ "batch_size": self.model.batch_size,
101
+ "num_labels": len(self.label_map),
102
+ "task": "sql-error-classification",
103
+ }
104
+ elif isinstance(self.model, MultiTowerClassifier):
105
+ payload = {
106
+ "model_type": "multi_tower",
107
+ "encoder_name": self.model.encoder_name,
108
+ "batch_size": self.model.batch_size,
109
+ "scaler": self.model.scaler,
110
+ "classifier": self.model.clf,
111
+ "classes_": self.model.classes_,
112
+ }
113
+ config = {
114
+ "model_type": "multi_tower",
115
+ "architecture": "multi-tower-semantic-comparison",
116
+ "encoder_name": self.model.encoder_name,
117
+ "batch_size": self.model.batch_size,
118
+ "num_labels": len(self.label_map),
119
+ "task": "sql-error-classification",
120
+ }
121
+ else:
122
+ raise ValueError("Only cross_encoder and multi_tower models can be published")
123
+
124
+ joblib.dump(payload, save_dir / CLASSIFIER_NAME)
125
+ with open(save_dir / CONFIG_NAME, "w") as f:
126
+ json.dump(config, f, indent=2)
127
+
128
+ categories = load_categories()
129
+ cat_data = [
130
+ {"id": c.id, "name": c.name, "description": c.description}
131
+ for c in categories
132
+ ]
133
+ with open(save_dir / CATEGORIES_NAME, "w") as f:
134
+ json.dump(cat_data, f, indent=2)
135
+
136
+ return save_dir
137
+
138
+ @classmethod
139
+ def from_pretrained(
140
+ cls,
141
+ pretrained_model_name_or_path: Union[str, Path],
142
+ *,
143
+ token: Optional[str] = None,
144
+ ) -> "SQLLErrorClassifierHF":
145
+ """Load from a local directory or Hugging Face Hub repo."""
146
+ path = _resolve_model_path(pretrained_model_name_or_path, token=token)
147
+
148
+ with open(path / CONFIG_NAME) as f:
149
+ config = json.load(f)
150
+
151
+ with open(path / CATEGORIES_NAME) as f:
152
+ categories = json.load(f)
153
+ label_map = {c["id"]: c["name"] for c in categories}
154
+
155
+ obj = joblib.load(path / CLASSIFIER_NAME)
156
+ model_type = config.get("model_type", obj.get("model_type"))
157
+
158
+ if model_type == "cross_encoder":
159
+ model = CrossEncoderClassifier(
160
+ cross_encoder_name=obj.get(
161
+ "cross_encoder_name",
162
+ config.get("cross_encoder_name", "cross-encoder/ms-marco-MiniLM-L6-v2"),
163
+ ),
164
+ batch_size=obj.get("batch_size", 32),
165
+ max_length=obj.get("max_length", 512),
166
+ )
167
+ model.scaler = obj["scaler"]
168
+ model.clf = obj["classifier"]
169
+ model.classes_ = obj.get("classes_", obj["classifier"].classes_)
170
+ else:
171
+ model = MultiTowerClassifier(
172
+ encoder_name=obj.get("encoder_name", config.get("encoder_name", DEFAULT_ENCODER)),
173
+ batch_size=obj.get("batch_size", 256),
174
+ )
175
+ model.scaler = obj["scaler"]
176
+ model.clf = obj["classifier"]
177
+ model.classes_ = obj.get("classes_", obj["classifier"].classes_)
178
+
179
+ return cls(model=model, label_map=label_map)
180
+
181
+
182
+ def _resolve_model_path(
183
+ pretrained_model_name_or_path: Union[str, Path],
184
+ token: Optional[str] = None,
185
+ ) -> Path:
186
+ local = Path(pretrained_model_name_or_path)
187
+ if local.exists() and (local / CONFIG_NAME).exists():
188
+ return local
189
+
190
+ from huggingface_hub import snapshot_download
191
+
192
+ return Path(
193
+ snapshot_download(
194
+ repo_id=str(pretrained_model_name_or_path),
195
+ token=token,
196
+ allow_patterns=[CONFIG_NAME, CLASSIFIER_NAME, CATEGORIES_NAME],
197
+ )
198
+ )
199
+
200
+
201
+ def package_for_hub(model_path: Path, output_dir: Path) -> Path:
202
+ """Convert a local joblib model into HF Hub layout."""
203
+ sklearn_model = load_model(model_path)
204
+ if not isinstance(sklearn_model, SUPPORTED_CONTEXT_MODELS):
205
+ raise ValueError(
206
+ "Only cross_encoder and multi_tower models can be published to Hugging Face Hub"
207
+ )
208
+ label_map = {c.id: c.name for c in load_categories()}
209
+ wrapper = SQLLErrorClassifierHF(model=sklearn_model, label_map=label_map)
210
+ return wrapper.save_pretrained(output_dir)
src/model.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQL error classifiers: TF-IDF baseline and MiniLM embedding model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import List, Literal, Optional, Protocol, Union
8
+
9
+ import joblib
10
+ import numpy as np
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.linear_model import SGDClassifier
13
+ from sklearn.pipeline import FeatureUnion, Pipeline
14
+
15
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
16
+ DEFAULT_MODEL_PATH = PROJECT_ROOT / "models" / "sql_error_classifier.joblib"
17
+ DEFAULT_ENCODER = "sentence-transformers/all-MiniLM-L6-v2"
18
+
19
+ ModelType = Literal["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"]
20
+
21
+
22
+ class TextClassifier(Protocol):
23
+ classes_: np.ndarray
24
+
25
+ def fit(self, texts: List[str], y: np.ndarray) -> "TextClassifier": ...
26
+ def predict(self, texts: List[str]) -> np.ndarray: ...
27
+ def predict_proba(self, texts: List[str]) -> np.ndarray: ...
28
+
29
+
30
+ def combine_features(
31
+ queries: List[str],
32
+ error_messages: Optional[List[str]] = None,
33
+ schemas: Optional[List[str]] = None,
34
+ questions: Optional[List[str]] = None,
35
+ ) -> List[str]:
36
+ """Fuse question, schema, query, and optional error message."""
37
+ texts: List[str] = []
38
+ for i, query in enumerate(queries):
39
+ parts: List[str] = []
40
+ if questions and questions[i]:
41
+ parts.append(f"QUESTION: {questions[i]}")
42
+ if schemas and schemas[i]:
43
+ parts.append(f"SCHEMA: {schemas[i]}")
44
+ parts.append(f"QUERY: {query}")
45
+ if error_messages and error_messages[i]:
46
+ parts.append(f"ERROR: {error_messages[i]}")
47
+ texts.append(" ".join(parts))
48
+ return texts
49
+
50
+
51
+ def _build_text_features() -> FeatureUnion:
52
+ return FeatureUnion(
53
+ [
54
+ (
55
+ "word",
56
+ TfidfVectorizer(
57
+ analyzer="word",
58
+ ngram_range=(1, 2),
59
+ max_features=30_000,
60
+ sublinear_tf=True,
61
+ strip_accents="unicode",
62
+ token_pattern=r"(?u)\b\w+\b|(?<=[=<>!])\S+",
63
+ ),
64
+ ),
65
+ (
66
+ "char",
67
+ TfidfVectorizer(
68
+ analyzer="char_wb",
69
+ ngram_range=(2, 5),
70
+ max_features=20_000,
71
+ sublinear_tf=True,
72
+ ),
73
+ ),
74
+ ]
75
+ )
76
+
77
+
78
+ def build_tfidf_classifier() -> Pipeline:
79
+ """Bag-of-words baseline. Fast but no deep semantic understanding."""
80
+ clf = SGDClassifier(
81
+ loss="log_loss",
82
+ penalty="l2",
83
+ alpha=1e-5,
84
+ max_iter=1000,
85
+ tol=1e-3,
86
+ class_weight="balanced",
87
+ random_state=42,
88
+ )
89
+ return Pipeline([("tfidf", _build_text_features()), ("clf", clf)])
90
+
91
+
92
+ class EmbeddingClassifier:
93
+ """
94
+ MiniLM sentence embeddings + linear classifier.
95
+
96
+ Understands question intent (e.g. 'average' vs wrong aggregate) because
97
+ the encoder models full sentence context, not isolated word counts.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ encoder_name: str = DEFAULT_ENCODER,
103
+ batch_size: int = 256,
104
+ ):
105
+ self.encoder_name = encoder_name
106
+ self.batch_size = batch_size
107
+ self.encoder = None
108
+ self.clf = SGDClassifier(
109
+ loss="log_loss",
110
+ penalty="l2",
111
+ alpha=1e-4,
112
+ max_iter=1000,
113
+ tol=1e-3,
114
+ class_weight="balanced",
115
+ random_state=42,
116
+ )
117
+ self.classes_: Optional[np.ndarray] = None
118
+
119
+ def _load_encoder(self):
120
+ if self.encoder is None:
121
+ from sentence_transformers import SentenceTransformer
122
+
123
+ self.encoder = SentenceTransformer(self.encoder_name)
124
+
125
+ def encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
126
+ self._load_encoder()
127
+ return self.encoder.encode(
128
+ texts,
129
+ batch_size=self.batch_size,
130
+ show_progress_bar=show_progress,
131
+ convert_to_numpy=True,
132
+ )
133
+
134
+ def fit(self, texts: List[str], y: np.ndarray) -> "EmbeddingClassifier":
135
+ X = self.encode(texts, show_progress=True)
136
+ self.clf.fit(X, y)
137
+ self.classes_ = self.clf.classes_
138
+ return self
139
+
140
+ def predict(self, texts: List[str]) -> np.ndarray:
141
+ return self.clf.predict(self.encode(texts))
142
+
143
+ def predict_proba(self, texts: List[str]) -> np.ndarray:
144
+ return self.clf.predict_proba(self.encode(texts))
145
+
146
+
147
+ def build_classifier(
148
+ model_type: ModelType = "cross_encoder",
149
+ ) -> Union[
150
+ Pipeline,
151
+ EmbeddingClassifier,
152
+ "MultiTowerClassifier",
153
+ "CrossEncoderClassifier",
154
+ "FineTunedCrossEncoderClassifier",
155
+ ]:
156
+ if model_type == "tfidf":
157
+ return build_tfidf_classifier()
158
+ if model_type == "minilm":
159
+ return EmbeddingClassifier()
160
+ if model_type == "multi_tower":
161
+ from src.multi_tower_model import MultiTowerClassifier
162
+
163
+ return MultiTowerClassifier()
164
+ if model_type == "cross_encoder":
165
+ from src.cross_encoder_model import CrossEncoderClassifier
166
+
167
+ return CrossEncoderClassifier()
168
+ if model_type == "cross_encoder_ft":
169
+ from src.cross_encoder_model import FineTunedCrossEncoderClassifier
170
+
171
+ return FineTunedCrossEncoderClassifier()
172
+ raise ValueError(f"Unknown model_type: {model_type}")
173
+
174
+
175
+ def save_model(
176
+ model: Union[
177
+ Pipeline,
178
+ EmbeddingClassifier,
179
+ "MultiTowerClassifier",
180
+ "CrossEncoderClassifier",
181
+ "FineTunedCrossEncoderClassifier",
182
+ ],
183
+ path: Path = DEFAULT_MODEL_PATH,
184
+ model_type: ModelType = "cross_encoder",
185
+ ) -> Path:
186
+ from src.cross_encoder_model import (
187
+ CrossEncoderClassifier,
188
+ FineTunedCrossEncoderClassifier,
189
+ )
190
+ from src.multi_tower_model import MultiTowerClassifier
191
+
192
+ path.parent.mkdir(parents=True, exist_ok=True)
193
+ if isinstance(model, FineTunedCrossEncoderClassifier):
194
+ ft_path = path if path.is_dir() or str(path).endswith("/") else path.with_suffix(".ce")
195
+ if ft_path.suffix == ".joblib":
196
+ ft_path = ft_path.with_suffix(".ce")
197
+ model.save(ft_path)
198
+ meta_path = ft_path / "meta.json" if ft_path.is_dir() else path.with_suffix(".meta.json")
199
+ with open(meta_path, "w") as f:
200
+ json.dump({"model_type": "cross_encoder_ft", "path": str(ft_path)}, f, indent=2)
201
+ return ft_path
202
+ if isinstance(model, CrossEncoderClassifier):
203
+ payload = {
204
+ "model_type": "cross_encoder",
205
+ "cross_encoder_name": model.cross_encoder_name,
206
+ "batch_size": model.batch_size,
207
+ "max_length": model.max_length,
208
+ "scaler": model.scaler,
209
+ "classifier": model.clf,
210
+ "classes_": model.classes_,
211
+ }
212
+ joblib.dump(payload, path)
213
+ meta_path = path.with_suffix(".meta.json")
214
+ with open(meta_path, "w") as f:
215
+ json.dump(
216
+ {
217
+ "model_type": "cross_encoder",
218
+ "cross_encoder_name": model.cross_encoder_name,
219
+ },
220
+ f,
221
+ indent=2,
222
+ )
223
+ elif isinstance(model, MultiTowerClassifier):
224
+ payload = {
225
+ "model_type": "multi_tower",
226
+ "encoder_name": model.encoder_name,
227
+ "batch_size": model.batch_size,
228
+ "scaler": model.scaler,
229
+ "classifier": model.clf,
230
+ "classes_": model.classes_,
231
+ }
232
+ joblib.dump(payload, path)
233
+ meta_path = path.with_suffix(".meta.json")
234
+ with open(meta_path, "w") as f:
235
+ json.dump(
236
+ {"model_type": "multi_tower", "encoder_name": model.encoder_name},
237
+ f,
238
+ indent=2,
239
+ )
240
+ elif isinstance(model, EmbeddingClassifier):
241
+ payload = {
242
+ "model_type": model_type,
243
+ "encoder_name": model.encoder_name,
244
+ "batch_size": model.batch_size,
245
+ "classifier": model.clf,
246
+ "classes_": model.classes_,
247
+ }
248
+ joblib.dump(payload, path)
249
+ meta_path = path.with_suffix(".meta.json")
250
+ with open(meta_path, "w") as f:
251
+ json.dump(
252
+ {"model_type": model_type, "encoder_name": model.encoder_name},
253
+ f,
254
+ indent=2,
255
+ )
256
+ else:
257
+ joblib.dump({"model_type": "tfidf", "pipeline": model}, path)
258
+ return path
259
+
260
+
261
+ def load_model(
262
+ path: Path = DEFAULT_MODEL_PATH,
263
+ ) -> Union[
264
+ Pipeline,
265
+ EmbeddingClassifier,
266
+ "MultiTowerClassifier",
267
+ "CrossEncoderClassifier",
268
+ "FineTunedCrossEncoderClassifier",
269
+ ]:
270
+ from src.cross_encoder_model import (
271
+ CrossEncoderClassifier,
272
+ FineTunedCrossEncoderClassifier,
273
+ )
274
+ from src.multi_tower_model import MultiTowerClassifier
275
+
276
+ path = Path(path)
277
+
278
+ # Fine-tuned cross-encoder saved as directory
279
+ ce_path = path.with_suffix(".ce") if path.suffix == ".joblib" else path
280
+ if ce_path.exists() and (ce_path / "config.json").exists():
281
+ return FineTunedCrossEncoderClassifier.load(ce_path)
282
+
283
+ meta_path = path.with_suffix(".meta.json")
284
+ if meta_path.exists():
285
+ with open(meta_path) as f:
286
+ meta = json.load(f)
287
+ if meta.get("model_type") == "cross_encoder_ft":
288
+ ft_path = Path(meta.get("path", str(ce_path)))
289
+ return FineTunedCrossEncoderClassifier.load(ft_path)
290
+
291
+ obj = joblib.load(path)
292
+ if isinstance(obj, dict):
293
+ if obj.get("model_type") == "cross_encoder":
294
+ model = CrossEncoderClassifier(
295
+ cross_encoder_name=obj["cross_encoder_name"],
296
+ batch_size=obj.get("batch_size", 32),
297
+ max_length=obj.get("max_length", 512),
298
+ )
299
+ model.scaler = obj["scaler"]
300
+ model.clf = obj["classifier"]
301
+ model.classes_ = obj.get("classes_", obj["classifier"].classes_)
302
+ return model
303
+ if obj.get("model_type") == "multi_tower":
304
+ model = MultiTowerClassifier(
305
+ encoder_name=obj["encoder_name"],
306
+ batch_size=obj.get("batch_size", 256),
307
+ )
308
+ model.scaler = obj["scaler"]
309
+ model.clf = obj["classifier"]
310
+ model.classes_ = obj.get("classes_", obj["classifier"].classes_)
311
+ return model
312
+ if obj.get("model_type") == "minilm":
313
+ model = EmbeddingClassifier(
314
+ encoder_name=obj["encoder_name"],
315
+ batch_size=obj.get("batch_size", 256),
316
+ )
317
+ model.clf = obj["classifier"]
318
+ model.classes_ = obj.get("classes_", obj["classifier"].classes_)
319
+ return model
320
+ return obj["pipeline"]
321
+ return obj
src/multi_tower_model.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-tower semantic comparison architecture for SQL error classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional
7
+
8
+ import numpy as np
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.preprocessing import StandardScaler
11
+
12
+ from src.model import DEFAULT_ENCODER
13
+ from src.sql_features import extract_sql_features
14
+
15
+
16
+ @dataclass
17
+ class QueryContext:
18
+ """Inputs available in the SQL playground at inference time."""
19
+
20
+ question: str
21
+ schema: str
22
+ correct_query: str
23
+ student_query: str
24
+ error_message: Optional[str] = None
25
+
26
+
27
+ def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray:
28
+ denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
29
+ denom = np.maximum(denom, 1e-8)
30
+ return np.sum(a * b, axis=1) / denom
31
+
32
+
33
+ class MultiTowerClassifier:
34
+ """
35
+ Recommended architecture for SQL error classification.
36
+
37
+ Three semantic towers (shared MiniLM encoder):
38
+ 1. Intent tower — question + schema → what should be answered
39
+ 2. Reference tower — correct_query → ground-truth solution
40
+ 3. Student tower — student_query → what the student wrote
41
+
42
+ Comparison layer fuses:
43
+ - tower embeddings
44
+ - |student − reference| (what changed)
45
+ - student ⊙ reference (interaction)
46
+ - cosine similarities (semantic alignment)
47
+ - SQL structural features (join/null/agg rules)
48
+
49
+ A light linear head maps the fused vector → 15 error categories.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ encoder_name: str = DEFAULT_ENCODER,
55
+ batch_size: int = 256,
56
+ ):
57
+ self.encoder_name = encoder_name
58
+ self.batch_size = batch_size
59
+ self.encoder = None
60
+ self.scaler = StandardScaler()
61
+ self.clf = LogisticRegression(
62
+ max_iter=1000,
63
+ solver="lbfgs",
64
+ class_weight="balanced",
65
+ random_state=42,
66
+ )
67
+ self.classes_: Optional[np.ndarray] = None
68
+
69
+ def _load_encoder(self):
70
+ if self.encoder is None:
71
+ from sentence_transformers import SentenceTransformer
72
+
73
+ self.encoder = SentenceTransformer(self.encoder_name)
74
+
75
+ def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
76
+ self._load_encoder()
77
+ return self.encoder.encode(
78
+ texts,
79
+ batch_size=self.batch_size,
80
+ show_progress_bar=show_progress,
81
+ convert_to_numpy=True,
82
+ )
83
+
84
+ @staticmethod
85
+ def _intent_text(ctx: QueryContext) -> str:
86
+ return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
87
+
88
+ @staticmethod
89
+ def _reference_text(ctx: QueryContext) -> str:
90
+ return f"REFERENCE: {ctx.correct_query}"
91
+
92
+ @staticmethod
93
+ def _student_text(ctx: QueryContext) -> str:
94
+ parts = [f"STUDENT: {ctx.student_query}"]
95
+ if ctx.error_message:
96
+ parts.append(f"ERROR: {ctx.error_message}")
97
+ return " ".join(parts)
98
+
99
+ def _build_feature_matrix(
100
+ self,
101
+ contexts: List[QueryContext],
102
+ show_progress: bool = False,
103
+ ) -> np.ndarray:
104
+ intent_texts = [self._intent_text(c) for c in contexts]
105
+ ref_texts = [self._reference_text(c) for c in contexts]
106
+ student_texts = [self._student_text(c) for c in contexts]
107
+
108
+ intent_emb = self._encode(intent_texts, show_progress)
109
+ ref_emb = self._encode(ref_texts, show_progress=False)
110
+ student_emb = self._encode(student_texts, show_progress=False)
111
+
112
+ diff = np.abs(student_emb - ref_emb)
113
+ prod = student_emb * ref_emb
114
+ cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1)
115
+ cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1)
116
+ cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1)
117
+
118
+ sql_feats = np.array(
119
+ [
120
+ extract_sql_features(c.student_query, c.correct_query)
121
+ for c in contexts
122
+ ],
123
+ dtype=np.float64,
124
+ )
125
+
126
+ return np.hstack(
127
+ [intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats]
128
+ )
129
+
130
+ def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier":
131
+ X = self._build_feature_matrix(contexts, show_progress=True)
132
+ X = self.scaler.fit_transform(X)
133
+ self.clf.fit(X, y)
134
+ self.classes_ = self.clf.classes_
135
+ return self
136
+
137
+ def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
138
+ X = self.scaler.transform(self._build_feature_matrix(contexts))
139
+ return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
140
+
141
+ def predict(self, contexts: List[QueryContext]) -> np.ndarray:
142
+ return self.clf.predict(self._prepare_features(contexts))
143
+
144
+ def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
145
+ return self.clf.predict_proba(self._prepare_features(contexts))
146
+
147
+ def explain_similarities(self, ctx: QueryContext) -> dict:
148
+ """Diagnostic scores for the playground UI."""
149
+ emb = self._build_feature_matrix([ctx])
150
+ intent_texts = [self._intent_text(ctx)]
151
+ ref_texts = [self._reference_text(ctx)]
152
+ student_texts = [self._student_text(ctx)]
153
+ intent_emb = self._encode(intent_texts)
154
+ ref_emb = self._encode(ref_texts)
155
+ student_emb = self._encode(student_texts)
156
+ return {
157
+ "student_vs_reference": float(_cosine(student_emb, ref_emb)[0]),
158
+ "student_vs_intent": float(_cosine(student_emb, intent_emb)[0]),
159
+ "reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]),
160
+ }
161
+
162
+
163
+ def contexts_from_dataframe(df) -> List[QueryContext]:
164
+ """Build QueryContext list from a training dataframe."""
165
+ has_error = "error_message" in df.columns
166
+ return [
167
+ QueryContext(
168
+ question=row["question"],
169
+ schema=row["schema"],
170
+ correct_query=row["correct_query"],
171
+ student_query=row["query"],
172
+ error_message=row["error_message"] if has_error else None,
173
+ )
174
+ for row in df.to_dict("records")
175
+ ]
src/predict.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference API for SQL error classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from typing import List, Optional
10
+
11
+ from src.categories import id_to_name, load_categories
12
+ from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
13
+ from src.cross_encoder_model import (
14
+ CrossEncoderClassifier,
15
+ FineTunedCrossEncoderClassifier,
16
+ )
17
+ from src.multi_tower_model import MultiTowerClassifier, QueryContext
18
+
19
+ CONTEXT_MODELS = (
20
+ CrossEncoderClassifier,
21
+ FineTunedCrossEncoderClassifier,
22
+ MultiTowerClassifier,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class Prediction:
28
+ label_id: int
29
+ label_name: str
30
+ confidence: float
31
+ top_k: List[dict]
32
+ similarities: Optional[dict] = None
33
+ pair_scores: Optional[dict] = None
34
+
35
+
36
+ class SQLErrorClassifier:
37
+ """Classifier wrapper for playground integration."""
38
+
39
+ def __init__(self, model_path: Path = DEFAULT_MODEL_PATH):
40
+ self.model = load_model(model_path)
41
+ self.label_map = id_to_name(load_categories())
42
+
43
+ def predict(
44
+ self,
45
+ query: str,
46
+ error_message: Optional[str] = None,
47
+ schema: Optional[str] = None,
48
+ question: Optional[str] = None,
49
+ correct_query: Optional[str] = None,
50
+ top_k: int = 3,
51
+ ) -> Prediction:
52
+ if isinstance(self.model, CONTEXT_MODELS):
53
+ if not all([schema, question, correct_query]):
54
+ raise ValueError(
55
+ "context models require schema, question, and correct_query"
56
+ )
57
+ ctx = QueryContext(
58
+ question=question,
59
+ schema=schema,
60
+ correct_query=correct_query,
61
+ student_query=query,
62
+ error_message=error_message,
63
+ )
64
+ proba = self.model.predict_proba([ctx])[0]
65
+ similarities = (
66
+ self.model.explain_similarities(ctx)
67
+ if isinstance(self.model, MultiTowerClassifier)
68
+ else None
69
+ )
70
+ pair_scores = (
71
+ self.model.explain_pair_scores(ctx)
72
+ if isinstance(self.model, CrossEncoderClassifier)
73
+ else None
74
+ )
75
+ else:
76
+ pair_scores = None
77
+ similarities = None
78
+ text = combine_features(
79
+ queries=[query],
80
+ error_messages=[error_message] if error_message else None,
81
+ schemas=[schema] if schema else None,
82
+ questions=[question] if question else None,
83
+ )[0]
84
+ proba = self.model.predict_proba([text])[0]
85
+ similarities = None
86
+
87
+ classes = self.model.classes_
88
+ ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
89
+ best_id = int(ranked[0][0])
90
+
91
+ return Prediction(
92
+ label_id=best_id,
93
+ label_name=self.label_map[best_id],
94
+ confidence=float(ranked[0][1]),
95
+ top_k=[
96
+ {
97
+ "label_id": int(cls),
98
+ "label_name": self.label_map[int(cls)],
99
+ "confidence": float(p),
100
+ }
101
+ for cls, p in ranked[:top_k]
102
+ ],
103
+ similarities=similarities,
104
+ pair_scores=pair_scores,
105
+ )
106
+
107
+
108
+ def main() -> None:
109
+ parser = argparse.ArgumentParser(description="Classify SQL error type")
110
+ parser.add_argument("--query", type=str, required=True)
111
+ parser.add_argument("--correct-query", type=str, default=None)
112
+ parser.add_argument("--error-message", type=str, default=None)
113
+ parser.add_argument("--schema", type=str, default=None)
114
+ parser.add_argument("--question", type=str, default=None)
115
+ parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
116
+ parser.add_argument("--top-k", type=int, default=3)
117
+ args = parser.parse_args()
118
+
119
+ clf = SQLErrorClassifier(args.model)
120
+ result = clf.predict(
121
+ args.query,
122
+ args.error_message,
123
+ args.schema,
124
+ args.question,
125
+ args.correct_query,
126
+ top_k=args.top_k,
127
+ )
128
+ print(json.dumps(asdict(result), indent=2))
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
src/sql_features.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Lightweight structural SQL features for the classification head."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import List
7
+
8
+ AGG_FUNCS = ("COUNT", "SUM", "AVG", "MAX", "MIN")
9
+ WINDOW_FUNCS = ("ROW_NUMBER", "RANK", "DENSE_RANK", "OVER")
10
+
11
+
12
+ def _upper(sql: str) -> str:
13
+ return sql.upper()
14
+
15
+
16
+ def extract_sql_features(student_query: str, correct_query: str = "") -> List[float]:
17
+ """
18
+ Rule-based signals that complement semantic embeddings.
19
+ Returns a fixed-length float vector.
20
+ """
21
+ s = _upper(student_query)
22
+ c = _upper(correct_query) if correct_query else ""
23
+
24
+ has_agg = any(f" {f}(" in s or f"{f}(" in s for f in AGG_FUNCS)
25
+ has_group = "GROUP BY" in s
26
+ has_join = "JOIN" in s
27
+ has_on = " ON " in s
28
+ has_where = " WHERE " in s
29
+ has_having = " HAVING " in s
30
+ has_distinct = "DISTINCT" in s
31
+ has_subquery = "(" in s and "SELECT" in s[s.find("(") :]
32
+ has_window = "OVER" in s
33
+ has_null_eq = "= NULL" in s or "=NULL" in s
34
+ has_is_null = "IS NULL" in s or "IS NOT NULL" in s
35
+ has_select_star = bool(re.search(r"SELECT\s+\*", s))
36
+ has_or = " OR " in s
37
+ has_and = " AND " in s
38
+
39
+ correct_has_distinct = "DISTINCT" in c
40
+ correct_has_group = "GROUP BY" in c
41
+ correct_has_inner = "INNER JOIN" in c
42
+ student_has_left = "LEFT JOIN" in s
43
+
44
+ return [
45
+ float(has_agg),
46
+ float(has_agg and not has_group),
47
+ float(has_join and not has_on),
48
+ float(has_join),
49
+ float(has_where and has_having),
50
+ float(has_agg and has_where and not has_having),
51
+ float(has_distinct),
52
+ float(correct_has_distinct and not has_distinct),
53
+ float(has_subquery),
54
+ float(has_window),
55
+ float(has_null_eq),
56
+ float(has_is_null),
57
+ float(has_select_star),
58
+ float(has_or and has_and),
59
+ float(correct_has_inner and student_has_left),
60
+ float(len(s) / max(len(c), 1)), # length ratio vs reference
61
+ ]
62
+
63
+
64
+ FEATURE_NAMES = [
65
+ "has_aggregate",
66
+ "agg_without_group_by",
67
+ "join_without_on",
68
+ "has_join",
69
+ "where_and_having",
70
+ "agg_in_where",
71
+ "has_distinct",
72
+ "missing_distinct_vs_correct",
73
+ "has_subquery",
74
+ "has_window",
75
+ "null_equals",
76
+ "is_null_check",
77
+ "select_star",
78
+ "and_or_mix",
79
+ "left_vs_inner_join",
80
+ "length_ratio",
81
+ ]
src/sql_templates.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Error injectors that transform exercise context into labeled mistakes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from typing import Callable, Dict, List, Tuple
7
+
8
+ from src.exercises import Exercise
9
+
10
+ FAKE_COLUMNS = ["fullname", "studentname", "coursename", "dept_name", "totals"]
11
+ FAKE_TABLES = ["student", "course", "enrolment", "employe", "orderz"]
12
+
13
+
14
+ def _pick(rng: random.Random, items: List[str], k: int = 1) -> str | List[str]:
15
+ if k == 1:
16
+ return rng.choice(items)
17
+ return rng.sample(items, k)
18
+
19
+
20
+ def _first_table(exercise: Exercise) -> str:
21
+ return exercise.tables[0]
22
+
23
+
24
+ def _second_table(exercise: Exercise) -> str:
25
+ return exercise.tables[1] if len(exercise.tables) > 1 else exercise.tables[0]
26
+
27
+
28
+ # --- Error injectors: (exercise) -> (erroneous_sql, error_message) ---
29
+
30
+ def inject_syntax_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
31
+ sql = exercise.correct_query
32
+ mutations = [
33
+ lambda s: s.replace("SELECT", "SELEC", 1),
34
+ lambda s: s.replace("FROM", "FRO", 1),
35
+ lambda s: s[:-1],
36
+ lambda s: s.replace(")", "", 1),
37
+ lambda s: s + " WHERE",
38
+ lambda s: s.replace(",", "", 1),
39
+ lambda s: s.replace("'", '"', 1) if "'" in s else s + " 'unclosed",
40
+ ]
41
+ return rng.choice(mutations)(sql), "syntax error at or near unexpected token"
42
+
43
+
44
+ def inject_join_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
45
+ t1, t2 = _first_table(exercise), _second_table(exercise)
46
+ col = _pick(rng, list(exercise.columns))
47
+ variants = [
48
+ f"SELECT {col} FROM {t1} JOIN {t2}",
49
+ f"SELECT {col} FROM {t1} INNER JOIN {t2} ON {t1}.id = {t2}.id",
50
+ (
51
+ f"SELECT {t1}.{col} FROM {t1} "
52
+ f"LEFT JOIN {t2} ON {t1}.{col} = {t2}.{col}"
53
+ ),
54
+ f"SELECT * FROM {t1}, {t2} WHERE {t1}.wrong_id = {t2}.wrong_id",
55
+ ]
56
+ return rng.choice(variants), "missing ON clause or invalid join condition"
57
+
58
+
59
+ def inject_aggregation_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
60
+ t = _first_table(exercise)
61
+ cols = list(exercise.columns)
62
+ group_col = cols[0]
63
+ agg_col = cols[-1]
64
+ bad = f"SELECT {group_col}, AVG({agg_col}) FROM {t}"
65
+ return bad, "column must appear in GROUP BY clause or be used in aggregate function"
66
+
67
+
68
+ def inject_having_where_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
69
+ t = _first_table(exercise)
70
+ cols = list(exercise.columns)
71
+ group_col, agg_col = cols[0], cols[-1]
72
+ bad = (
73
+ f"SELECT {group_col}, COUNT({agg_col}) FROM {t} "
74
+ f"WHERE COUNT({agg_col}) > {rng.randint(1, 5)}"
75
+ )
76
+ return bad, "aggregate functions are not allowed in WHERE"
77
+
78
+
79
+ def inject_subquery_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
80
+ t1, t2 = _first_table(exercise), _second_table(exercise)
81
+ col = _pick(rng, list(exercise.columns))
82
+ bad = f"SELECT {col} FROM {t1} WHERE {col} = (SELECT {col} FROM {t2})"
83
+ return bad, "subquery returned more than one row"
84
+
85
+
86
+ def inject_window_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
87
+ t = _first_table(exercise)
88
+ col = _pick(rng, list(exercise.columns))
89
+ variants = [
90
+ f"SELECT {col}, ROW_NUMBER() OVER () FROM {t}",
91
+ f"SELECT {col}, SUM({col}) OVER (ORDER BY {col}) FROM {t} GROUP BY {col}",
92
+ f"SELECT {col}, RANK() OVER (PARTITION {col}) FROM {t}",
93
+ ]
94
+ return rng.choice(variants), "window function requires PARTITION BY or ORDER BY"
95
+
96
+
97
+ def inject_null_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
98
+ t = _first_table(exercise)
99
+ col = _pick(rng, list(exercise.columns))
100
+ bad = f"SELECT * FROM {t} WHERE {col} = NULL"
101
+ return bad, "use IS NULL or IS NOT NULL to test for null values"
102
+
103
+
104
+ def inject_date_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
105
+ t = _first_table(exercise)
106
+ variants = [
107
+ f"SELECT * FROM {t} WHERE order_date = '31/02/2023'",
108
+ f"SELECT * FROM {t} WHERE order_date = DATE '2023-13-40'",
109
+ f"SELECT * FROM {t} WHERE STR_TO_DATE('bad-date', '%Y-%m-%d')",
110
+ f"SELECT * FROM {t} WHERE hire_date > 'yesterday'",
111
+ ]
112
+ return rng.choice(variants), "invalid date format or unknown date function"
113
+
114
+
115
+ def inject_column_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
116
+ t = _first_table(exercise)
117
+ col = _pick(rng, FAKE_COLUMNS)
118
+ bad = f"SELECT {col} FROM {t}"
119
+ return bad, f"column '{col}' does not exist"
120
+
121
+
122
+ def inject_table_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
123
+ tbl = _pick(rng, FAKE_TABLES)
124
+ col = _pick(rng, list(exercise.columns))
125
+ bad = f"SELECT {col} FROM {tbl}"
126
+ return bad, f"relation '{tbl}' does not exist"
127
+
128
+
129
+ def inject_datatype_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
130
+ t = _first_table(exercise)
131
+ col = _pick(rng, list(exercise.columns))
132
+ bad = f"SELECT {col} FROM {t} WHERE {col} = '{rng.choice(['abc', 'ten', 'N/A'])}'"
133
+ return bad, "operator does not exist: integer = character varying"
134
+
135
+
136
+ def inject_duplicate_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
137
+ """Drop DISTINCT when the question asks for unique values."""
138
+ sql = exercise.correct_query
139
+ if "DISTINCT" in sql.upper():
140
+ bad = sql.upper().replace("DISTINCT ", "").replace("distinct ", "")
141
+ # restore original casing loosely
142
+ bad = sql.replace("DISTINCT ", "").replace("distinct ", "")
143
+ else:
144
+ col = _pick(rng, list(exercise.columns))
145
+ bad = f"SELECT {col} FROM {_first_table(exercise)}"
146
+ return bad, "query returns duplicate rows; DISTINCT may be required"
147
+
148
+
149
+ def inject_logical_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
150
+ """
151
+ Produce a query that runs against the schema but answers the question incorrectly.
152
+ Variants are tied to the exercise question and correct answer.
153
+ """
154
+ sql = exercise.correct_query
155
+ q = exercise.question.lower()
156
+ variants: List[str] = []
157
+
158
+ if "average" in q or "avg" in sql.lower():
159
+ variants.append(sql.replace("AVG(", "SUM(", 1))
160
+ variants.append(sql.replace("AVG(", "MAX(", 1))
161
+ if " and " in q and " AND " in sql:
162
+ variants.append(sql.replace(" AND ", " OR ", 1))
163
+ if "join" in sql.lower():
164
+ t1, t2 = _first_table(exercise), _second_table(exercise)
165
+ variants.append(
166
+ f"SELECT {t1}.name, {t2}.name FROM {t1} "
167
+ f"JOIN {t2} ON {t1}.id = {t2}.id"
168
+ )
169
+ variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1))
170
+ if "between" in q and "BETWEEN" in sql.upper():
171
+ upper = sql.upper()
172
+ between_part = upper.split("BETWEEN", 1)[1]
173
+ bounds = between_part.split("AND", 1)
174
+ if len(bounds) == 2:
175
+ lo = bounds[0].strip().split()[-1]
176
+ hi = bounds[1].strip().split()[0]
177
+ variants.append(
178
+ sql.split("WHERE", 1)[0]
179
+ + f" WHERE price BETWEEN {hi} AND {lo}"
180
+ )
181
+ if "rank" in q or "over" in sql.lower():
182
+ col = _pick(rng, list(exercise.columns))
183
+ variants.append(
184
+ f"SELECT name, {col} FROM {_first_table(exercise)} ORDER BY {col} DESC"
185
+ )
186
+ if "total" in q and "WHERE" in sql.upper():
187
+ variants.append(sql.replace("active", "inactive"))
188
+ if "highest" in q or "max" in sql.lower():
189
+ col = _pick(rng, list(exercise.columns))
190
+ variants.append(
191
+ f"SELECT name FROM {_first_table(exercise)} "
192
+ f"WHERE {col} >= (SELECT AVG({col}) FROM {_second_table(exercise)})"
193
+ )
194
+ if "enrolled" in q and "INNER JOIN" in sql.upper():
195
+ variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1))
196
+ if "not provided" in q or "is null" in sql.lower():
197
+ variants.append(sql.replace("IS NULL", "= ''"))
198
+
199
+ if not variants:
200
+ col = _pick(rng, list(exercise.columns))
201
+ t = _first_table(exercise)
202
+ variants = [
203
+ f"SELECT {col} FROM {t} ORDER BY {col} DESC LIMIT 10",
204
+ f"SELECT COUNT(*) FROM {t}",
205
+ ]
206
+
207
+ bad = rng.choice(variants)
208
+ return bad, "query executes but produces incorrect result set"
209
+
210
+
211
+ def inject_performance_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
212
+ t1, t2 = _first_table(exercise), _second_table(exercise)
213
+ variants = [
214
+ f"SELECT * FROM {t1}",
215
+ f"SELECT * FROM {t1} JOIN {t2} ON {t1}.id = {t2}.id",
216
+ (
217
+ f"SELECT * FROM {t1} "
218
+ f"WHERE {_pick(rng, list(exercise.columns))} "
219
+ f"LIKE '%{rng.choice(['a', 'e', 'i'])}%'"
220
+ ),
221
+ f"SELECT * FROM {t1} CROSS JOIN {t2}",
222
+ ]
223
+ return rng.choice(variants), "inefficient query: SELECT * or cartesian join detected"
224
+
225
+
226
+ def inject_filtering_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
227
+ sql = exercise.correct_query
228
+ col = _pick(rng, list(exercise.columns))
229
+ t = _first_table(exercise)
230
+ threshold = rng.randint(50, 90)
231
+ variants = [
232
+ sql.replace(">", "<", 1) if ">" in sql else sql.replace("=", "!=", 1),
233
+ f"SELECT {col} FROM {t} WHERE {col} > {threshold} AND {col} < {threshold - 20}",
234
+ f"SELECT {col} FROM {t} WHERE NOT {col} > {threshold}",
235
+ sql.replace(" AND ", " OR ", 1) if " AND " in sql else (
236
+ f"SELECT {col} FROM {t} WHERE {col} BETWEEN {threshold} AND {threshold - 10}"
237
+ ),
238
+ ]
239
+ return rng.choice(variants), "WHERE clause filters incorrect rows"
240
+
241
+
242
+ ERROR_INJECTORS: Dict[int, Callable[[random.Random, Exercise], Tuple[str, str]]] = {
243
+ 0: inject_syntax_error,
244
+ 1: inject_join_error,
245
+ 2: inject_aggregation_error,
246
+ 3: inject_having_where_error,
247
+ 4: inject_subquery_error,
248
+ 5: inject_window_error,
249
+ 6: inject_null_error,
250
+ 7: inject_date_error,
251
+ 8: inject_column_error,
252
+ 9: inject_table_error,
253
+ 10: inject_datatype_error,
254
+ 11: inject_duplicate_error,
255
+ 12: inject_logical_error,
256
+ 13: inject_performance_error,
257
+ 14: inject_filtering_error,
258
+ }
src/train.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train the SQL error classifier."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import pandas as pd
10
+ from sklearn.metrics import classification_report
11
+ from sklearn.model_selection import train_test_split
12
+
13
+ from src.categories import id_to_name, load_categories
14
+ from src.cross_encoder_model import (
15
+ CrossEncoderClassifier,
16
+ FineTunedCrossEncoderClassifier,
17
+ )
18
+ from src.model import (
19
+ DEFAULT_MODEL_PATH,
20
+ ModelType,
21
+ build_classifier,
22
+ combine_features,
23
+ save_model,
24
+ )
25
+ from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe
26
+
27
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
28
+ DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
29
+ DEFAULT_METRICS = PROJECT_ROOT / "models" / "metrics.json"
30
+
31
+ CONTEXT_MODELS = (
32
+ CrossEncoderClassifier,
33
+ FineTunedCrossEncoderClassifier,
34
+ MultiTowerClassifier,
35
+ )
36
+
37
+
38
+ def _split_dataframe(
39
+ df: pd.DataFrame, test_size: float, val_size: float, seed: int
40
+ ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
41
+ trainval, test = train_test_split(
42
+ df, test_size=test_size, random_state=seed, stratify=df["label_id"]
43
+ )
44
+ relative_val = val_size / (1 - test_size)
45
+ train, val = train_test_split(
46
+ trainval,
47
+ test_size=relative_val,
48
+ random_state=seed,
49
+ stratify=trainval["label_id"],
50
+ )
51
+ return train, val, test
52
+
53
+
54
+ def train(
55
+ data_path: Path = DEFAULT_DATA,
56
+ model_path: Path = DEFAULT_MODEL_PATH,
57
+ metrics_path: Path = DEFAULT_METRICS,
58
+ test_size: float = 0.1,
59
+ val_size: float = 0.1,
60
+ use_error_message: bool = True,
61
+ max_train_samples: int | None = None,
62
+ model_type: ModelType = "cross_encoder",
63
+ epochs: int = 1,
64
+ seed: int = 42,
65
+ ) -> dict:
66
+ print(f"Loading data from {data_path}...")
67
+ df = pd.read_parquet(data_path)
68
+
69
+ if max_train_samples and len(df) > max_train_samples:
70
+ df = df.sample(n=max_train_samples, random_state=seed)
71
+
72
+ if not use_error_message and "error_message" in df.columns:
73
+ df = df.drop(columns=["error_message"])
74
+
75
+ train_df, val_df, test_df = _split_dataframe(df, test_size, val_size, seed)
76
+ print(
77
+ f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}"
78
+ )
79
+
80
+ model = build_classifier(model_type=model_type)
81
+ print(f"Training {model_type} classifier...")
82
+
83
+ if isinstance(model, CONTEXT_MODELS):
84
+ train_ctx = contexts_from_dataframe(train_df)
85
+ val_ctx = contexts_from_dataframe(val_df)
86
+ test_ctx = contexts_from_dataframe(test_df)
87
+
88
+ if isinstance(model, FineTunedCrossEncoderClassifier):
89
+ model.fit(
90
+ train_ctx,
91
+ train_df["label_id"].values,
92
+ epochs=epochs,
93
+ output_path=model_path.with_suffix(".ce")
94
+ if model_path.suffix == ".joblib"
95
+ else model_path,
96
+ )
97
+ else:
98
+ model.fit(train_ctx, train_df["label_id"].values)
99
+
100
+ val_preds = model.predict(val_ctx)
101
+ test_preds = model.predict(test_ctx)
102
+ y_val = val_df["label_id"].values
103
+ y_test = test_df["label_id"].values
104
+ else:
105
+
106
+ def to_texts(frame: pd.DataFrame) -> list[str]:
107
+ return combine_features(
108
+ queries=frame["query"].tolist(),
109
+ error_messages=frame["error_message"].tolist()
110
+ if "error_message" in frame.columns
111
+ else None,
112
+ schemas=frame["schema"].tolist() if "schema" in frame.columns else None,
113
+ questions=frame["question"].tolist()
114
+ if "question" in frame.columns
115
+ else None,
116
+ )
117
+
118
+ model.fit(to_texts(train_df), train_df["label_id"].values)
119
+ val_preds = model.predict(to_texts(val_df))
120
+ test_preds = model.predict(to_texts(test_df))
121
+ y_val = val_df["label_id"].values
122
+ y_test = test_df["label_id"].values
123
+
124
+ val_report = classification_report(
125
+ y_val, val_preds, output_dict=True, zero_division=0
126
+ )
127
+ print(f"Validation accuracy: {val_report['accuracy']:.4f}")
128
+
129
+ test_report = classification_report(
130
+ y_test, test_preds, output_dict=True, zero_division=0
131
+ )
132
+ print(f"Test accuracy: {test_report['accuracy']:.4f}")
133
+
134
+ save_model(model, model_path, model_type=model_type)
135
+ print(f"Model saved to {model_path}")
136
+
137
+ categories = load_categories()
138
+ label_map = id_to_name(categories)
139
+ metrics = {
140
+ "train_size": len(train_df),
141
+ "val_size": len(val_df),
142
+ "test_size": len(test_df),
143
+ "model_type": model_type,
144
+ "epochs": epochs if model_type == "cross_encoder_ft" else None,
145
+ "use_error_message": use_error_message,
146
+ "validation": val_report,
147
+ "test": test_report,
148
+ "label_map": {str(k): v for k, v in label_map.items()},
149
+ }
150
+ metrics_path.parent.mkdir(parents=True, exist_ok=True)
151
+ with open(metrics_path, "w") as f:
152
+ json.dump(metrics, f, indent=2)
153
+ print(f"Metrics saved to {metrics_path}")
154
+
155
+ return metrics
156
+
157
+
158
+ def main() -> None:
159
+ parser = argparse.ArgumentParser(description="Train SQL error classifier")
160
+ parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
161
+ parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
162
+ parser.add_argument("--metrics", type=Path, default=DEFAULT_METRICS)
163
+ parser.add_argument("--test-size", type=float, default=0.1)
164
+ parser.add_argument("--val-size", type=float, default=0.1)
165
+ parser.add_argument("--no-error-message", action="store_true")
166
+ parser.add_argument("--max-samples", type=int, default=None)
167
+ parser.add_argument(
168
+ "--model-type",
169
+ choices=["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"],
170
+ default="cross_encoder",
171
+ help="cross_encoder (recommended): joint attention pairs; "
172
+ "cross_encoder_ft: fine-tuned end-to-end (best accuracy)",
173
+ )
174
+ parser.add_argument(
175
+ "--epochs",
176
+ type=int,
177
+ default=1,
178
+ help="Epochs for cross_encoder_ft fine-tuning",
179
+ )
180
+ parser.add_argument("--seed", type=int, default=42)
181
+ args = parser.parse_args()
182
+
183
+ train(
184
+ data_path=args.data,
185
+ model_path=args.model,
186
+ metrics_path=args.metrics,
187
+ test_size=args.test_size,
188
+ val_size=args.val_size,
189
+ use_error_message=not args.no_error_message,
190
+ max_train_samples=args.max_samples,
191
+ model_type=args.model_type,
192
+ epochs=args.epochs,
193
+ seed=args.seed,
194
+ )
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
train_space_app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Space — CodeBERT SQL Error Classifier Training UI.
3
+
4
+ Deploy as a Gradio Space with app_file: train_space_app.py
5
+ Set hardware to GPU (t4-small recommended).
6
+ Add HF_TOKEN secret to push trained models to your Hub account.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ import shutil
14
+ import tempfile
15
+ from pathlib import Path
16
+
17
+ import gradio as gr
18
+ import pandas as pd
19
+
20
+ from src.hf_train_codebert import train
21
+
22
+ PROJECT_ROOT = Path(__file__).parent
23
+ DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
24
+ OUTPUT_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
25
+ BUNDLED_DATASETS = {
26
+ "Dev (15K samples)": str(PROJECT_ROOT / "data" / "sql_errors_dev.parquet"),
27
+ "Full (1M samples)": str(PROJECT_ROOT / "data" / "sql_errors_1m.parquet"),
28
+ }
29
+
30
+
31
+ def _format_metrics(metrics: dict) -> str:
32
+ val = metrics.get("validation", {})
33
+ test = metrics.get("test", {})
34
+ lines = [
35
+ "## Training complete",
36
+ "",
37
+ f"- Train samples: **{metrics.get('train_samples', 0):,}**",
38
+ f"- Val samples: **{metrics.get('val_samples', 0):,}**",
39
+ f"- Test samples: **{metrics.get('test_samples', 0):,}**",
40
+ "",
41
+ "### Validation",
42
+ f"- F1 macro: **{val.get('eval_f1_macro', 0):.4f}**",
43
+ f"- F1 micro: **{val.get('eval_f1_micro', 0):.4f}**",
44
+ "",
45
+ "### Test",
46
+ f"- F1 macro: **{test.get('f1_macro', 0):.4f}**",
47
+ f"- F1 micro: **{test.get('f1_micro', 0):.4f}**",
48
+ f"- Subset accuracy: **{test.get('subset_accuracy', 0):.4f}**",
49
+ "",
50
+ f"Model saved to `{OUTPUT_DIR}`",
51
+ ]
52
+ if metrics.get("hub_url"):
53
+ lines.append(f"\n**Hub model:** {metrics['hub_url']}")
54
+ return "\n".join(lines)
55
+
56
+
57
+ def run_training(
58
+ dataset_choice: str,
59
+ uploaded_file,
60
+ max_samples: int,
61
+ epochs: float,
62
+ batch_size: int,
63
+ learning_rate: float,
64
+ max_length: int,
65
+ fp16: bool,
66
+ push_to_hub: bool,
67
+ hub_model_id: str,
68
+ progress=gr.Progress(),
69
+ ):
70
+ progress(0, desc="Preparing dataset...")
71
+
72
+ if uploaded_file is not None:
73
+ data_path = Path(uploaded_file.name)
74
+ else:
75
+ data_path = Path(BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA))
76
+ if not data_path.exists():
77
+ return (
78
+ f"Dataset not found: `{data_path}`. "
79
+ "Upload a parquet file or include data/ in the Space repo.",
80
+ None,
81
+ None,
82
+ )
83
+
84
+ hub_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
85
+ if push_to_hub and not hub_token:
86
+ return (
87
+ "Add `HF_TOKEN` to Space secrets to push models to the Hub.",
88
+ None,
89
+ None,
90
+ )
91
+ if push_to_hub and not hub_model_id.strip():
92
+ return "Enter a Hub model id (e.g. `your-username/sql-codebert-classifier`).", None, None
93
+
94
+ if OUTPUT_DIR.exists():
95
+ shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
96
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
97
+
98
+ samples = int(max_samples) if max_samples and max_samples > 0 else None
99
+ progress(0.1, desc="Starting CodeBERT training...")
100
+
101
+ try:
102
+ metrics = train(
103
+ data_path=data_path,
104
+ output_dir=OUTPUT_DIR,
105
+ epochs=epochs,
106
+ batch_size=batch_size,
107
+ learning_rate=learning_rate,
108
+ max_length=max_length,
109
+ max_samples=samples,
110
+ fp16=fp16,
111
+ save_strategy="no",
112
+ push_to_hub=push_to_hub,
113
+ hub_model_id=hub_model_id.strip() or None,
114
+ hub_token=hub_token,
115
+ )
116
+ except Exception as exc:
117
+ return f"Training failed:\n\n```\n{exc}\n```", None, None
118
+
119
+ progress(1.0, desc="Done")
120
+ if push_to_hub and hub_model_id.strip():
121
+ metrics["hub_url"] = f"https://huggingface.co/{hub_model_id.strip()}"
122
+
123
+ metrics_path = OUTPUT_DIR / "metrics.json"
124
+ summary = _format_metrics(metrics)
125
+ return summary, str(metrics_path) if metrics_path.exists() else None, str(OUTPUT_DIR)
126
+
127
+
128
+ def load_preview(dataset_choice: str, uploaded_file) -> str:
129
+ try:
130
+ if uploaded_file is not None:
131
+ df = pd.read_parquet(uploaded_file.name)
132
+ else:
133
+ path = BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA)
134
+ if not Path(path).exists():
135
+ return f"Dataset not found: {path}"
136
+ df = pd.read_parquet(path)
137
+ cols = list(df.columns)
138
+ sample = df.head(2).to_dict(orient="records")
139
+ return f"**Rows:** {len(df):,}\n\n**Columns:** `{cols}`\n\n**Sample:**\n```json\n{json.dumps(sample, indent=2)[:2000]}\n```"
140
+ except Exception as exc:
141
+ return f"Could not load preview: {exc}"
142
+
143
+
144
+ with gr.Blocks(title="SQL Error Classifier — Train") as demo:
145
+ gr.Markdown(
146
+ """
147
+ # SQL Error Classifier — CodeBERT Training
148
+ Train **microsoft/codebert-base** as a cross-encoder on this Space.
149
+
150
+ **Input format:** `QUESTION` + `SCHEMA` + `STUDENT_SQL` + `CORRECT_SQL` (single sequence)
151
+
152
+ **GPU recommended** — upgrade Space hardware to `t4-small` or better.
153
+ """
154
+ )
155
+
156
+ with gr.Row():
157
+ with gr.Column(scale=1):
158
+ dataset_choice = gr.Dropdown(
159
+ choices=list(BUNDLED_DATASETS.keys()),
160
+ value="Dev (15K samples)",
161
+ label="Bundled dataset",
162
+ )
163
+ uploaded = gr.File(
164
+ label="Or upload parquet",
165
+ file_types=[".parquet"],
166
+ )
167
+ preview_btn = gr.Button("Preview dataset")
168
+ preview_out = gr.Markdown()
169
+
170
+ max_samples = gr.Number(
171
+ label="Max samples (0 = all)",
172
+ value=5000,
173
+ precision=0,
174
+ )
175
+ epochs = gr.Slider(1, 10, value=2, step=1, label="Epochs")
176
+ batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch size")
177
+ learning_rate = gr.Number(label="Learning rate", value=2e-5)
178
+ max_length = gr.Slider(128, 512, value=512, step=64, label="Max length")
179
+ fp16 = gr.Checkbox(label="FP16 (GPU only)", value=True)
180
+
181
+ push_to_hub = gr.Checkbox(label="Push to Hugging Face Hub", value=False)
182
+ hub_model_id = gr.Textbox(
183
+ label="Hub model id",
184
+ placeholder="your-username/sql-codebert-classifier",
185
+ )
186
+
187
+ train_btn = gr.Button("Start Training", variant="primary")
188
+
189
+ with gr.Column(scale=1):
190
+ result = gr.Markdown(label="Results")
191
+ metrics_file = gr.File(label="metrics.json")
192
+ model_dir = gr.Textbox(label="Model output path", interactive=False)
193
+
194
+ preview_btn.click(load_preview, [dataset_choice, uploaded], preview_out)
195
+ train_btn.click(
196
+ run_training,
197
+ [
198
+ dataset_choice,
199
+ uploaded,
200
+ max_samples,
201
+ epochs,
202
+ batch_size,
203
+ learning_rate,
204
+ max_length,
205
+ fp16,
206
+ push_to_hub,
207
+ hub_model_id,
208
+ ],
209
+ [result, metrics_file, model_dir],
210
+ )
211
+
212
+ gr.Markdown(
213
+ """
214
+ ### Space setup
215
+ 1. Create a Gradio Space and push this repo
216
+ 2. Set **Hardware → GPU (t4-small)**
217
+ 3. Add secret `HF_TOKEN` (write token) to push models
218
+ 4. Include `data/sql_errors_dev.parquet` in the repo (or upload at runtime)
219
+
220
+ ### After training
221
+ Use the saved model with:
222
+ ```python
223
+ from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
224
+ clf = CodeBERTSQLErrorClassifier("models/codebert-cross-encoder")
225
+ ```
226
+ """
227
+ )
228
+
229
+ if __name__ == "__main__":
230
+ demo.launch()