botInfinity commited on
Commit
67d88cf
Β·
verified Β·
1 Parent(s): 6834fe7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +76 -55
main.py CHANGED
@@ -1,55 +1,75 @@
1
- import os
2
  import io
3
- import logging
4
- from typing import Tuple
5
-
6
  from fastapi import FastAPI, File, UploadFile, HTTPException
7
- from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
- from PIL import Image
10
 
11
- # Roboflow inference
12
- from inference import get_model
 
 
13
 
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger("vehicle-predictor")
16
 
17
- # FastAPI setup
18
- app = FastAPI(title="Vehicle Type Predictor")
 
19
 
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["*"], # you can tighten this later if needed
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
 
28
- # Load Roboflow model at startup
29
- ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
30
- MODEL_ID = "vehicle-classification-eapcd/19"
31
-
32
- if ROBOFLOW_API_KEY is None:
33
- logger.error("❌ ROBOFLOW_API_KEY not found in environment variables")
34
- model = None
35
- else:
36
- try:
37
- logger.info("πŸš€ Loading Roboflow model...")
38
- model = get_model(model_id=MODEL_ID, api_key=ROBOFLOW_API_KEY)
39
- logger.info("βœ… Roboflow model loaded successfully")
40
- except Exception as e:
41
- logger.exception("❌ Failed to load Roboflow model")
42
- model = None
43
-
44
-
45
- # Response model
46
  class PredictionResponse(BaseModel):
47
  label: str
48
  confidence: float
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.post("/predict", response_model=PredictionResponse)
52
  async def predict(file: UploadFile = File(...)):
 
53
  if model is None:
54
  raise HTTPException(status_code=503, detail="Model not loaded")
55
 
@@ -57,30 +77,31 @@ async def predict(file: UploadFile = File(...)):
57
  raise HTTPException(status_code=400, detail="File must be an image")
58
 
59
  try:
60
- contents = await file.read()
61
-
62
- # Roboflow accepts PIL Image directly
63
- img = Image.open(io.BytesIO(contents)).convert("RGB")
64
 
65
- # Run inference
66
- result = model.infer(img)
67
 
68
- if not result.get("predictions"):
69
- raise HTTPException(status_code=500, detail="No predictions returned")
 
 
70
 
71
- # Take top prediction
72
- pred = result["predictions"][0]
73
- label = pred.get("class", "Unknown")
74
- confidence = float(pred.get("confidence", 0.0))
75
 
76
- logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}")
77
- return PredictionResponse(label=label, confidence=confidence)
 
 
78
 
79
  except Exception as e:
80
- logger.exception("Prediction failed")
81
- raise HTTPException(status_code=500, detail="Prediction failed")
82
 
83
 
 
 
 
84
  @app.get("/health")
85
- def health():
86
- return {"status": "ok", "model_loaded": model is not None}
 
 
1
  import io
2
+ import torch
3
+ import timm
4
+ from PIL import Image
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
6
  from pydantic import BaseModel
7
+ from torchvision import transforms
8
 
9
+ # =========================
10
+ # App Init
11
+ # =========================
12
+ app = FastAPI(title="Vehicle Type Classifier API")
13
 
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print("Using device:", device)
16
 
17
+ MODEL_PATH = "vehicle_classifier_best.pth"
18
+ model = None
19
+ class_names = None
20
 
 
 
 
 
 
 
 
21
 
22
+ # =========================
23
+ # Response Schema
24
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class PredictionResponse(BaseModel):
26
  label: str
27
  confidence: float
28
 
29
 
30
+ # =========================
31
+ # Image Transform
32
+ # =========================
33
+ transform = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(
37
+ mean=[0.485, 0.456, 0.406],
38
+ std=[0.229, 0.224, 0.225]
39
+ )
40
+ ])
41
+
42
+
43
+ # =========================
44
+ # Load Model on Startup
45
+ # =========================
46
+ @app.on_event("startup")
47
+ def load_model():
48
+ global model, class_names
49
+
50
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
51
+ class_names = checkpoint["class_names"]
52
+ num_classes = len(class_names)
53
+
54
+ model = timm.create_model(
55
+ "convnext_tiny",
56
+ pretrained=False,
57
+ num_classes=num_classes
58
+ )
59
+
60
+ model.load_state_dict(checkpoint["model_state_dict"])
61
+ model.to(device)
62
+ model.eval()
63
+
64
+ print("βœ… Model loaded successfully")
65
+
66
+
67
+ # =========================
68
+ # Prediction Endpoint
69
+ # =========================
70
  @app.post("/predict", response_model=PredictionResponse)
71
  async def predict(file: UploadFile = File(...)):
72
+
73
  if model is None:
74
  raise HTTPException(status_code=503, detail="Model not loaded")
75
 
 
77
  raise HTTPException(status_code=400, detail="File must be an image")
78
 
79
  try:
80
+ image_bytes = await file.read()
81
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
82
 
83
+ img_tensor = transform(image).unsqueeze(0).to(device)
 
84
 
85
+ with torch.no_grad():
86
+ outputs = model(img_tensor)
87
+ probs = torch.softmax(outputs, dim=1)
88
+ pred_idx = torch.argmax(probs, dim=1).item()
89
 
90
+ label = class_names[pred_idx]
91
+ confidence = probs[0][pred_idx].item()
 
 
92
 
93
+ return PredictionResponse(
94
+ label=label,
95
+ confidence=round(confidence, 4)
96
+ )
97
 
98
  except Exception as e:
99
+ raise HTTPException(status_code=500, detail=str(e))
 
100
 
101
 
102
+ # =========================
103
+ # Health Check
104
+ # =========================
105
  @app.get("/health")
106
+ def root():
107
+ return {"status": "Vehicle Classifier API is running πŸš—πŸοΈ"}