File size: 4,470 Bytes
c1db98e
 
 
 
 
 
 
 
 
e8c02eb
c1db98e
 
 
 
 
 
 
 
 
 
 
e8c02eb
c1db98e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55f921
 
 
c1db98e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
AURIS startup script — downloads model artifacts from HuggingFace Hub
before the API server starts.

Usage (in Dockerfile CMD):
    python startup.py && uvicorn app.main:app --host 0.0.0.0 --port 7860

Environment variables:
    HF_TOKEN          — HuggingFace access token (optional for public repos)
    AURIS_MODELS_REPO — HF repo ID, default: Rthur2003/auris-models
    MODELS_DIR        — Local destination, default: /app/models
    SKIP_MODEL_DOWNLOAD — Set to "1" to skip (for local dev)
"""

from __future__ import annotations

import os
import sys
import time
from pathlib import Path

REPO_ID   = os.getenv("AURIS_MODELS_REPO", "Rthur2003/auris-models")
MODELS_DIR = Path(os.getenv("MODELS_DIR", "/app/models"))
HF_TOKEN  = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
SKIP      = os.getenv("SKIP_MODEL_DOWNLOAD", "0") == "1"

# Files that must exist for the API to work
REQUIRED_FILES = [
    "auris_classifier_v1.pkl",
    "feature_scaler_v1.pkl",
    "feature_columns_v1.json",
    "feature_stats_v1.json",
    "training_results.json",
    "deep_learning_results.json",
    "model_lightgbm.pkl",
    "model_xgboost.pkl",
    "model_random_forest.pkl",
    "model_gradient_boosting.pkl",
    "model_svm_rbf.pkl",
    "model_mlp_neural_network.pkl",
    "model_logistic_regression.pkl",
    "model_dl_deep_mlp_512_256_128_64.pkl",
    "model_dl_1d_cnn.pkl",
    "model_dl_residual_mlp_3_blocks.pkl",
    "model_dl_attention_mlp.pkl",
    "wav2vec2_auris_v1.pt",
]

# Large files that are optional (wav2vec2 tower works without wav2vec2)
OPTIONAL_FILES = {
    "wav2vec2_auris_v1.pt",
}


def _already_downloaded() -> bool:
    """Return True if all required non-optional files already exist."""
    missing = [
        f for f in REQUIRED_FILES
        if f not in OPTIONAL_FILES and not (MODELS_DIR / f).exists()
    ]
    if missing:
        print(f"[startup] Missing files: {missing}")
        return False
    return True


def download_models() -> None:
    if SKIP:
        print("[startup] SKIP_MODEL_DOWNLOAD=1 — skipping download.")
        return

    MODELS_DIR.mkdir(parents=True, exist_ok=True)

    if _already_downloaded():
        print("[startup] All required model files already present — skipping download.")
        return

    print(f"[startup] Downloading models from {REPO_ID}{MODELS_DIR}")
    t0 = time.time()

    try:
        from huggingface_hub import hf_hub_download, list_repo_files
    except ImportError:
        print("[startup] huggingface_hub not installed — pip install huggingface-hub")
        sys.exit(1)

    kwargs = {"repo_id": REPO_ID, "repo_type": "model"}
    if HF_TOKEN:
        kwargs["token"] = HF_TOKEN

    # Get list of files in the repo
    try:
        repo_files = list(list_repo_files(**{k: v for k, v in kwargs.items() if k != "token"},
                                          token=HF_TOKEN))
    except Exception as e:
        print(f"[startup] Cannot list repo files: {e}")
        print("[startup] Trying to download known files directly...")
        repo_files = REQUIRED_FILES

    errors: list[str] = []
    for filename in REQUIRED_FILES:
        dest = MODELS_DIR / filename
        if dest.exists():
            print(f"[startup]   skip  {filename} (exists)")
            continue

        is_optional = filename in OPTIONAL_FILES
        if filename not in repo_files and is_optional:
            print(f"[startup]   skip  {filename} (not in repo, optional)")
            continue
        try:
            print(f"[startup]   dl    {filename} ...", end=" ", flush=True)
            path = hf_hub_download(
                filename=filename,
                local_dir=str(MODELS_DIR),
                **kwargs,
            )
            print(f"OK ({Path(path).stat().st_size / 1024 / 1024:.1f} MB)")
        except Exception as e:
            if is_optional:
                print(f"SKIP (optional: {e})")
            else:
                print(f"ERROR: {e}")
                errors.append(f"{filename}: {e}")

    elapsed = time.time() - t0
    print(f"[startup] Download complete in {elapsed:.1f}s")

    if errors:
        print(f"[startup] FATAL — {len(errors)} required file(s) failed:")
        for err in errors:
            print(f"  - {err}")
        print("[startup] Set AURIS_MODELS_REPO and HF_TOKEN env vars if repo is private.")
        sys.exit(1)


if __name__ == "__main__":
    download_models()
    print("[startup] Ready.")