from pathlib import Path from fastapi import FastAPI, HTTPException from .registry import BundleConfigError, ModelRegistry, RequestValidationError from .schemas import PredictionRequest def create_app(bundle_root: Path) -> FastAPI: registry = ModelRegistry(bundle_root) app = FastAPI( title="SQuADDS ML Inference API", version="0.1.0", description=( "HTTP API for running inference against ML models trained in " "ML_qubit_design and packaged for the SQuADDS Hugging Face Space." ), ) @app.get("/") def root() -> dict: return { "service": "SQuADDS ML Inference API", "docs": "/docs", "models_endpoint": "/models", "predict_endpoint": "/predict", } @app.get("/health") def health() -> dict: return { "status": "ok", "available_models": registry.available_model_ids(), "bundle_root": str(bundle_root), } @app.get("/models") def list_models() -> dict: return {"models": registry.describe_models()} @app.post("/predict") def predict(request: PredictionRequest) -> dict: try: payload = registry.predict( model_id=request.model_id, inputs=request.inputs, include_scaled_outputs=request.options.include_scaled_outputs, ) except RequestValidationError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except BundleConfigError as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc except Exception as exc: raise HTTPException( status_code=500, detail=f"Unexpected inference error for model '{request.model_id}': {exc}", ) from exc return payload return app