File size: 9,475 Bytes
e72f783
 
 
 
 
 
 
 
ce1542b
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f4176
 
 
 
 
 
 
 
 
 
 
 
2ac1b86
64f4176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
64f4176
 
 
 
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3474d8
343ba20
c3474d8
343ba20
c3474d8
343ba20
c3474d8
343ba20
c3474d8
343ba20
e72f783
ce1542b
 
 
 
c3474d8
ce1542b
 
 
 
 
e72f783
 
 
ce1542b
 
 
e72f783
ce1542b
 
e72f783
ce1542b
 
e72f783
c3474d8
e72f783
 
c3474d8
e72f783
ce1542b
 
 
e72f783
ce1542b
 
c3474d8
e72f783
ce1542b
 
e72f783
c3474d8
e72f783
c3474d8
 
e72f783
ce1542b
 
 
e72f783
ce1542b
 
c3474d8
ce1542b
 
 
e72f783
 
 
343ba20
 
ce1542b
 
343ba20
c3474d8
343ba20
c3474d8
 
343ba20
c3474d8
 
e72f783
 
ce1542b
 
c3474d8
ce1542b
 
343ba20
 
ce1542b
 
c3474d8
343ba20
c3474d8
343ba20
e72f783
ce1542b
 
 
e72f783
c3474d8
 
 
 
e72f783
 
 
 
 
 
 
 
 
 
 
ce1542b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# api/startup.py
# All model and index loading happens here β€” once at FastAPI startup
# Everything stays in memory for the entire server lifetime
# Never load models per-request

import os
import json
import time
import sys
import torch
import clip

from src.patchcore import patchcore
from src.retriever import retriever
from src.graph import knowledge_graph
from src.depth import depth_estimator
from src.xai import gradcam, shap_explainer
from src.cache import inference_cache
from src.orchestrator import init_orchestrator
from api.logger import init_logger


# Startup timestamp β€” used for uptime calculation in /health
STARTUP_TIME = None
MODEL_VERSION = "v1.0"

def download_artifacts():
    """Download all required artifacts from HF Dataset at startup."""
    from huggingface_hub import hf_hub_download, snapshot_download
    import shutil

    HF_REPO = "CaffeinatedCoding/anomalyos-logs"
    token = os.environ.get("HF_TOKEN")
    
    os.makedirs("data", exist_ok=True)

    files_to_download = [
        ("models/pca_256.pkl",              "data/pca_256.pkl"),
        ("models/midas_small.onnx",         "data/midas_small.onnx"),
        ("configs/thresholds.json",          "data/thresholds.json"),
        ("graph/knowledge_graph.json",       "data/knowledge_graph.json"),
        ("indexes/index1_category.faiss",    "data/index1_category.faiss"),
        ("indexes/index1_metadata.json",     "data/index1_metadata.json"),
        ("indexes/index2_defect.faiss",      "data/index2_defect.faiss"),
        ("indexes/index2_metadata.json",     "data/index2_metadata.json"),
    ]

    # Index 3 β€” one per category
    categories = [
        'bottle','cable','capsule','carpet','grid','hazelnut',
        'leather','metal_nut','pill','screw','tile','toothbrush',
        'transistor','wood','zipper'
    ]
    for cat in categories:
        files_to_download.append((
            f"indexes/index3_{cat}.faiss",
            f"data/index3_{cat}.faiss"
        ))

    for repo_path, local_path in files_to_download:
        if os.path.exists(local_path):
            print(f"Already exists: {local_path}")
            continue
        try:
            print(f"Downloading {repo_path}...")
            downloaded = hf_hub_download(
                repo_id=HF_REPO,
                filename=repo_path,
                repo_type="dataset",
                token=token,
                local_dir="/tmp/artifacts"
            )
            shutil.copy(downloaded, local_path)
            print(f"  β†’ {local_path}")
        except Exception as e:
            print(f"  WARNING: Could not download {repo_path}: {e}")

