Spaces:
Running
Running
File size: 2,026 Bytes
b02f059 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import os
from typing import Any
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
IMAGE_SIZE = 299
CLASS_NAMES = ("No-Stroke", "Stroke")
REPO_ID = os.environ.get("STROKE_MODEL_REPO", "melisklc0/efficientnet-b0-stroke-distilled")
ONNX_FILENAME = "model.onnx"
def _softmax(x: np.ndarray) -> np.ndarray:
x = x.astype(np.float64)
x = x - np.max(x, axis=-1, keepdims=True)
e = np.exp(x)
return (e / e.sum(axis=-1, keepdims=True)).astype(np.float32)
def preprocess_image(img: Image.Image, image_size: int = IMAGE_SIZE) -> np.ndarray:
"""RGB, resize, ImageNet normalize -> NCHW float32."""
rgb = img.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
arr = np.asarray(rgb, dtype=np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
arr = (arr - mean) / std
return np.expand_dims(arr, axis=0)
def load_stroke_model():
"""Download ONNX from the model Hub repo and build an inference session."""
onnx_path = hf_hub_download(
repo_id=REPO_ID,
filename=ONNX_FILENAME,
repo_type="model",
)
providers: list[str] = ["CPUExecutionProvider"]
if ort.get_device() == "GPU":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
session = ort.InferenceSession(onnx_path, providers=providers)
return session, preprocess_image
def predict(session: ort.InferenceSession, preprocess: Any, img: Image.Image):
x = preprocess(img)
inp = session.get_inputs()[0].name
logits = session.run(None, {inp: x})[0]
probs = _softmax(logits[0])
results = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
pred_idx = int(np.argmax(probs))
prediction = CLASS_NAMES[pred_idx]
confidence = float(probs[pred_idx])
return prediction, confidence, results
|