File size: 2,462 Bytes
0660004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import io
import logging
from typing import Tuple

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from PIL import Image
import numpy as np
import tensorflow as tf

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("vehicle-predictor")

MODEL_FILENAME = "complete_model_model.h5"
MODEL_PATH = os.path.join(os.path.dirname(__file__), MODEL_FILENAME)
IMG_SIZE = (224, 224)

CLASS_NAMES = [
    'Ambulance', 'Bicycle', 'Boat', 'Bus', 'Car', 'Helicopter', 'Limousine',
    'Motorcycle', 'PickUp', 'Segway', 'Snowmobile', 'Tank', 'Taxi', 'Truck', 'Van'
]

app = FastAPI(title="Vehicle Type Predictor")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # you can tighten this later if needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load model at startup
try:
    logger.info("๐Ÿš€ Loading model...")
    model = tf.keras.models.load_model(MODEL_PATH)
    logger.info("โœ… Model loaded successfully.")
except Exception as e:
    logger.exception("โŒ Model load failed")
    model = None


class PredictionResponse(BaseModel):
    label: str
    confidence: float


def preprocess_image_file(file_bytes: bytes) -> np.ndarray:
    img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
    img = img.resize(IMG_SIZE)
    arr = np.asarray(img).astype("float32") / 255.0
    arr = np.expand_dims(arr, axis=0)
    return arr


@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File must be an image")

    try:
        contents = await file.read()
        x = preprocess_image_file(contents)
        preds = model.predict(x)
        idx = int(np.argmax(preds[0]))
        label = CLASS_NAMES[idx]
        confidence = float(preds[0][idx])
        logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}")
        return PredictionResponse(label=label, confidence=confidence)
    except Exception as e:
        logger.exception("Prediction failed")
        raise HTTPException(status_code=500, detail="Prediction failed")


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": model is not None}