""" AURIS startup script — downloads model artifacts from HuggingFace Hub before the API server starts. Usage (in Dockerfile CMD): python startup.py && uvicorn app.main:app --host 0.0.0.0 --port 7860 Environment variables: HF_TOKEN — HuggingFace access token (optional for public repos) AURIS_MODELS_REPO — HF repo ID, default: Rthur2003/auris-models MODELS_DIR — Local destination, default: /app/models SKIP_MODEL_DOWNLOAD — Set to "1" to skip (for local dev) """ from __future__ import annotations import os import sys import time from pathlib import Path REPO_ID = os.getenv("AURIS_MODELS_REPO", "Rthur2003/auris-models") MODELS_DIR = Path(os.getenv("MODELS_DIR", "/app/models")) HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") SKIP = os.getenv("SKIP_MODEL_DOWNLOAD", "0") == "1" # Files that must exist for the API to work REQUIRED_FILES = [ "auris_classifier_v1.pkl", "feature_scaler_v1.pkl", "feature_columns_v1.json", "feature_stats_v1.json", "training_results.json", "deep_learning_results.json", "model_lightgbm.pkl", "model_xgboost.pkl", "model_random_forest.pkl", "model_gradient_boosting.pkl", "model_svm_rbf.pkl", "model_mlp_neural_network.pkl", "model_logistic_regression.pkl", "model_dl_deep_mlp_512_256_128_64.pkl", "model_dl_1d_cnn.pkl", "model_dl_residual_mlp_3_blocks.pkl", "model_dl_attention_mlp.pkl", "wav2vec2_auris_v1.pt", ] # Large files that are optional (wav2vec2 tower works without wav2vec2) OPTIONAL_FILES = { "wav2vec2_auris_v1.pt", } def _already_downloaded() -> bool: """Return True if all required non-optional files already exist.""" missing = [ f for f in REQUIRED_FILES if f not in OPTIONAL_FILES and not (MODELS_DIR / f).exists() ] if missing: print(f"[startup] Missing files: {missing}") return False return True def download_models() -> None: if SKIP: print("[startup] SKIP_MODEL_DOWNLOAD=1 — skipping download.") return MODELS_DIR.mkdir(parents=True, exist_ok=True) if _already_downloaded(): print("[startup] All required model files already present — skipping download.") return print(f"[startup] Downloading models from {REPO_ID} → {MODELS_DIR}") t0 = time.time() try: from huggingface_hub import hf_hub_download, list_repo_files except ImportError: print("[startup] huggingface_hub not installed — pip install huggingface-hub") sys.exit(1) kwargs = {"repo_id": REPO_ID, "repo_type": "model"} if HF_TOKEN: kwargs["token"] = HF_TOKEN # Get list of files in the repo try: repo_files = list(list_repo_files(**{k: v for k, v in kwargs.items() if k != "token"}, token=HF_TOKEN)) except Exception as e: print(f"[startup] Cannot list repo files: {e}") print("[startup] Trying to download known files directly...") repo_files = REQUIRED_FILES errors: list[str] = [] for filename in REQUIRED_FILES: dest = MODELS_DIR / filename if dest.exists(): print(f"[startup] skip {filename} (exists)") continue is_optional = filename in OPTIONAL_FILES if filename not in repo_files and is_optional: print(f"[startup] skip {filename} (not in repo, optional)") continue try: print(f"[startup] dl {filename} ...", end=" ", flush=True) path = hf_hub_download( filename=filename, local_dir=str(MODELS_DIR), **kwargs, ) print(f"OK ({Path(path).stat().st_size / 1024 / 1024:.1f} MB)") except Exception as e: if is_optional: print(f"SKIP (optional: {e})") else: print(f"ERROR: {e}") errors.append(f"{filename}: {e}") elapsed = time.time() - t0 print(f"[startup] Download complete in {elapsed:.1f}s") if errors: print(f"[startup] FATAL — {len(errors)} required file(s) failed:") for err in errors: print(f" - {err}") print("[startup] Set AURIS_MODELS_REPO and HF_TOKEN env vars if repo is private.") sys.exit(1) if __name__ == "__main__": download_models() print("[startup] Ready.")