Jerry75AI's picture
Upload 2 files
913b62c verified
import io
from typing import Dict
import numpy as np
import tensorflow as tf
from PIL import Image, UnidentifiedImageError
from robyn import Robyn, Request
LABELS = [
"battery", # 电池
"biological", # 生物垃圾/厨余垃圾
"brown-glass", # 棕色玻璃
"cardboard", # 纸板
"clothes", # 衣物
"green-glass", # 绿色玻璃
"metal", # 金属
"paper", # 纸张
"plastic", # 塑料
"shoes", # 鞋子
"trash", # 其他垃圾
"white-glass", # 白色玻璃
]
MODEL_PATH = "model/model_resnet50.keras"
def preprocess(image: Image.Image) -> np.ndarray:
"""
完整的图像预处理流程,将输入图像转换为模型可接受的格式
image: PIL Image 对象,输入的原始图像
返回: 预处理后的图像数组,形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255
"""
# 检查图像模式是否为 RGB,如果不是则进行转换
if image.mode != "RGB":
image = image.convert("RGB")
# 使用最近邻插值法将图像调整为 224x224 像素
# 224x224 是 ResNet50 模型的标准输入尺寸
image = image.resize((224, 224), Image.NEAREST)
# 将调整后的 PIL Image 转换回 NumPy 数组
# 形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255
# 再将图像数据类型转换为 float32,以便进行后续计算
rgb224 = np.asarray(image).astype("float32")
# 在第一个维度(批次维度)上扩展数组,使其形状变为 [1, 224, 224, 3]
# 这是为了匹配深度学习模型期望的输入格式(批次大小, 高度, 宽度, 通道数)
return np.expand_dims(rgb224, axis=0)
class PreTrainedModel:
"""
预训练模型包装类,用于加载和运行垃圾分类模型
"""
def __init__(self) -> None:
"""
初始化预训练模型
"""
self.model = tf.keras.models.load_model(MODEL_PATH)
def predict_image(self, image: Image.Image) -> Dict[str, float]:
"""
对输入图像进行分类预测
image: PIL Image 对象,待分类的图像
返回: 包含每个标签及其对应预测概率的字典
"""
# 对输入图像进行预处理,转换为模型可接受的格式
x = preprocess(image)
# 使用模型进行预测,返回预测结果
preds = self.model.predict(x)
# 如果预测结果是列表或元组(某些模型会返回多个输出),取第一个输出
if isinstance(preds, (list, tuple)):
preds = preds[0]
# 将预测结果转换为 NumPy 数组,去除多余的维度,并转换为 Python 列表
probs = np.asarray(preds).squeeze().tolist()
# 将标签与对应的预测概率组合成字典返回
return {label: score for label, score in zip(LABELS, probs)}
# 创建全局模型实例,程序启动时加载模型
# 这样做可以避免每次预测时重复加载模型,提高响应速度
model = PreTrainedModel()
def predict(image):
"""
预测函数,用于 Gradio 接口调用
image: 输入图像
返回: 包含预测标签、置信度和所有类别概率的字典
"""
# 调用模型进行预测,获取每个类别的概率
predictions = model.predict_image(image)
# 找出概率最高的类别作为预测结果
max_label = max(predictions, key=predictions.get)
return max_label
# 创建服务器
app = Robyn(__file__)
@app.post("/predict")
def predict_route(request: Request):
try:
# 读取 POST 二进制流
raw = request.body
if raw is None or len(raw) == 0:
return {"error": "empty request body"}, 400
# Robyn 的 request.body 可能是 str 或 bytes;
# 发图片二进制时通常会是 bytes。
if isinstance(raw, str):
raw = raw.encode("latin1")
image = Image.open(io.BytesIO(raw)).convert("RGB")
label = predict(image)
return label
except UnidentifiedImageError:
return {"error": "invalid image bytes"}, 400
except Exception as e:
return {"error": str(e)}, 500
@app.get("/")
def home():
return "POST /predict"
if __name__ == "__main__":
app.start(host="0.0.0.0", port=7860)