crowncode-backend / startup.py
Rthur2003's picture
fix: correct AURIS model repository ID in Dockerfile, startup script, and upload script
e8c02eb
"""
AURIS startup script — downloads model artifacts from HuggingFace Hub
before the API server starts.
Usage (in Dockerfile CMD):
python startup.py && uvicorn app.main:app --host 0.0.0.0 --port 7860
Environment variables:
HF_TOKEN — HuggingFace access token (optional for public repos)
AURIS_MODELS_REPO — HF repo ID, default: Rthur2003/auris-models
MODELS_DIR — Local destination, default: /app/models
SKIP_MODEL_DOWNLOAD — Set to "1" to skip (for local dev)
"""
from __future__ import annotations
import os
import sys
import time
from pathlib import Path
REPO_ID = os.getenv("AURIS_MODELS_REPO", "Rthur2003/auris-models")
MODELS_DIR = Path(os.getenv("MODELS_DIR", "/app/models"))
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
SKIP = os.getenv("SKIP_MODEL_DOWNLOAD", "0") == "1"
# Files that must exist for the API to work
REQUIRED_FILES = [
"auris_classifier_v1.pkl",
"feature_scaler_v1.pkl",
"feature_columns_v1.json",
"feature_stats_v1.json",
"training_results.json",
"deep_learning_results.json",
"model_lightgbm.pkl",
"model_xgboost.pkl",
"model_random_forest.pkl",
"model_gradient_boosting.pkl",
"model_svm_rbf.pkl",
"model_mlp_neural_network.pkl",
"model_logistic_regression.pkl",
"model_dl_deep_mlp_512_256_128_64.pkl",
"model_dl_1d_cnn.pkl",
"model_dl_residual_mlp_3_blocks.pkl",
"model_dl_attention_mlp.pkl",
"wav2vec2_auris_v1.pt",
]
# Large files that are optional (wav2vec2 tower works without wav2vec2)
OPTIONAL_FILES = {
"wav2vec2_auris_v1.pt",
}
def _already_downloaded() -> bool:
"""Return True if all required non-optional files already exist."""
missing = [
f for f in REQUIRED_FILES
if f not in OPTIONAL_FILES and not (MODELS_DIR / f).exists()
]
if missing:
print(f"[startup] Missing files: {missing}")
return False
return True
def download_models() -> None:
if SKIP:
print("[startup] SKIP_MODEL_DOWNLOAD=1 — skipping download.")
return
MODELS_DIR.mkdir(parents=True, exist_ok=True)
if _already_downloaded():
print("[startup] All required model files already present — skipping download.")
return
print(f"[startup] Downloading models from {REPO_ID}{MODELS_DIR}")
t0 = time.time()
try:
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print("[startup] huggingface_hub not installed — pip install huggingface-hub")
sys.exit(1)
kwargs = {"repo_id": REPO_ID, "repo_type": "model"}
if HF_TOKEN:
kwargs["token"] = HF_TOKEN
# Get list of files in the repo
try:
repo_files = list(list_repo_files(**{k: v for k, v in kwargs.items() if k != "token"},
token=HF_TOKEN))
except Exception as e:
print(f"[startup] Cannot list repo files: {e}")
print("[startup] Trying to download known files directly...")
repo_files = REQUIRED_FILES
errors: list[str] = []
for filename in REQUIRED_FILES:
dest = MODELS_DIR / filename
if dest.exists():
print(f"[startup] skip {filename} (exists)")
continue
is_optional = filename in OPTIONAL_FILES
if filename not in repo_files and is_optional:
print(f"[startup] skip {filename} (not in repo, optional)")
continue
try:
print(f"[startup] dl {filename} ...", end=" ", flush=True)
path = hf_hub_download(
filename=filename,
local_dir=str(MODELS_DIR),
**kwargs,
)
print(f"OK ({Path(path).stat().st_size / 1024 / 1024:.1f} MB)")
except Exception as e:
if is_optional:
print(f"SKIP (optional: {e})")
else:
print(f"ERROR: {e}")
errors.append(f"{filename}: {e}")
elapsed = time.time() - t0
print(f"[startup] Download complete in {elapsed:.1f}s")
if errors:
print(f"[startup] FATAL — {len(errors)} required file(s) failed:")
for err in errors:
print(f" - {err}")
print("[startup] Set AURIS_MODELS_REPO and HF_TOKEN env vars if repo is private.")
sys.exit(1)
if __name__ == "__main__":
download_models()
print("[startup] Ready.")