Spaces:
Build error
Build error
| import io | |
| from typing import Dict | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from bustapi import BustAPI, request, Response | |
| LABELS = [ | |
| "battery", # 电池 | |
| "biological", # 生物垃圾/厨余垃圾 | |
| "brown-glass", # 棕色玻璃 | |
| "cardboard", # 纸板 | |
| "clothes", # 衣物 | |
| "green-glass", # 绿色玻璃 | |
| "metal", # 金属 | |
| "paper", # 纸张 | |
| "plastic", # 塑料 | |
| "shoes", # 鞋子 | |
| "trash", # 其他垃圾 | |
| "white-glass", # 白色玻璃 | |
| ] | |
| MODEL_PATH = "model/model_resnet50.keras" | |
| def preprocess(image: Image.Image) -> np.ndarray: | |
| """ | |
| 完整的图像预处理流程,将输入图像转换为模型可接受的格式 | |
| image: PIL Image 对象,输入的原始图像 | |
| 返回: 预处理后的图像数组,形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255 | |
| """ | |
| # 检查图像模式是否为 RGB,如果不是则进行转换 | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # 使用最近邻插值法将图像调整为 224x224 像素 | |
| # 224x224 是 ResNet50 模型的标准输入尺寸 | |
| image = image.resize((224, 224), Image.NEAREST) | |
| # 将调整后的 PIL Image 转换回 NumPy 数组 | |
| # 形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255 | |
| # 再将图像数据类型转换为 float32,以便进行后续计算 | |
| rgb224 = np.asarray(image).astype("float32") | |
| # 在第一个维度(批次维度)上扩展数组,使其形状变为 [1, 224, 224, 3] | |
| # 这是为了匹配深度学习模型期望的输入格式(批次大小, 高度, 宽度, 通道数) | |
| return np.expand_dims(rgb224, axis=0) | |
| class PreTrainedModel: | |
| """ | |
| 预训练模型包装类,用于加载和运行垃圾分类模型 | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| 初始化预训练模型 | |
| """ | |
| self.model = tf.keras.models.load_model(MODEL_PATH) | |
| def predict_image(self, image: Image.Image) -> Dict[str, float]: | |
| """ | |
| 对输入图像进行分类预测 | |
| image: PIL Image 对象,待分类的图像 | |
| 返回: 包含每个标签及其对应预测概率的字典 | |
| """ | |
| # 对输入图像进行预处理,转换为模型可接受的格式 | |
| x = preprocess(image) | |
| # 使用模型进行预测,返回预测结果 | |
| preds = self.model.predict(x) | |
| # 如果预测结果是列表或元组(某些模型会返回多个输出),取第一个输出 | |
| if isinstance(preds, (list, tuple)): | |
| preds = preds[0] | |
| # 将预测结果转换为 NumPy 数组,去除多余的维度,并转换为 Python 列表 | |
| probs = np.asarray(preds).squeeze().tolist() | |
| # 将标签与对应的预测概率组合成字典返回 | |
| return {label: score for label, score in zip(LABELS, probs)} | |
| # 创建全局模型实例,程序启动时加载模型 | |
| # 这样做可以避免每次预测时重复加载模型,提高响应速度 | |
| model = PreTrainedModel() | |
| def predict(image): | |
| """ | |
| 预测函数,用于 Gradio 接口调用 | |
| image: 输入图像 | |
| 返回: 包含预测标签、置信度和所有类别概率的字典 | |
| """ | |
| # 调用模型进行预测,获取每个类别的概率 | |
| predictions = model.predict_image(image) | |
| # 找出概率最高的类别作为预测结果 | |
| max_label = max(predictions, key=predictions.get) | |
| return max_label | |
| # 创建服务器 | |
| app = BustAPI() | |
| async def predict_api(): | |
| # 读取 POST 二进制流 | |
| img_bytes = await request.body() | |
| # 转 PIL Image | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| # 推理 | |
| label = predict(image) | |
| return Response.json({ | |
| "label": label | |
| }) | |
| async def home(): | |
| return "POST /predict" | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |