Jerry75AI commited on
Commit
5fea29a
·
verified ·
1 Parent(s): 18742db

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +84 -57
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,86 +1,113 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import tensorflow as tf
5
- from typing import List, Dict, Any
6
  import io
 
7
 
8
- # Labels must mirror src/classification-model/index.ts
9
- LABELS: List[str] = [
10
- "battery",
11
- "biological",
12
- "brown-glass",
13
- "cardboard",
14
- "clothes",
15
- "green-glass",
16
- "metal",
17
- "paper",
18
- "plastic",
19
- "shoes",
20
- "trash",
21
- "white-glass",
 
 
 
 
22
  ]
 
23
 
24
 
25
- def _load_image_to_rgb(image: Image.Image) -> np.ndarray:
 
 
 
 
 
 
26
  if image.mode != "RGB":
27
  image = image.convert("RGB")
28
- return np.asarray(image)
29
-
30
-
31
- def _resize_224(img_rgb: np.ndarray) -> np.ndarray:
32
- im = Image.fromarray(img_rgb)
33
- im = im.resize((224, 224), Image.NEAREST)
34
- return np.asarray(im)
35
-
36
-
37
- def _preprocess(image: Image.Image) -> np.ndarray:
38
- rgb = _load_image_to_rgb(image)
39
- rgb224 = _resize_224(rgb)
40
- # shape [1,224,224,3], float32 in 0..255
41
- arr = rgb224.astype("float32")
42
- return np.expand_dims(arr, axis=0)
43
 
44
 
45
  class PreTrainedModel:
46
- def __init__(self, model_path: str = "model/model_resnet50.keras") -> None:
47
- self.model = tf.keras.models.load_model(model_path)
 
 
 
 
 
 
48
 
49
  def predict_image(self, image: Image.Image) -> Dict[str, float]:
50
- x = _preprocess(image)
 
 
 
 
 
 
 
51
  preds = self.model.predict(x)
 
52
  if isinstance(preds, (list, tuple)):
53
  preds = preds[0]
 
54
  probs = np.asarray(preds).squeeze().tolist()
55
-
56
  return {label: score for label, score in zip(LABELS, probs)}
57
 
58
-
 
59
  model = PreTrainedModel()
60
 
61
-
62
  def predict(image):
 
 
 
 
 
 
63
  predictions = model.predict_image(image)
64
 
65
- probs_percent = {label: round(p * 100, 2)
66
- for label, p in predictions.items()}
 
67
 
68
- max_label = max(probs_percent, key=probs_percent.get)
 
69
 
70
- return {
71
- "label": max_label,
72
- "percentage": probs_percent[max_label],
73
- "probabilities": probs_percent,
74
- }
 
 
 
75
 
 
 
 
76
 
77
- iface = gr.Interface(
78
- fn=predict,
79
- inputs=gr.Image(type="pil"),
80
- outputs=gr.JSON(),
81
- title="Waste Classification",
82
- description="Upload an image of waste to classify it.",
83
- )
84
 
85
  if __name__ == "__main__":
86
- iface.launch()
 
 
 
 
 
 
1
  import io
2
+ from typing import Dict
3
 
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from PIL import Image
7
+ from bustapi import BustAPI, Request, Response
8
+
9
+ LABELS = [
10
+ "battery", # 电池
11
+ "biological", # 生物垃圾/厨余垃圾
12
+ "brown-glass", # 棕色玻璃
13
+ "cardboard", # 纸板
14
+ "clothes", # 衣物
15
+ "green-glass", # 绿色玻璃
16
+ "metal", # 金属
17
+ "paper", # 纸张
18
+ "plastic", # 塑料
19
+ "shoes", # 鞋子
20
+ "trash", # 其他垃圾
21
+ "white-glass", # 白色玻璃
22
  ]
23
+ MODEL_PATH = "model/model_resnet50.keras"
24
 
25
 
