Spaces:
Running
Running
File size: 4,470 Bytes
c1db98e e8c02eb c1db98e e8c02eb c1db98e a55f921 c1db98e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """
AURIS startup script — downloads model artifacts from HuggingFace Hub
before the API server starts.
Usage (in Dockerfile CMD):
python startup.py && uvicorn app.main:app --host 0.0.0.0 --port 7860
Environment variables:
HF_TOKEN — HuggingFace access token (optional for public repos)
AURIS_MODELS_REPO — HF repo ID, default: Rthur2003/auris-models
MODELS_DIR — Local destination, default: /app/models
SKIP_MODEL_DOWNLOAD — Set to "1" to skip (for local dev)
"""
from __future__ import annotations
import os
import sys
import time
from pathlib import Path
REPO_ID = os.getenv("AURIS_MODELS_REPO", "Rthur2003/auris-models")
MODELS_DIR = Path(os.getenv("MODELS_DIR", "/app/models"))
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
SKIP = os.getenv("SKIP_MODEL_DOWNLOAD", "0") == "1"
# Files that must exist for the API to work
REQUIRED_FILES = [
"auris_classifier_v1.pkl",
"feature_scaler_v1.pkl",
"feature_columns_v1.json",
"feature_stats_v1.json",
"training_results.json",
"deep_learning_results.json",
"model_lightgbm.pkl",
"model_xgboost.pkl",
"model_random_forest.pkl",
"model_gradient_boosting.pkl",
"model_svm_rbf.pkl",
"model_mlp_neural_network.pkl",
"model_logistic_regression.pkl",
"model_dl_deep_mlp_512_256_128_64.pkl",
"model_dl_1d_cnn.pkl",
"model_dl_residual_mlp_3_blocks.pkl",
"model_dl_attention_mlp.pkl",
"wav2vec2_auris_v1.pt",
]
# Large files that are optional (wav2vec2 tower works without wav2vec2)
OPTIONAL_FILES = {
"wav2vec2_auris_v1.pt",
}
def _already_downloaded() -> bool:
"""Return True if all required non-optional files already exist."""
missing = [
f for f in REQUIRED_FILES
if f not in OPTIONAL_FILES and not (MODELS_DIR / f).exists()
]
if missing:
print(f"[startup] Missing files: {missing}")
return False
return True
def download_models() -> None:
if SKIP:
print("[startup] SKIP_MODEL_DOWNLOAD=1 — skipping download.")
return
MODELS_DIR.mkdir(parents=True, exist_ok=True)
if _already_downloaded():
print("[startup] All required model files already present — skipping download.")
return
print(f"[startup] Downloading models from {REPO_ID} → {MODELS_DIR}")
t0 = time.time()
try:
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print("[startup] huggingface_hub not installed — pip install huggingface-hub")
sys.exit(1)
kwargs = {"repo_id": REPO_ID, "repo_type": "model"}
if HF_TOKEN:
kwargs["token"] = HF_TOKEN
# Get list of files in the repo
try:
repo_files = list(list_repo_files(**{k: v for k, v in kwargs.items() if k != "token"},
token=HF_TOKEN))
except Exception as e:
print(f"[startup] Cannot list repo files: {e}")
print("[startup] Trying to download known files directly...")
repo_files = REQUIRED_FILES
errors: list[str] = []
for filename in REQUIRED_FILES:
dest = MODELS_DIR / filename
if dest.exists():
print(f"[startup] skip {filename} (exists)")
continue
is_optional = filename in OPTIONAL_FILES
if filename not in repo_files and is_optional:
print(f"[startup] skip {filename} (not in repo, optional)")
continue
try:
print(f"[startup] dl {filename} ...", end=" ", flush=True)
path = hf_hub_download(
filename=filename,
local_dir=str(MODELS_DIR),
**kwargs,
)
print(f"OK ({Path(path).stat().st_size / 1024 / 1024:.1f} MB)")
except Exception as e:
if is_optional:
print(f"SKIP (optional: {e})")
else:
print(f"ERROR: {e}")
errors.append(f"{filename}: {e}")
elapsed = time.time() - t0
print(f"[startup] Download complete in {elapsed:.1f}s")
if errors:
print(f"[startup] FATAL — {len(errors)} required file(s) failed:")
for err in errors:
print(f" - {err}")
print("[startup] Set AURIS_MODELS_REPO and HF_TOKEN env vars if repo is private.")
sys.exit(1)
if __name__ == "__main__":
download_models()
print("[startup] Ready.")
|