import io from typing import Dict import numpy as np import tensorflow as tf from PIL import Image, UnidentifiedImageError from robyn import Robyn, Request LABELS = [ "battery", # 电池 "biological", # 生物垃圾/厨余垃圾 "brown-glass", # 棕色玻璃 "cardboard", # 纸板 "clothes", # 衣物 "green-glass", # 绿色玻璃 "metal", # 金属 "paper", # 纸张 "plastic", # 塑料 "shoes", # 鞋子 "trash", # 其他垃圾 "white-glass", # 白色玻璃 ] MODEL_PATH = "model/model_resnet50.keras" def preprocess(image: Image.Image) -> np.ndarray: """ 完整的图像预处理流程,将输入图像转换为模型可接受的格式 image: PIL Image 对象,输入的原始图像 返回: 预处理后的图像数组,形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255 """ # 检查图像模式是否为 RGB,如果不是则进行转换 if image.mode != "RGB": image = image.convert("RGB") # 使用最近邻插值法将图像调整为 224x224 像素 # 224x224 是 ResNet50 模型的标准输入尺寸 image = image.resize((224, 224), Image.NEAREST) # 将调整后的 PIL Image 转换回 NumPy 数组 # 形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255 # 再将图像数据类型转换为 float32,以便进行后续计算 rgb224 = np.asarray(image).astype("float32") # 在第一个维度(批次维度)上扩展数组,使其形状变为 [1, 224, 224, 3] # 这是为了匹配深度学习模型期望的输入格式(批次大小, 高度, 宽度, 通道数) return np.expand_dims(rgb224, axis=0) class PreTrainedModel: """ 预训练模型包装类,用于加载和运行垃圾分类模型 """ def __init__(self) -> None: """ 初始化预训练模型 """ self.model = tf.keras.models.load_model(MODEL_PATH) def predict_image(self, image: Image.Image) -> Dict[str, float]: """ 对输入图像进行分类预测 image: PIL Image 对象,待分类的图像 返回: 包含每个标签及其对应预测概率的字典 """ # 对输入图像进行预处理,转换为模型可接受的格式 x = preprocess(image) # 使用模型进行预测,返回预测结果 preds = self.model.predict(x) # 如果预测结果是列表或元组(某些模型会返回多个输出),取第一个输出 if isinstance(preds, (list, tuple)): preds = preds[0] # 将预测结果转换为 NumPy 数组,去除多余的维度,并转换为 Python 列表 probs = np.asarray(preds).squeeze().tolist() # 将标签与对应的预测概率组合成字典返回 return {label: score for label, score in zip(LABELS, probs)} # 创建全局模型实例,程序启动时加载模型 # 这样做可以避免每次预测时重复加载模型,提高响应速度 model = PreTrainedModel() def predict(image): """ 预测函数,用于 Gradio 接口调用 image: 输入图像 返回: 包含预测标签、置信度和所有类别概率的字典 """ # 调用模型进行预测,获取每个类别的概率 predictions = model.predict_image(image) # 找出概率最高的类别作为预测结果 max_label = max(predictions, key=predictions.get) return max_label # 创建服务器 app = Robyn(__file__) @app.post("/predict") def predict_route(request: Request): try: # 读取 POST 二进制流 raw = request.body if raw is None or len(raw) == 0: return {"error": "empty request body"}, 400 # Robyn 的 request.body 可能是 str 或 bytes; # 发图片二进制时通常会是 bytes。 if isinstance(raw, str): raw = raw.encode("latin1") image = Image.open(io.BytesIO(raw)).convert("RGB") label = predict(image) return label except UnidentifiedImageError: return {"error": "invalid image bytes"}, 400 except Exception as e: return {"error": str(e)}, 500 @app.get("/") def home(): return "POST /predict" if __name__ == "__main__": app.start(host="0.0.0.0", port=7860)