Spaces:
Running
Running
| import os | |
| from typing import Any | |
| import numpy as np | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| IMAGE_SIZE = 299 | |
| CLASS_NAMES = ("No-Stroke", "Stroke") | |
| REPO_ID = os.environ.get("STROKE_MODEL_REPO", "melisklc0/efficientnet-b0-stroke-distilled") | |
| ONNX_FILENAME = "model.onnx" | |
| def _softmax(x: np.ndarray) -> np.ndarray: | |
| x = x.astype(np.float64) | |
| x = x - np.max(x, axis=-1, keepdims=True) | |
| e = np.exp(x) | |
| return (e / e.sum(axis=-1, keepdims=True)).astype(np.float32) | |
| def preprocess_image(img: Image.Image, image_size: int = IMAGE_SIZE) -> np.ndarray: | |
| """RGB, resize, ImageNet normalize -> NCHW float32.""" | |
| rgb = img.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR) | |
| arr = np.asarray(rgb, dtype=np.float32) / 255.0 | |
| arr = np.transpose(arr, (2, 0, 1)) | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1) | |
| arr = (arr - mean) / std | |
| return np.expand_dims(arr, axis=0) | |
| def load_stroke_model(): | |
| """Download ONNX from the model Hub repo and build an inference session.""" | |
| onnx_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=ONNX_FILENAME, | |
| repo_type="model", | |
| ) | |
| providers: list[str] = ["CPUExecutionProvider"] | |
| if ort.get_device() == "GPU": | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| session = ort.InferenceSession(onnx_path, providers=providers) | |
| return session, preprocess_image | |
| def predict(session: ort.InferenceSession, preprocess: Any, img: Image.Image): | |
| x = preprocess(img) | |
| inp = session.get_inputs()[0].name | |
| logits = session.run(None, {inp: x})[0] | |
| probs = _softmax(logits[0]) | |
| results = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} | |
| pred_idx = int(np.argmax(probs)) | |
| prediction = CLASS_NAMES[pred_idx] | |
| confidence = float(probs[pred_idx]) | |
| return prediction, confidence, results | |