26
+ def preprocess(image: Image.Image) -> np.ndarray:
27
+ """
28
+ 完整的图像预处理流程,将输入图像转换为模型可接受的格式
29
+ image: PIL Image 对象,输入的原始图像
30
+ 返回: 预处理后的图像数组,形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255
31
+ """
32
+ # 检查图像模式是否为 RGB,如果不是则进行转换
33
  if image.mode != "RGB":
34
  image = image.convert("RGB")
35
+ # 使用最近邻插值法将图像调整为 224x224 像素
36
+ # 224x224 是 ResNet50 模型的标准输入尺寸
37
+ image = image.resize((224, 224), Image.NEAREST)
38
+ # 将调整后的 PIL Image 转换回 NumPy 数组
39
+ # 形状为 [1, 224, 224, 3],数据类型为 float32,像素值范围 0-255
40
+ # 再将图像数据类型转换为 float32,以便进行后续计算
41
+ rgb224 = np.asarray(image).astype("float32")
42
+ # 在第一个维度(批次维度)上扩展数组,使其形状变为 [1, 224, 224, 3]
43
+ # 这是为了匹配深度学习模型期望的输入格式(批次大小, 高度, 宽度, 通道数)
44
+ return np.expand_dims(rgb224, axis=0)
 
 
 
 
 
45
 
46
 
47
  class PreTrainedModel:
48
+ """
49
+ 预训练模型包装类,用于加载和运行垃圾分类模型
50
+ """
51
+ def __init__(self) -> None:
52
+ """
53
+ 初始化预训练模型
54
+ """
55
+ self.model = tf.keras.models.load_model(MODEL_PATH)
56
 
57
  def predict_image(self, image: Image.Image) -> Dict[str, float]:
58
+ """
59
+ 对输入图像进行分类预测
60
+ image: PIL Image 对象,待分类的图像
61
+ 返回: 包含每个标签及其对应预测概率的字典
62
+ """
63
+ # 对输入图像进行预处理,转换为模型可接受的格式
64
+ x = preprocess(image)
65
+ # 使用模型进行预测,返回预测结果
66
  preds = self.model.predict(x)
67
+ # 如果预测结果是列表或元组(某些模型会返回多个输出),取第一个输出
68
  if isinstance(preds, (list, tuple)):
69
  preds = preds[0]
70
+ # 将预测结果转换为 NumPy 数组,去除多余的维度,并转换为 Python 列表
71
  probs = np.asarray(preds).squeeze().tolist()
72
+ # 将标签与对应的预测概率组合成字典返回
73
  return {label: score for label, score in zip(LABELS, probs)}
74
 
75
+ # 创建全局模型实例,程序启动时加载模型
76
+ # 这样做可以避免每次预测时重复加载模型,提高响应速度
77
  model = PreTrainedModel()
78
 
 
79
  def predict(image):
80
+ """
81
+ 预测函数,用于 Gradio 接口调用
82
+ image: 输入图像
83
+ 返回: 包含预测标签、置信度和所有类别概率的字典
84
+ """
85
+ # 调用模型进行预测,获取每个类别的概率
86
  predictions = model.predict_image(image)
87
 
88
+ # 找出概率最高的类别作为预测结果
89
+ max_label = max(predictions, key=predictions.get)
90
+ return max_label
91
 
92
+ # 创建服务器
93
+ app = BustAPI()
94
 
95
+ @app.post("/predict")
96
+ async def predict_api(req: Request):
97
+ # 读取 POST 二进制流
98
+ img_bytes = await req.body()
99
+ # 转 PIL Image
100
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
101
+ # 推理
102
+ label = predict(image)
103
 
104
+ return Response.json({
105
+ "label": label
106
+ })
107
 
108
+ @app.get("/")
109
+ async def home():
110
+ return "POST /predict"
 
 
 
 
111
 
112
  if __name__ == "__main__":
113
+ app.run(host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -2,4 +2,4 @@ tensorflow==2.16.1
2
  numpy
3
  Pillow
4
  requests
5
- gradio
 
2
  numpy
3
  Pillow
4
  requests
5
+ bustapi