def load_all():
    """
    Called once from FastAPI lifespan on startup.
    Order matters β€” patchcore before orchestrator, logger before anything logs.
    """
    global STARTUP_TIME
    STARTUP_TIME = time.time()

    print("=" * 50)
    print("AnomalyOS startup sequence")
    print("=" * 50)

    # Download artifacts first
    download_artifacts()
    

    # ── CPU thread tuning ─────────────────────────────────────
    # HF Spaces CPU Basic = 2 vCPU
    # Limit PyTorch threads to match β€” prevents over-subscription
    torch.set_num_threads(2)
    torch.set_default_dtype(torch.float32)
    print(f"PyTorch threads: {torch.get_num_threads()}")

    # ── Logger ────────────────────────────────────────────────
    hf_token = os.environ.get("HF_TOKEN", "")
    init_logger(hf_token)

    # ── PatchCore extractor ───────────────────────────────────
    patchcore.load()

    # ── FAISS indexes ─────────────────────────────────────────
    # Index 3 is lazy-loaded β€” not loaded here
    retriever.load_indexes()

    # ── Knowledge graph ───────────────────────────────────────
    knowledge_graph.load()

    # ── MiDaS depth estimator ─────────────────────────────────
    try:
        depth_estimator.load()
    except FileNotFoundError as e:
        print(f"WARNING: {e}")
        print("Depth features will return zeros β€” inference continues")

    # ── CLIP model ────────────────────────────────────────────
    # Loaded here, injected into orchestrator
    print("Loading CLIP ViT-B/32...", flush=True)
    try:
        print("  [Downloading CLIP weights...]", flush=True)
        clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
        print("  [CLIP weights loaded, setting eval mode...]", flush=True)
        clip_model.eval()
        print("CLIP loaded βœ“", flush=True)
    except Exception as e:
        print(f"ERROR loading CLIP: {e}", flush=True)
        raise

    # DEBUG: Aggressive output buffer flushing after CLIP
    sys.stdout.write("[DEBUG] Point 1: After CLIP load\n")
    sys.stdout.flush()
    
    print("Loading thresholds...", flush=True)
    sys.stdout.write("[DEBUG] Point 2: After thresholds print\n")
    sys.stdout.flush()
    
    sys.stdout.write("[DEBUG] Point 2a: Building thresholds path\n")
    sys.stdout.flush()
    thresholds_path = os.path.join(
        os.environ.get("DATA_DIR", "data"), "thresholds.json"
    )
    sys.stdout.write(f"[DEBUG] Point 2b: Checking if {thresholds_path} exists\n")
    sys.stdout.flush()
    
    if os.path.exists(thresholds_path):
        sys.stdout.write("[DEBUG] Point 2c: File exists, opening\n")
        sys.stdout.flush()
        with open(thresholds_path) as f:
            sys.stdout.write("[DEBUG] Point 2d: File opened, loading JSON\n")
            sys.stdout.flush()
            thresholds = json.load(f)
        print(f"Thresholds loaded βœ“ {len(thresholds)} categories", flush=True)
    else:
        thresholds = {}
        print("WARNING: thresholds.json not found β€” using score > 0.5 fallback", flush=True)

    sys.stdout.write("[DEBUG] Point 3: After thresholds loading\n")
    sys.stdout.flush()

    # ── GradCAM++ ─────────────────────────────────────────────
    sys.stdout.write("[DEBUG] Point 4: Before GradCAM load\n")
    sys.stdout.flush()
    print("Loading GradCAM++...", flush=True)
    try:
        sys.stdout.write("[DEBUG] Point 4a: Inside GradCAM load try\n")
        sys.stdout.flush()
        gradcam.load()
        print("GradCAM++ loaded βœ“", flush=True)
    except Exception as e:
        print(f"WARNING: GradCAM++ load failed: {e}", flush=True)
        print("Forensics mode will run without GradCAM++", flush=True)

    sys.stdout.write("[DEBUG] Point 5: After GradCAM load\n")
    sys.stdout.flush()

    # ── SHAP background ───────────────────────────────────────
    sys.stdout.write("[DEBUG] Point 6: Before SHAP load\n")
    sys.stdout.flush()
    print("Loading SHAP background...", flush=True)
    sys.stdout.write("[DEBUG] Point 6a: After SHAP print\n")
    sys.stdout.flush()
    
    bg_path = os.path.join(
        os.environ.get("DATA_DIR", "data"), "shap_background.npy"
    )
    try:
        if os.path.exists(bg_path):
            sys.stdout.write("[DEBUG] Point 6b: SHAP file exists, loading\n")
            sys.stdout.flush()
            shap_explainer.load_background(bg_path)
            print("SHAP background loaded βœ“", flush=True)
        else:
            print(f"WARNING: SHAP background not found at {bg_path}", flush=True)
            print("SHAP explanations will use default background", flush=True)
    except Exception as e:
        print(f"WARNING: SHAP background load failed: {e}", flush=True)
        print("SHAP explanations will use default background", flush=True)

    # ── Inject into orchestrator ──────────────────────────────
    sys.stdout.write("[DEBUG] Point 7: Before orchestrator init\n")
    sys.stdout.flush()
    print("Initializing orchestrator...", flush=True)
    sys.stdout.write("[DEBUG] Point 7a: About to call init_orchestrator\n")
    sys.stdout.flush()
    try:
        init_orchestrator(clip_model, clip_preprocess, thresholds)
        sys.stdout.write("[DEBUG] Point 7b: init_orchestrator returned\n")
        sys.stdout.flush()
        print("Orchestrator initialized βœ“", flush=True)
    except Exception as e:
        print(f"ERROR initializing orchestrator: {e}", flush=True)
        raise

    sys.stdout.write("[DEBUG] Point 8: After orchestrator init β€” about to print completion\n")
    sys.stdout.flush()
    
    elapsed = time.time() - STARTUP_TIME
    print("=" * 50, flush=True)
    print(f"Startup complete in {elapsed:.1f}s βœ“", flush=True)
    print(f"Model version: {MODEL_VERSION}", flush=True)
    print("=" * 50, flush=True)

    return {
        "clip_model": clip_model,
        "clip_preprocess": clip_preprocess,
        "thresholds": thresholds
    }


def get_uptime() -> float:
    if STARTUP_TIME is None:
        return 0.0
    return time.time() - STARTUP_TIME