"""Hugging Face Hub integration for the SQL error classifier.""" from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, Optional, Union import joblib from src.categories import load_categories from src.cross_encoder_model import CrossEncoderClassifier from src.model import DEFAULT_ENCODER, load_model from src.multi_tower_model import MultiTowerClassifier, QueryContext PROJECT_ROOT = Path(__file__).resolve().parent.parent CONFIG_NAME = "config.json" CLASSIFIER_NAME = "classifier.joblib" CATEGORIES_NAME = "categories.json" SUPPORTED_CONTEXT_MODELS = (CrossEncoderClassifier, MultiTowerClassifier) class SQLLErrorClassifierHF: """ Hugging Face–compatible wrapper for SQL error classifiers. Usage: clf = SQLLErrorClassifierHF.from_pretrained("username/sql-error-classifier") result = clf.predict( question="...", schema="...", correct_query="...", student_query="..." ) """ def __init__(self, model, label_map: Dict[int, str]): self.model = model self.label_map = label_map def predict( self, question: str, schema: str, correct_query: str, student_query: str, error_message: Optional[str] = None, top_k: int = 3, ) -> Dict[str, Any]: ctx = QueryContext( question=question, schema=schema, correct_query=correct_query, student_query=student_query, error_message=error_message, ) proba = self.model.predict_proba([ctx])[0] classes = self.model.classes_ ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True) best_id = int(ranked[0][0]) diagnostics: Dict[str, Any] = {} if isinstance(self.model, CrossEncoderClassifier): diagnostics["pair_scores"] = self.model.explain_pair_scores(ctx) elif isinstance(self.model, MultiTowerClassifier): diagnostics["similarities"] = self.model.explain_similarities(ctx) return { "label_id": best_id, "label_name": self.label_map[best_id], "confidence": float(ranked[0][1]), "top_k": [ { "label_id": int(cls), "label_name": self.label_map[int(cls)], "confidence": float(p), } for cls, p in ranked[:top_k] ], **diagnostics, } def save_pretrained(self, save_directory: Union[str, Path]) -> Path: """Save model artifacts in Hugging Face Hub layout.""" save_dir = Path(save_directory) save_dir.mkdir(parents=True, exist_ok=True) if isinstance(self.model, CrossEncoderClassifier): payload = { "model_type": "cross_encoder", "cross_encoder_name": self.model.cross_encoder_name, "batch_size": self.model.batch_size, "max_length": self.model.max_length, "scaler": self.model.scaler, "classifier": self.model.clf, "classes_": self.model.classes_, } config = { "model_type": "cross_encoder", "architecture": "cross-encoder-pairwise", "cross_encoder_name": self.model.cross_encoder_name, "batch_size": self.model.batch_size, "num_labels": len(self.label_map), "task": "sql-error-classification", } elif isinstance(self.model, MultiTowerClassifier): payload = { "model_type": "multi_tower", "encoder_name": self.model.encoder_name, "batch_size": self.model.batch_size, "scaler": self.model.scaler, "classifier": self.model.clf, "classes_": self.model.classes_, } config = { "model_type": "multi_tower", "architecture": "multi-tower-semantic-comparison", "encoder_name": self.model.encoder_name, "batch_size": self.model.batch_size, "num_labels": len(self.label_map), "task": "sql-error-classification", } else: raise ValueError("Only cross_encoder and multi_tower models can be published") joblib.dump(payload, save_dir / CLASSIFIER_NAME) with open(save_dir / CONFIG_NAME, "w") as f: json.dump(config, f, indent=2) categories = load_categories() cat_data = [ {"id": c.id, "name": c.name, "description": c.description} for c in categories ] with open(save_dir / CATEGORIES_NAME, "w") as f: json.dump(cat_data, f, indent=2) return save_dir @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, Path], *, token: Optional[str] = None, ) -> "SQLLErrorClassifierHF": """Load from a local directory or Hugging Face Hub repo.""" path = _resolve_model_path(pretrained_model_name_or_path, token=token) with open(path / CONFIG_NAME) as f: config = json.load(f) with open(path / CATEGORIES_NAME) as f: categories = json.load(f) label_map = {c["id"]: c["name"] for c in categories} obj = joblib.load(path / CLASSIFIER_NAME) model_type = config.get("model_type", obj.get("model_type")) if model_type == "cross_encoder": model = CrossEncoderClassifier( cross_encoder_name=obj.get( "cross_encoder_name", config.get("cross_encoder_name", "cross-encoder/ms-marco-MiniLM-L6-v2"), ), batch_size=obj.get("batch_size", 32), max_length=obj.get("max_length", 512), ) model.scaler = obj["scaler"] model.clf = obj["classifier"] model.classes_ = obj.get("classes_", obj["classifier"].classes_) else: model = MultiTowerClassifier( encoder_name=obj.get("encoder_name", config.get("encoder_name", DEFAULT_ENCODER)), batch_size=obj.get("batch_size", 256), ) model.scaler = obj["scaler"] model.clf = obj["classifier"] model.classes_ = obj.get("classes_", obj["classifier"].classes_) return cls(model=model, label_map=label_map) def _resolve_model_path( pretrained_model_name_or_path: Union[str, Path], token: Optional[str] = None, ) -> Path: local = Path(pretrained_model_name_or_path) if local.exists() and (local / CONFIG_NAME).exists(): return local from huggingface_hub import snapshot_download return Path( snapshot_download( repo_id=str(pretrained_model_name_or_path), token=token, allow_patterns=[CONFIG_NAME, CLASSIFIER_NAME, CATEGORIES_NAME], ) ) def package_for_hub(model_path: Path, output_dir: Path) -> Path: """Convert a local joblib model into HF Hub layout.""" sklearn_model = load_model(model_path) if not isinstance(sklearn_model, SUPPORTED_CONTEXT_MODELS): raise ValueError( "Only cross_encoder and multi_tower models can be published to Hugging Face Hub" ) label_map = {c.id: c.name for c in load_categories()} wrapper = SQLLErrorClassifierHF(model=sklearn_model, label_map=label_map) return wrapper.save_pretrained(output_dir)