GetAround / api.py
Zbehel
Ajouter openpyxl pour lire les fichiers Excel
ab80adf
import mlflow
import uvicorn
import pandas as pd
from pydantic import BaseModel
from typing import Literal, List, Union
from fastapi import FastAPI, File, UploadFile
import joblib
# Log model from mlflow
logged_model = 'runs:/.../model'
# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)
tags_metadata = [
{
"name": "Machine Learning",
"description": "Prediction Endpoint."
}
]
app = FastAPI(
title="Car price prediction API",
openapi_tags=tags_metadata
)
class PredictionFeatures(BaseModel):
model_key: str
mileage: int
engine_power: int
fuel: str
car_type: str
private_parking_available: bool
has_gps: bool
has_air_conditioning: bool
automatic_car: bool
has_getaround_connect: bool
has_speed_regulator: bool
winter_tires: bool
@app.get("/", tags=["Introduction Endpoints"])
async def index():
"""
Simply returns a welcome message!
"""
message = "Hello world! This `/` is the most simple and default endpoint. If you want to learn more, check out documentation of the api at `/docs`"
return message
@app.post("/predict", tags=["Machine Learning"])
async def predict(predictionFeatures: PredictionFeatures):
# Read data
input_data = pd.DataFrame({
"model_key": [predictionFeatures.model_key],
"mileage": [predictionFeatures.mileage],
"engine_power": [predictionFeatures.engine_power],
"fuel": [predictionFeatures.fuel],
"car_type": [predictionFeatures.car_type],
"private_parking_available": [predictionFeatures.private_parking_available],
"has_gps": [predictionFeatures.has_gps],
"has_air_conditioning": [predictionFeatures.has_air_conditioning],
"automatic_car": [predictionFeatures.automatic_car],
"has_getaround_connect": [predictionFeatures.has_getaround_connect],
"has_speed_regulator": [predictionFeatures.has_speed_regulator],
"winter_tires": [predictionFeatures.winter_tires]
})
prediction = loaded_model.predict(input_data)
# Format response
response = {"prediction": prediction.tolist()[0]}
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)