File size: 2,026 Bytes
b02f059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from typing import Any

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

IMAGE_SIZE = 299
CLASS_NAMES = ("No-Stroke", "Stroke")
REPO_ID = os.environ.get("STROKE_MODEL_REPO", "melisklc0/efficientnet-b0-stroke-distilled")
ONNX_FILENAME = "model.onnx"


def _softmax(x: np.ndarray) -> np.ndarray:
    x = x.astype(np.float64)
    x = x - np.max(x, axis=-1, keepdims=True)
    e = np.exp(x)
    return (e / e.sum(axis=-1, keepdims=True)).astype(np.float32)


def preprocess_image(img: Image.Image, image_size: int = IMAGE_SIZE) -> np.ndarray:
    """RGB, resize, ImageNet normalize -> NCHW float32."""
    rgb = img.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
    arr = np.asarray(rgb, dtype=np.float32) / 255.0
    arr = np.transpose(arr, (2, 0, 1))
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
    arr = (arr - mean) / std
    return np.expand_dims(arr, axis=0)


def load_stroke_model():
    """Download ONNX from the model Hub repo and build an inference session."""
    onnx_path = hf_hub_download(
        repo_id=REPO_ID,
        filename=ONNX_FILENAME,
        repo_type="model",
    )
    providers: list[str] = ["CPUExecutionProvider"]
    if ort.get_device() == "GPU":
        providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
    session = ort.InferenceSession(onnx_path, providers=providers)
    return session, preprocess_image


def predict(session: ort.InferenceSession, preprocess: Any, img: Image.Image):
    x = preprocess(img)
    inp = session.get_inputs()[0].name
    logits = session.run(None, {inp: x})[0]
    probs = _softmax(logits[0])

    results = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
    pred_idx = int(np.argmax(probs))
    prediction = CLASS_NAMES[pred_idx]
    confidence = float(probs[pred_idx])

    return prediction, confidence, results