g4tes commited on
Commit
f2f4624
·
verified ·
1 Parent(s): b98f9a6

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitattributes +2 -34
  2. README.md +3 -1
  3. app.py +76 -0
  4. config.json +19 -0
  5. inference.py +68 -0
  6. model/model_resnet50.keras +3 -0
  7. requirements.txt +5 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.keras filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -9,4 +9,6 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # HF Model: Classification
13
+
14
+ This folder contains files to load a Keras (.keras) image classification model on Hugging Face Inference.
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tensorflow as tf
5
+ from typing import List, Dict, Any
6
+ import io
7
+
8
+ # Labels must mirror src/classification-model/index.ts
9
+ LABELS: List[str] = [
10
+ "battery",
11
+ "biological",
12
+ "brown-glass",
13
+ "cardboard",
14
+ "clothes",
15
+ "green-glass",
16
+ "metal",
17
+ "paper",
18
+ "plastic",
19
+ "shoes",
20
+ "trash",
21
+ "white-glass",
22
+ ]
23
+
24
+
25
+ def _load_image_to_rgb(image: Image.Image) -> np.ndarray:
26
+ if image.mode != "RGB":
27
+ image = image.convert("RGB")
28
+ return np.asarray(image)
29
+
30
+
31
+ def _resize_224(img_rgb: np.ndarray) -> np.ndarray:
32
+ im = Image.fromarray(img_rgb)
33
+ im = im.resize((224, 224), Image.NEAREST)
34
+ return np.asarray(im)
35
+
36
+
37
+ def _preprocess(image: Image.Image) -> np.ndarray:
38
+ rgb = _load_image_to_rgb(image)
39
+ rgb224 = _resize_224(rgb)
40
+ # shape [1,224,224,3], float32 in 0..255
41
+ arr = rgb224.astype("float32")
42
+ return np.expand_dims(arr, axis=0)
43
+
44
+
45
+ class PreTrainedModel:
46
+ def __init__(self, model_path: str = "model/model_resnet50.keras") -> None:
47
+ self.model = tf.keras.models.load_model(model_path)
48
+
49
+ def predict_image(self, image: Image.Image) -> Dict[str, float]:
50
+ x = _preprocess(image)
51
+ preds = self.model.predict(x)
52
+ if isinstance(preds, (list, tuple)):
53
+ preds = preds[0]
54
+ probs = np.asarray(preds).squeeze().tolist()
55
+
56
+ return {label: score for label, score in zip(LABELS, probs)}
57
+
58
+
59
+ model = PreTrainedModel()
60
+
61
+
62
+ def predict(image):
63
+ predictions = model.predict_image(image)
64
+ return predictions
65
+
66
+
67
+ iface = gr.Interface(
68
+ fn=predict,
69
+ inputs=gr.Image(type="pil"),
70
+ outputs=gr.Label(num_top_classes=3),
71
+ title="Waste Classification",
72
+ description="Upload an image of waste to classify it.",
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ iface.launch()
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["CustomKerasClassifier"],
3
+ "framework": "keras",
4
+ "image_size": 224,
5
+ "labels": [
6
+ "battery",
7
+ "biological",
8
+ "brown-glass",
9
+ "cardboard",
10
+ "clothes",
11
+ "green-glass",
12
+ "metal",
13
+ "paper",
14
+ "plastic",
15
+ "shoes",
16
+ "trash",
17
+ "white-glass"
18
+ ]
19
+ }
inference.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+ import io
3
+ import numpy as np
4
+ from PIL import Image
5
+ import requests
6
+
7
+ import tensorflow as tf
8
+
9
+ # Labels must mirror src/classification-model/index.ts
10
+ LABELS: List[str] = [
11
+ "battery",
12
+ "biological",
13
+ "brown-glass",
14
+ "cardboard",
15
+ "clothes",
16
+ "green-glass",
17
+ "metal",
18
+ "paper",
19
+ "plastic",
20
+ "shoes",
21
+ "trash",
22
+ "white-glass",
23
+ ]
24
+
25
+
26
+ def _load_image_to_rgb(image: Image.Image) -> np.ndarray:
27
+ if image.mode != "RGB":
28
+ image = image.convert("RGB")
29
+ return np.asarray(image)
30
+
31
+
32
+ def _resize_224(img_rgb: np.ndarray) -> np.ndarray:
33
+ im = Image.fromarray(img_rgb)
34
+ im = im.resize((224, 224), Image.NEAREST)
35
+ return np.asarray(im)
36
+
37
+
38
+ def _preprocess(image_bytes: bytes) -> np.ndarray:
39
+ # Mirror TS: ensure JPEG-like decode and resize 224x224, keep 0..255 range
40
+ image = Image.open(io.BytesIO(image_bytes))
41
+ rgb = _load_image_to_rgb(image)
42
+ rgb224 = _resize_224(rgb)
43
+ # shape [1,224,224,3], float32 in 0..255
44
+ arr = rgb224.astype("float32")
45
+ return np.expand_dims(arr, axis=0)
46
+
47
+
48
+ class PreTrainedModel:
49
+ def __init__(self, model_path: str = "model/model_resnet50.keras") -> None:
50
+ self.model = tf.keras.models.load_model(model_path)
51
+
52
+ def predict(self, inputs: bytes) -> List[Dict[str, Any]]:
53
+ x = _preprocess(inputs)
54
+ preds = self.model.predict(x)
55
+ if isinstance(preds, (list, tuple)):
56
+ preds = preds[0]
57
+ probs = np.asarray(preds).squeeze().tolist()
58
+ # Top-1 output following TS behavior
59
+ idx = int(np.argmax(probs))
60
+ return [
61
+ {"label": LABELS[idx], "score": float(probs[idx])},
62
+ ]
63
+
64
+
65
+ def load_model(model_dir: str = ".") -> PreTrainedModel:
66
+ # HF Inference API convention: a top-level load entrypoint
67
+ return PreTrainedModel(model_path=f"{model_dir}/model/model_resnet50.keras")
68
+
model/model_resnet50.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93aca6d248291878520c966415f3a23a2320370e809ca4b45c6358e332518052
3
+ size 243395061
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tensorflow==2.16.1
2
+ numpy
3
+ Pillow
4
+ requests
5
+ gradio