Spaces:
Running
Running
File size: 2,684 Bytes
93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf 93fc243 67d88cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import io
import torch
import timm
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from torchvision import transforms
# =========================
# App Init
# =========================
app = FastAPI(title="Vehicle Type Classifier API")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
MODEL_PATH = "vehicle_classifier_best.pth"
model = None
class_names = None
# =========================
# Response Schema
# =========================
class PredictionResponse(BaseModel):
label: str
confidence: float
# =========================
# Image Transform
# =========================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# =========================
# Load Model on Startup
# =========================
@app.on_event("startup")
def load_model():
global model, class_names
checkpoint = torch.load(MODEL_PATH, map_location=device)
class_names = checkpoint["class_names"]
num_classes = len(class_names)
model = timm.create_model(
"convnext_tiny",
pretrained=False,
num_classes=num_classes
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
print("โ
Model loaded successfully")
# =========================
# Prediction Endpoint
# =========================
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
label = class_names[pred_idx]
confidence = probs[0][pred_idx].item()
return PredictionResponse(
label=label,
confidence=round(confidence, 4)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =========================
# Health Check
# =========================
@app.get("/health")
def root():
return {"status": "Vehicle Classifier API is running ๐๐๏ธ"}
|