nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
2.43 kB
#!/usr/bin/env python3
"""Package and push the SQL error classifier to Hugging Face Hub."""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from huggingface_hub import HfApi, create_repo
from src.huggingface import package_for_hub
DEFAULT_MODEL = PROJECT_ROOT / "models" / "multi_tower_dev.joblib"
DEFAULT_PACKAGE = PROJECT_ROOT / "models" / "hf_package"
MODEL_CARD = PROJECT_ROOT / "hub" / "MODEL_CARD.md"
def push(
model_path: Path = DEFAULT_MODEL,
package_dir: Path = DEFAULT_PACKAGE,
repo_id: str = "",
private: bool = False,
token: str | None = None,
) -> str:
if not repo_id:
raise ValueError("--repo-id is required (e.g. your-username/sql-error-classifier)")
token = token or os.getenv("HF_TOKEN")
api = HfApi(token=token)
print(f"Packaging model from {model_path}...")
package_for_hub(model_path, package_dir)
print(f"Creating repo {repo_id}...")
create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token)
print("Uploading model files...")
api.upload_folder(
folder_path=str(package_dir),
repo_id=repo_id,
repo_type="model",
token=token,
)
if MODEL_CARD.exists():
api.upload_file(
path_or_fileobj=str(MODEL_CARD),
path_in_repo="README.md",
repo_id=repo_id,
repo_type="model",
token=token,
)
url = f"https://huggingface.co/{repo_id}"
print(f"Done: {url}")
return url
def main() -> None:
parser = argparse.ArgumentParser(description="Push SQL error classifier to HF Hub")
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL)
parser.add_argument("--package-dir", type=Path, default=DEFAULT_PACKAGE)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Hugging Face repo id, e.g. nishantgupta/sql-error-classifier",
)
parser.add_argument("--private", action="store_true")
parser.add_argument("--token", type=str, default=None)
args = parser.parse_args()
push(
model_path=args.model,
package_dir=args.package_dir,
repo_id=args.repo_id,
private=args.private,
token=args.token,
)
if __name__ == "__main__":
main()