Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Header, HTTPException, Depends | |
| from pydantic import BaseModel, create_model | |
| import pandas as pd | |
| import joblib | |
| import os | |
| from dotenv import load_dotenv | |
| from automl_engine import load_pipeline | |
| load_dotenv() | |
| app = FastAPI(title="AutoML Model Serving API") | |
| API_SECRET_KEY = os.getenv("API_SECRET_KEY", "supersecretkey123") | |
| def verify_api_key(x_api_key: str = Header(...)): | |
| if x_api_key != API_SECRET_KEY: | |
| raise HTTPException(status_code=403, detail="Invalid API Key") | |
| return x_api_key | |
| # Global storage for the loaded pipeline | |
| model_assets = {"processor": None, "model": None} | |
| def load_latest_model(): | |
| model_dir = "models" | |
| if os.path.exists(model_dir): | |
| files = [f for f in os.listdir(model_dir) if f.endswith(".pkl")] | |
| if files: | |
| # Load the most recent one | |
| latest_file = sorted(files)[-1] | |
| path = os.path.join(model_dir, latest_file) | |
| processor, model = load_pipeline(path) | |
| model_assets["processor"] = processor | |
| model_assets["model"] = model | |
| return True | |
| return False | |
| async def startup_event(): | |
| load_latest_model() | |
| def read_root(): | |
| return {"status": "online", "model_loaded": model_assets["model"] is not None} | |
| class PredictionRequest(BaseModel): | |
| data: list # List of dicts for rows | |
| def predict(request: PredictionRequest): | |
| if model_assets["model"] is None: | |
| if not load_latest_model(): | |
| raise HTTPException(status_code=400, detail="No model loaded. Train a model first.") | |
| try: | |
| df = pd.DataFrame(request.data) | |
| X_proc = model_assets["processor"].transform(df) | |
| predictions = model_assets["model"].predict(X_proc) | |
| # If classifier and label encoder exists, inverse transform | |
| if hasattr(model_assets["processor"], "label_encoder") and model_assets["processor"].label_encoder: | |
| predictions = model_assets["processor"].label_encoder.inverse_transform(predictions) | |
| return {"predictions": predictions.tolist()} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |