Spaces:
Build error
Build error
| from typing import List, Dict, Any | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| import tensorflow as tf | |
| # Labels must mirror src/classification-model/index.ts | |
| LABELS: List[str] = [ | |
| "battery", | |
| "biological", | |
| "brown-glass", | |
| "cardboard", | |
| "clothes", | |
| "green-glass", | |
| "metal", | |
| "paper", | |
| "plastic", | |
| "shoes", | |
| "trash", | |
| "white-glass", | |
| ] | |
| def _load_image_to_rgb(image: Image.Image) -> np.ndarray: | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return np.asarray(image) | |
| def _resize_224(img_rgb: np.ndarray) -> np.ndarray: | |
| im = Image.fromarray(img_rgb) | |
| im = im.resize((224, 224), Image.NEAREST) | |
| return np.asarray(im) | |
| def _preprocess(image_bytes: bytes) -> np.ndarray: | |
| # Mirror TS: ensure JPEG-like decode and resize 224x224, keep 0..255 range | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| rgb = _load_image_to_rgb(image) | |
| rgb224 = _resize_224(rgb) | |
| # shape [1,224,224,3], float32 in 0..255 | |
| arr = rgb224.astype("float32") | |
| return np.expand_dims(arr, axis=0) | |
| class PreTrainedModel: | |
| def __init__(self, model_path: str = "model/model_resnet50.keras") -> None: | |
| self.model = tf.keras.models.load_model(model_path) | |
| def predict(self, inputs: bytes) -> List[Dict[str, Any]]: | |
| x = _preprocess(inputs) | |
| preds = self.model.predict(x) | |
| if isinstance(preds, (list, tuple)): | |
| preds = preds[0] | |
| probs = np.asarray(preds).squeeze().tolist() | |
| # Top-1 output following TS behavior | |
| idx = int(np.argmax(probs)) | |
| return [ | |
| {"label": LABELS[idx], "score": float(probs[idx])}, | |
| ] | |
| def load_model(model_dir: str = ".") -> PreTrainedModel: | |
| # HF Inference API convention: a top-level load entrypoint | |
| return PreTrainedModel(model_path=f"{model_dir}/model/model_resnet50.keras") | |