Spaces:
Build error
Build error
File size: 3,977 Bytes
18742db 5fea29a 18742db 5fea29a cc2f78d 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 5fea29a cc2f78d 5fea29a cc2f78d 5fea29a 18742db 5fea29a 18742db 5fea29a 18742db 802c1db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import io
from typing import Dict
import numpy as np
import tensorflow as tf
from PIL import Image
from bustapi import BustAPI, request, Response
LABELS = [
"battery", # 电池
"biological", # 生物垃圾/厨余垃圾
"brown-glass", # 棕色玻璃
"cardboard", # 纸板
"clothes", # 衣物
"green-glass", # 绿色玻璃
"metal", # 金属
"paper", # 纸张
"plastic", # 塑料
"shoes", # 鞋子
"trash", # 其他垃圾
"white-glass", # 白色玻璃
]
MODEL_PATH = "model/model_resnet50.keras"
def preprocess(image: Image.Image) -> np.ndarray:
"""
完整的图像预处理流程,将输入图像转换为模型可接受的格式
image: PIL Image 对象,输入的原始图像
返回: 预处理后的图像数组,形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255
"""
# 检查图像模式是否为 RGB,如果不是则进行转换
if image.mode != "RGB":
image = image.convert("RGB")
# 使用最近邻插值法将图像调整为 224x224 像素
# 224x224 是 ResNet50 模型的标准输入尺寸
image = image.resize((224, 224), Image.NEAREST)
# 将调整后的 PIL Image 转换回 NumPy 数组
# 形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255
# 再将图像数据类型转换为 float32,以便进行后续计算
rgb224 = np.asarray(image).astype("float32")
# 在第一个维度(批次维度)上扩展数组,使其形状变为 [1, 224, 224, 3]
# 这是为了匹配深度学习模型期望的输入格式(批次大小, 高度, 宽度, 通道数)
return np.expand_dims(rgb224, axis=0)
class PreTrainedModel:
"""
预训练模型包装类,用于加载和运行垃圾分类模型
"""
def __init__(self) -> None:
"""
初始化预训练模型
"""
self.model = tf.keras.models.load_model(MODEL_PATH)
def predict_image(self, image: Image.Image) -> Dict[str, float]:
"""
对输入图像进行分类预测
image: PIL Image 对象,待分类的图像
返回: 包含每个标签及其对应预测概率的字典
"""
# 对输入图像进行预处理,转换为模型可接受的格式
x = preprocess(image)
# 使用模型进行预测,返回预测结果
preds = self.model.predict(x)
# 如果预测结果是列表或元组(某些模型会返回多个输出),取第一个输出
if isinstance(preds, (list, tuple)):
preds = preds[0]
# 将预测结果转换为 NumPy 数组,去除多余的维度,并转换为 Python 列表
probs = np.asarray(preds).squeeze().tolist()
# 将标签与对应的预测概率组合成字典返回
return {label: score for label, score in zip(LABELS, probs)}
# 创建全局模型实例,程序启动时加载模型
# 这样做可以避免每次预测时重复加载模型,提高响应速度
model = PreTrainedModel()
def predict(image):
"""
预测函数,用于 Gradio 接口调用
image: 输入图像
返回: 包含预测标签、置信度和所有类别概率的字典
"""
# 调用模型进行预测,获取每个类别的概率
predictions = model.predict_image(image)
# 找出概率最高的类别作为预测结果
max_label = max(predictions, key=predictions.get)
return max_label
# 创建服务器
app = BustAPI()
@app.post("/predict")
async def predict_api():
# 读取 POST 二进制流
img_bytes = await request.body()
# 转 PIL Image
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# 推理
label = predict(image)
return Response.json({
"label": label
})
@app.get("/")
async def home():
return "POST /predict"
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|