#!/usr/bin/env python3 """Package and push the SQL error classifier to Hugging Face Hub.""" from __future__ import annotations import argparse import os import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from huggingface_hub import HfApi, create_repo from src.huggingface import package_for_hub DEFAULT_MODEL = PROJECT_ROOT / "models" / "multi_tower_dev.joblib" DEFAULT_PACKAGE = PROJECT_ROOT / "models" / "hf_package" MODEL_CARD = PROJECT_ROOT / "hub" / "MODEL_CARD.md" def push( model_path: Path = DEFAULT_MODEL, package_dir: Path = DEFAULT_PACKAGE, repo_id: str = "", private: bool = False, token: str | None = None, ) -> str: if not repo_id: raise ValueError("--repo-id is required (e.g. your-username/sql-error-classifier)") token = token or os.getenv("HF_TOKEN") api = HfApi(token=token) print(f"Packaging model from {model_path}...") package_for_hub(model_path, package_dir) print(f"Creating repo {repo_id}...") create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token) print("Uploading model files...") api.upload_folder( folder_path=str(package_dir), repo_id=repo_id, repo_type="model", token=token, ) if MODEL_CARD.exists(): api.upload_file( path_or_fileobj=str(MODEL_CARD), path_in_repo="README.md", repo_id=repo_id, repo_type="model", token=token, ) url = f"https://huggingface.co/{repo_id}" print(f"Done: {url}") return url def main() -> None: parser = argparse.ArgumentParser(description="Push SQL error classifier to HF Hub") parser.add_argument("--model", type=Path, default=DEFAULT_MODEL) parser.add_argument("--package-dir", type=Path, default=DEFAULT_PACKAGE) parser.add_argument( "--repo-id", type=str, required=True, help="Hugging Face repo id, e.g. nishantgupta/sql-error-classifier", ) parser.add_argument("--private", action="store_true") parser.add_argument("--token", type=str, default=None) args = parser.parse_args() push( model_path=args.model, package_dir=args.package_dir, repo_id=args.repo_id, private=args.private, token=args.token, ) if __name__ == "__main__": main()