Spaces:
Sleeping
Sleeping
| import io | |
| from typing import Dict | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image, UnidentifiedImageError | |
| from robyn import Robyn, Request | |
| LABELS = [ | |
| "battery", # 电池 | |
| "biological", # 生物垃圾/厨余垃圾 | |
| "brown-glass", # 棕色玻璃 | |
| "cardboard", # 纸板 | |
| "clothes", # 衣物 | |
| "green-glass", # 绿色玻璃 | |
| "metal", # 金属 | |
| "paper", # 纸张 | |
| "plastic", # 塑料 | |
| "shoes", # 鞋子 | |
| "trash", # 其他垃圾 | |
| "white-glass", # 白色玻璃 | |
| ] | |
| MODEL_PATH = "model/model_resnet50.keras" | |
| def preprocess(image: Image.Image) -> np.ndarray: | |
| """ | |
| 完整的图像预处理流程,将输入图像转换为模型可接受的格式 | |
| image: PIL Image 对象,输入的原始图像 | |
| 返回: 预处理后的图像数组,形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255 | |
| """ | |
| # 检查图像模式是否为 RGB,如果不是则进行转换 | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # 使用最近邻插值法将图像调整为 224x224 像素 | |
| # 224x224 是 ResNet50 模型的标准输入尺寸 | |
| image = image.resize((224, 224), Image.NEAREST) | |
| # 将调整后的 PIL Image 转换回 NumPy 数组 | |
| # 形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255 | |
| # 再将图像数据类型转换为 float32,以便进行后续计算 | |
| rgb224 = np.asarray(image).astype("float32") | |
| # 在第一个维度(批次维度)上扩展数组,使其形状变为 [1, 224, 224, 3] | |
| # 这是为了匹配深度学习模型期望的输入格式(批次大小, 高度, 宽度, 通道数) | |
| return np.expand_dims(rgb224, axis=0) | |
| class PreTrainedModel: | |
| """ | |
| 预训练模型包装类,用于加载和运行垃圾分类模型 | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| 初始化预训练模型 | |
| """ | |
| self.model = tf.keras.models.load_model(MODEL_PATH) | |
| def predict_image(self, image: Image.Image) -> Dict[str, float]: | |
| """ | |
| 对输入图像进行分类预测 | |
| image: PIL Image 对象,待分类的图像 | |
| 返回: 包含每个标签及其对应预测概率的字典 | |
| """ | |
| # 对输入图像进行预处理,转换为模型可接受的格式 | |
| x = preprocess(image) | |
| # 使用模型进行预测,返回预测结果 | |
| preds = self.model.predict(x) | |
| # 如果预测结果是列表或元组(某些模型会返回多个输出),取第一个输出 | |
| if isinstance(preds, (list, tuple)): | |
| preds = preds[0] | |
| # 将预测结果转换为 NumPy 数组,去除多余的维度,并转换为 Python 列表 | |
| probs = np.asarray(preds).squeeze().tolist() | |
| # 将标签与对应的预测概率组合成字典返回 | |
| return {label: score for label, score in zip(LABELS, probs)} | |
| # 创建全局模型实例,程序启动时加载模型 | |
| # 这样做可以避免每次预测时重复加载模型,提高响应速度 | |
| model = PreTrainedModel() | |
| def predict(image): | |
| """ | |
| 预测函数,用于 Gradio 接口调用 | |
| image: 输入图像 | |
| 返回: 包含预测标签、置信度和所有类别概率的字典 | |
| """ | |
| # 调用模型进行预测,获取每个类别的概率 | |
| predictions = model.predict_image(image) | |
| # 找出概率最高的类别作为预测结果 | |
| max_label = max(predictions, key=predictions.get) | |
| return max_label | |
| # 创建服务器 | |
| app = Robyn(__file__) | |
| def predict_route(request: Request): | |
| try: | |
| # 读取 POST 二进制流 | |
| raw = request.body | |
| if raw is None or len(raw) == 0: | |
| return {"error": "empty request body"}, 400 | |
| # Robyn 的 request.body 可能是 str 或 bytes; | |
| # 发图片二进制时通常会是 bytes。 | |
| if isinstance(raw, str): | |
| raw = raw.encode("latin1") | |
| image = Image.open(io.BytesIO(raw)).convert("RGB") | |
| label = predict(image) | |
| return label | |
| except UnidentifiedImageError: | |
| return {"error": "invalid image bytes"}, 400 | |
| except Exception as e: | |
| return {"error": str(e)}, 500 | |
| def home(): | |
| return "POST /predict" | |
| if __name__ == "__main__": | |
| app.start(host="0.0.0.0", port=7860) | |