import os import io import logging from typing import Tuple from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from PIL import Image import numpy as np import tensorflow as tf logging.basicConfig(level=logging.INFO) logger = logging.getLogger("vehicle-predictor") MODEL_FILENAME = "complete_model_model.h5" MODEL_PATH = os.path.join(os.path.dirname(__file__), MODEL_FILENAME) IMG_SIZE = (224, 224) CLASS_NAMES = [ 'Ambulance', 'Bicycle', 'Boat', 'Bus', 'Car', 'Helicopter', 'Limousine', 'Motorcycle', 'PickUp', 'Segway', 'Snowmobile', 'Tank', 'Taxi', 'Truck', 'Van' ] app = FastAPI(title="Vehicle Type Predictor") app.add_middleware( CORSMiddleware, allow_origins=["*"], # you can tighten this later if needed allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model at startup try: logger.info("🚀 Loading model...") model = tf.keras.models.load_model(MODEL_PATH) logger.info("✅ Model loaded successfully.") except Exception as e: logger.exception("❌ Model load failed") model = None class PredictionResponse(BaseModel): label: str confidence: float def preprocess_image_file(file_bytes: bytes) -> np.ndarray: img = Image.open(io.BytesIO(file_bytes)).convert("RGB") img = img.resize(IMG_SIZE) arr = np.asarray(img).astype("float32") / 255.0 arr = np.expand_dims(arr, axis=0) return arr @app.post("/predict", response_model=PredictionResponse) async def predict(file: UploadFile = File(...)): if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") try: contents = await file.read() x = preprocess_image_file(contents) preds = model.predict(x) idx = int(np.argmax(preds[0])) label = CLASS_NAMES[idx] confidence = float(preds[0][idx]) logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}") return PredictionResponse(label=label, confidence=confidence) except Exception as e: logger.exception("Prediction failed") raise HTTPException(status_code=500, detail="Prediction failed") @app.get("/health") def health(): return {"status": "ok", "model_loaded": model is not None}