Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import pickle | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| from typing import List | |
| app = FastAPI(title="Headache Predictor API") | |
| # Global variables for model and threshold | |
| clf = None | |
| threshold = 0.5 | |
| # --- Pydantic Models --- | |
| class SinglePredictionRequest(BaseModel): | |
| features: List[float] | |
| class SinglePredictionResponse(BaseModel): | |
| prediction: int | |
| probability: float | |
| class BatchPredictionRequest(BaseModel): | |
| instances: List[List[float]] | |
| class DayPrediction(BaseModel): | |
| day: int | |
| prediction: int | |
| probability: float | |
| class BatchPredictionResponse(BaseModel): | |
| predictions: List[DayPrediction] | |
| # --- Startup Event --- | |
| async def load_model(): | |
| global clf, threshold | |
| try: | |
| cache_dir = "/tmp/hf_cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| model_path = hf_hub_download( | |
| repo_id="emp-admin/headache-predictor-xgboost", | |
| filename="model.pkl", | |
| cache_dir=cache_dir, | |
| token=hf_token | |
| ) | |
| with open(model_path, "rb") as f: | |
| model_data = pickle.load(f) | |
| if isinstance(model_data, dict): | |
| clf = model_data["model"] | |
| # Load threshold if available, otherwise default to 0.5 | |
| threshold = float(model_data.get("optimal_threshold", 0.5)) | |
| print(f"✅ Model loaded (optimal_threshold={threshold})") | |
| else: | |
| clf = model_data | |
| threshold = 0.5 | |
| print("✅ Model loaded (threshold=0.5 default)") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # --- Endpoints --- | |
| def read_root(): | |
| return { | |
| "message": "Headache Predictor API", | |
| "status": "running", | |
| "endpoints": { | |
| "predict": "/predict - Single day prediction", | |
| "predict_batch": "/predict/batch - 7-day forecast", | |
| "health": "/health" | |
| }, | |
| "examples": { | |
| "single": { | |
| "url": "/predict", | |
| # Example shortened for brevity in display | |
| "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]} | |
| }, | |
| "batch": { | |
| "url": "/predict/batch", | |
| "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"]]} | |
| } | |
| } | |
| } | |
| def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": clf is not None | |
| } | |
| def predict(request: SinglePredictionRequest): | |
| """Predict headache risk for a single day""" | |
| if clf is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Convert input to numpy array | |
| features = np.array(request.features).reshape(1, -1) | |
| # Get probability array for both classes [prob_0, prob_1] | |
| prob_array = clf.predict_proba(features)[0] | |
| # Always return probability of headache (class 1) | |
| headache_probability = float(prob_array[1]) | |
| # Make prediction using the loaded threshold | |
| prediction = 1 if headache_probability >= threshold else 0 | |
| return SinglePredictionResponse( | |
| prediction=int(prediction), | |
| probability=headache_probability | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}") | |
| def predict_batch(request: BatchPredictionRequest): | |
| """Predict headache risk for multiple days (batch)""" | |
| if clf is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| X = np.array(request.instances, dtype=float) | |
| if X.ndim != 2: | |
| raise ValueError(f"Expected 2D array, got shape {X.shape}") | |
| probas = clf.predict_proba(X)[:, 1] # class-1 prob | |
| # Use the global threshold | |
| preds = (probas >= threshold).astype(int) | |
| results = [ | |
| DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i])) | |
| for i in range(len(probas)) | |
| ] | |
| return BatchPredictionResponse(predictions=results) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}") |