Spaces:
Running
Running
| """ | |
| AURIS startup script — downloads model artifacts from HuggingFace Hub | |
| before the API server starts. | |
| Usage (in Dockerfile CMD): | |
| python startup.py && uvicorn app.main:app --host 0.0.0.0 --port 7860 | |
| Environment variables: | |
| HF_TOKEN — HuggingFace access token (optional for public repos) | |
| AURIS_MODELS_REPO — HF repo ID, default: Rthur2003/auris-models | |
| MODELS_DIR — Local destination, default: /app/models | |
| SKIP_MODEL_DOWNLOAD — Set to "1" to skip (for local dev) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| REPO_ID = os.getenv("AURIS_MODELS_REPO", "Rthur2003/auris-models") | |
| MODELS_DIR = Path(os.getenv("MODELS_DIR", "/app/models")) | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| SKIP = os.getenv("SKIP_MODEL_DOWNLOAD", "0") == "1" | |
| # Files that must exist for the API to work | |
| REQUIRED_FILES = [ | |
| "auris_classifier_v1.pkl", | |
| "feature_scaler_v1.pkl", | |
| "feature_columns_v1.json", | |
| "feature_stats_v1.json", | |
| "training_results.json", | |
| "deep_learning_results.json", | |
| "model_lightgbm.pkl", | |
| "model_xgboost.pkl", | |
| "model_random_forest.pkl", | |
| "model_gradient_boosting.pkl", | |
| "model_svm_rbf.pkl", | |
| "model_mlp_neural_network.pkl", | |
| "model_logistic_regression.pkl", | |
| "model_dl_deep_mlp_512_256_128_64.pkl", | |
| "model_dl_1d_cnn.pkl", | |
| "model_dl_residual_mlp_3_blocks.pkl", | |
| "model_dl_attention_mlp.pkl", | |
| "wav2vec2_auris_v1.pt", | |
| ] | |
| # Large files that are optional (wav2vec2 tower works without wav2vec2) | |
| OPTIONAL_FILES = { | |
| "wav2vec2_auris_v1.pt", | |
| } | |
| def _already_downloaded() -> bool: | |
| """Return True if all required non-optional files already exist.""" | |
| missing = [ | |
| f for f in REQUIRED_FILES | |
| if f not in OPTIONAL_FILES and not (MODELS_DIR / f).exists() | |
| ] | |
| if missing: | |
| print(f"[startup] Missing files: {missing}") | |
| return False | |
| return True | |
| def download_models() -> None: | |
| if SKIP: | |
| print("[startup] SKIP_MODEL_DOWNLOAD=1 — skipping download.") | |
| return | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| if _already_downloaded(): | |
| print("[startup] All required model files already present — skipping download.") | |
| return | |
| print(f"[startup] Downloading models from {REPO_ID} → {MODELS_DIR}") | |
| t0 = time.time() | |
| try: | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| except ImportError: | |
| print("[startup] huggingface_hub not installed — pip install huggingface-hub") | |
| sys.exit(1) | |
| kwargs = {"repo_id": REPO_ID, "repo_type": "model"} | |
| if HF_TOKEN: | |
| kwargs["token"] = HF_TOKEN | |
| # Get list of files in the repo | |
| try: | |
| repo_files = list(list_repo_files(**{k: v for k, v in kwargs.items() if k != "token"}, | |
| token=HF_TOKEN)) | |
| except Exception as e: | |
| print(f"[startup] Cannot list repo files: {e}") | |
| print("[startup] Trying to download known files directly...") | |
| repo_files = REQUIRED_FILES | |
| errors: list[str] = [] | |
| for filename in REQUIRED_FILES: | |
| dest = MODELS_DIR / filename | |
| if dest.exists(): | |
| print(f"[startup] skip {filename} (exists)") | |
| continue | |
| is_optional = filename in OPTIONAL_FILES | |
| if filename not in repo_files and is_optional: | |
| print(f"[startup] skip {filename} (not in repo, optional)") | |
| continue | |
| try: | |
| print(f"[startup] dl {filename} ...", end=" ", flush=True) | |
| path = hf_hub_download( | |
| filename=filename, | |
| local_dir=str(MODELS_DIR), | |
| **kwargs, | |
| ) | |
| print(f"OK ({Path(path).stat().st_size / 1024 / 1024:.1f} MB)") | |
| except Exception as e: | |
| if is_optional: | |
| print(f"SKIP (optional: {e})") | |
| else: | |
| print(f"ERROR: {e}") | |
| errors.append(f"{filename}: {e}") | |
| elapsed = time.time() - t0 | |
| print(f"[startup] Download complete in {elapsed:.1f}s") | |
| if errors: | |
| print(f"[startup] FATAL — {len(errors)} required file(s) failed:") | |
| for err in errors: | |
| print(f" - {err}") | |
| print("[startup] Set AURIS_MODELS_REPO and HF_TOKEN env vars if repo is private.") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| download_models() | |
| print("[startup] Ready.") | |