sql-error-classifier-train / src /huggingface.py
nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
7.65 kB
"""Hugging Face Hub integration for the SQL error classifier."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, Optional, Union
import joblib
from src.categories import load_categories
from src.cross_encoder_model import CrossEncoderClassifier
from src.model import DEFAULT_ENCODER, load_model
from src.multi_tower_model import MultiTowerClassifier, QueryContext
PROJECT_ROOT = Path(__file__).resolve().parent.parent
CONFIG_NAME = "config.json"
CLASSIFIER_NAME = "classifier.joblib"
CATEGORIES_NAME = "categories.json"
SUPPORTED_CONTEXT_MODELS = (CrossEncoderClassifier, MultiTowerClassifier)
class SQLLErrorClassifierHF:
"""
Hugging Face–compatible wrapper for SQL error classifiers.
Usage:
clf = SQLLErrorClassifierHF.from_pretrained("username/sql-error-classifier")
result = clf.predict(
question="...", schema="...", correct_query="...", student_query="..."
)
"""
def __init__(self, model, label_map: Dict[int, str]):
self.model = model
self.label_map = label_map
def predict(
self,
question: str,
schema: str,
correct_query: str,
student_query: str,
error_message: Optional[str] = None,
top_k: int = 3,
) -> Dict[str, Any]:
ctx = QueryContext(
question=question,
schema=schema,
correct_query=correct_query,
student_query=student_query,
error_message=error_message,
)
proba = self.model.predict_proba([ctx])[0]
classes = self.model.classes_
ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
best_id = int(ranked[0][0])
diagnostics: Dict[str, Any] = {}
if isinstance(self.model, CrossEncoderClassifier):
diagnostics["pair_scores"] = self.model.explain_pair_scores(ctx)
elif isinstance(self.model, MultiTowerClassifier):
diagnostics["similarities"] = self.model.explain_similarities(ctx)
return {
"label_id": best_id,
"label_name": self.label_map[best_id],
"confidence": float(ranked[0][1]),
"top_k": [
{
"label_id": int(cls),
"label_name": self.label_map[int(cls)],
"confidence": float(p),
}
for cls, p in ranked[:top_k]
],
**diagnostics,
}
def save_pretrained(self, save_directory: Union[str, Path]) -> Path:
"""Save model artifacts in Hugging Face Hub layout."""
save_dir = Path(save_directory)
save_dir.mkdir(parents=True, exist_ok=True)
if isinstance(self.model, CrossEncoderClassifier):
payload = {
"model_type": "cross_encoder",
"cross_encoder_name": self.model.cross_encoder_name,
"batch_size": self.model.batch_size,
"max_length": self.model.max_length,
"scaler": self.model.scaler,
"classifier": self.model.clf,
"classes_": self.model.classes_,
}
config = {
"model_type": "cross_encoder",
"architecture": "cross-encoder-pairwise",
"cross_encoder_name": self.model.cross_encoder_name,
"batch_size": self.model.batch_size,
"num_labels": len(self.label_map),
"task": "sql-error-classification",
}
elif isinstance(self.model, MultiTowerClassifier):
payload = {
"model_type": "multi_tower",
"encoder_name": self.model.encoder_name,
"batch_size": self.model.batch_size,
"scaler": self.model.scaler,
"classifier": self.model.clf,
"classes_": self.model.classes_,
}
config = {
"model_type": "multi_tower",
"architecture": "multi-tower-semantic-comparison",
"encoder_name": self.model.encoder_name,
"batch_size": self.model.batch_size,
"num_labels": len(self.label_map),
"task": "sql-error-classification",
}
else:
raise ValueError("Only cross_encoder and multi_tower models can be published")
joblib.dump(payload, save_dir / CLASSIFIER_NAME)
with open(save_dir / CONFIG_NAME, "w") as f:
json.dump(config, f, indent=2)
categories = load_categories()
cat_data = [
{"id": c.id, "name": c.name, "description": c.description}
for c in categories
]
with open(save_dir / CATEGORIES_NAME, "w") as f:
json.dump(cat_data, f, indent=2)
return save_dir
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, Path],
*,
token: Optional[str] = None,
) -> "SQLLErrorClassifierHF":
"""Load from a local directory or Hugging Face Hub repo."""
path = _resolve_model_path(pretrained_model_name_or_path, token=token)
with open(path / CONFIG_NAME) as f:
config = json.load(f)
with open(path / CATEGORIES_NAME) as f:
categories = json.load(f)
label_map = {c["id"]: c["name"] for c in categories}
obj = joblib.load(path / CLASSIFIER_NAME)
model_type = config.get("model_type", obj.get("model_type"))
if model_type == "cross_encoder":
model = CrossEncoderClassifier(
cross_encoder_name=obj.get(
"cross_encoder_name",
config.get("cross_encoder_name", "cross-encoder/ms-marco-MiniLM-L6-v2"),
),
batch_size=obj.get("batch_size", 32),
max_length=obj.get("max_length", 512),
)
model.scaler = obj["scaler"]
model.clf = obj["classifier"]
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
else:
model = MultiTowerClassifier(
encoder_name=obj.get("encoder_name", config.get("encoder_name", DEFAULT_ENCODER)),
batch_size=obj.get("batch_size", 256),
)
model.scaler = obj["scaler"]
model.clf = obj["classifier"]
model.classes_ = obj.get("classes_", obj["classifier"].classes_)
return cls(model=model, label_map=label_map)
def _resolve_model_path(
pretrained_model_name_or_path: Union[str, Path],
token: Optional[str] = None,
) -> Path:
local = Path(pretrained_model_name_or_path)
if local.exists() and (local / CONFIG_NAME).exists():
return local
from huggingface_hub import snapshot_download
return Path(
snapshot_download(
repo_id=str(pretrained_model_name_or_path),
token=token,
allow_patterns=[CONFIG_NAME, CLASSIFIER_NAME, CATEGORIES_NAME],
)
)
def package_for_hub(model_path: Path, output_dir: Path) -> Path:
"""Convert a local joblib model into HF Hub layout."""
sklearn_model = load_model(model_path)
if not isinstance(sklearn_model, SUPPORTED_CONTEXT_MODELS):
raise ValueError(
"Only cross_encoder and multi_tower models can be published to Hugging Face Hub"
)
label_map = {c.id: c.name for c in load_categories()}
wrapper = SQLLErrorClassifierHF(model=sklearn_model, label_map=label_map)
return wrapper.save_pretrained(output_dir)