Spaces:
Sleeping
Sleeping
| """Hugging Face Hub integration for the SQL error classifier.""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Union | |
| import joblib | |
| from src.categories import load_categories | |
| from src.cross_encoder_model import CrossEncoderClassifier | |
| from src.model import DEFAULT_ENCODER, load_model | |
| from src.multi_tower_model import MultiTowerClassifier, QueryContext | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| CONFIG_NAME = "config.json" | |
| CLASSIFIER_NAME = "classifier.joblib" | |
| CATEGORIES_NAME = "categories.json" | |
| SUPPORTED_CONTEXT_MODELS = (CrossEncoderClassifier, MultiTowerClassifier) | |
| class SQLLErrorClassifierHF: | |
| """ | |
| Hugging Face–compatible wrapper for SQL error classifiers. | |
| Usage: | |
| clf = SQLLErrorClassifierHF.from_pretrained("username/sql-error-classifier") | |
| result = clf.predict( | |
| question="...", schema="...", correct_query="...", student_query="..." | |
| ) | |
| """ | |
| def __init__(self, model, label_map: Dict[int, str]): | |
| self.model = model | |
| self.label_map = label_map | |
| def predict( | |
| self, | |
| question: str, | |
| schema: str, | |
| correct_query: str, | |
| student_query: str, | |
| error_message: Optional[str] = None, | |
| top_k: int = 3, | |
| ) -> Dict[str, Any]: | |
| ctx = QueryContext( | |
| question=question, | |
| schema=schema, | |
| correct_query=correct_query, | |
| student_query=student_query, | |
| error_message=error_message, | |
| ) | |
| proba = self.model.predict_proba([ctx])[0] | |
| classes = self.model.classes_ | |
| ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True) | |
| best_id = int(ranked[0][0]) | |
| diagnostics: Dict[str, Any] = {} | |
| if isinstance(self.model, CrossEncoderClassifier): | |
| diagnostics["pair_scores"] = self.model.explain_pair_scores(ctx) | |
| elif isinstance(self.model, MultiTowerClassifier): | |
| diagnostics["similarities"] = self.model.explain_similarities(ctx) | |
| return { | |
| "label_id": best_id, | |
| "label_name": self.label_map[best_id], | |
| "confidence": float(ranked[0][1]), | |
| "top_k": [ | |
| { | |
| "label_id": int(cls), | |
| "label_name": self.label_map[int(cls)], | |
| "confidence": float(p), | |
| } | |
| for cls, p in ranked[:top_k] | |
| ], | |
| **diagnostics, | |
| } | |
| def save_pretrained(self, save_directory: Union[str, Path]) -> Path: | |
| """Save model artifacts in Hugging Face Hub layout.""" | |
| save_dir = Path(save_directory) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| if isinstance(self.model, CrossEncoderClassifier): | |
| payload = { | |
| "model_type": "cross_encoder", | |
| "cross_encoder_name": self.model.cross_encoder_name, | |
| "batch_size": self.model.batch_size, | |
| "max_length": self.model.max_length, | |
| "scaler": self.model.scaler, | |
| "classifier": self.model.clf, | |
| "classes_": self.model.classes_, | |
| } | |
| config = { | |
| "model_type": "cross_encoder", | |
| "architecture": "cross-encoder-pairwise", | |
| "cross_encoder_name": self.model.cross_encoder_name, | |
| "batch_size": self.model.batch_size, | |
| "num_labels": len(self.label_map), | |
| "task": "sql-error-classification", | |
| } | |
| elif isinstance(self.model, MultiTowerClassifier): | |
| payload = { | |
| "model_type": "multi_tower", | |
| "encoder_name": self.model.encoder_name, | |
| "batch_size": self.model.batch_size, | |
| "scaler": self.model.scaler, | |
| "classifier": self.model.clf, | |
| "classes_": self.model.classes_, | |
| } | |
| config = { | |
| "model_type": "multi_tower", | |
| "architecture": "multi-tower-semantic-comparison", | |
| "encoder_name": self.model.encoder_name, | |
| "batch_size": self.model.batch_size, | |
| "num_labels": len(self.label_map), | |
| "task": "sql-error-classification", | |
| } | |
| else: | |
| raise ValueError("Only cross_encoder and multi_tower models can be published") | |
| joblib.dump(payload, save_dir / CLASSIFIER_NAME) | |
| with open(save_dir / CONFIG_NAME, "w") as f: | |
| json.dump(config, f, indent=2) | |
| categories = load_categories() | |
| cat_data = [ | |
| {"id": c.id, "name": c.name, "description": c.description} | |
| for c in categories | |
| ] | |
| with open(save_dir / CATEGORIES_NAME, "w") as f: | |
| json.dump(cat_data, f, indent=2) | |
| return save_dir | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: Union[str, Path], | |
| *, | |
| token: Optional[str] = None, | |
| ) -> "SQLLErrorClassifierHF": | |
| """Load from a local directory or Hugging Face Hub repo.""" | |
| path = _resolve_model_path(pretrained_model_name_or_path, token=token) | |
| with open(path / CONFIG_NAME) as f: | |
| config = json.load(f) | |
| with open(path / CATEGORIES_NAME) as f: | |
| categories = json.load(f) | |
| label_map = {c["id"]: c["name"] for c in categories} | |
| obj = joblib.load(path / CLASSIFIER_NAME) | |
| model_type = config.get("model_type", obj.get("model_type")) | |
| if model_type == "cross_encoder": | |
| model = CrossEncoderClassifier( | |
| cross_encoder_name=obj.get( | |
| "cross_encoder_name", | |
| config.get("cross_encoder_name", "cross-encoder/ms-marco-MiniLM-L6-v2"), | |
| ), | |
| batch_size=obj.get("batch_size", 32), | |
| max_length=obj.get("max_length", 512), | |
| ) | |
| model.scaler = obj["scaler"] | |
| model.clf = obj["classifier"] | |
| model.classes_ = obj.get("classes_", obj["classifier"].classes_) | |
| else: | |
| model = MultiTowerClassifier( | |
| encoder_name=obj.get("encoder_name", config.get("encoder_name", DEFAULT_ENCODER)), | |
| batch_size=obj.get("batch_size", 256), | |
| ) | |
| model.scaler = obj["scaler"] | |
| model.clf = obj["classifier"] | |
| model.classes_ = obj.get("classes_", obj["classifier"].classes_) | |
| return cls(model=model, label_map=label_map) | |
| def _resolve_model_path( | |
| pretrained_model_name_or_path: Union[str, Path], | |
| token: Optional[str] = None, | |
| ) -> Path: | |
| local = Path(pretrained_model_name_or_path) | |
| if local.exists() and (local / CONFIG_NAME).exists(): | |
| return local | |
| from huggingface_hub import snapshot_download | |
| return Path( | |
| snapshot_download( | |
| repo_id=str(pretrained_model_name_or_path), | |
| token=token, | |
| allow_patterns=[CONFIG_NAME, CLASSIFIER_NAME, CATEGORIES_NAME], | |
| ) | |
| ) | |
| def package_for_hub(model_path: Path, output_dir: Path) -> Path: | |
| """Convert a local joblib model into HF Hub layout.""" | |
| sklearn_model = load_model(model_path) | |
| if not isinstance(sklearn_model, SUPPORTED_CONTEXT_MODELS): | |
| raise ValueError( | |
| "Only cross_encoder and multi_tower models can be published to Hugging Face Hub" | |
| ) | |
| label_map = {c.id: c.name for c in load_categories()} | |
| wrapper = SQLLErrorClassifierHF(model=sklearn_model, label_map=label_map) | |
| return wrapper.save_pretrained(output_dir) | |