Spaces:
Sleeping
Sleeping
File size: 5,753 Bytes
b138cbf a3c0988 8275409 b138cbf a3c0988 b138cbf a3c0988 b138cbf a3c0988 789b3f0 b138cbf a3c0988 b138cbf a3c0988 b138cbf 789b3f0 a3c0988 789b3f0 a3c0988 789b3f0 a3c0988 789b3f0 a3c0988 b138cbf 789b3f0 a3c0988 789b3f0 a3c0988 789b3f0 a3c0988 789b3f0 b138cbf a3c0988 b138cbf 789b3f0 a3c0988 b138cbf a3c0988 b138cbf a3c0988 b138cbf a3c0988 b138cbf a3c0988 b138cbf 5f259e2 a3c0988 789b3f0 a3c0988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import json
from pathlib import Path
import numpy as np
import streamlit as st
from PIL import Image
import tensorflow as tf
# -------------------------
# Page config
# -------------------------
st.set_page_config(
page_title='Facial Keypoints Predictor (CNN)',
page_icon='π',
layout='centered'
)
st.title('π Facial Keypoints Predictor (CNN)')
st.write('Upload a face image and the model will predict 15 facial keypoints (30 values: x/y).')
# -------------------------
# Paths (HuggingFace friendly)
# Put ALL files inside /src
# -------------------------
BASE_DIR = Path(__file__).resolve().parent
MODEL_KERAS_PATH = BASE_DIR / 'final_keypoints_cnn.keras'
MODEL_H5_PATH = BASE_DIR / 'final_keypoints_cnn.h5'
TARGET_COLS_PATH = BASE_DIR / 'target_cols.json'
PREPROCESS_PATH = BASE_DIR / 'preprocess_config.json'
# -------------------------
# Load assets
# -------------------------
@st.cache_resource
def load_assets():
# β
IMPORTANT: Keras 3 does NOT load SavedModel folders via load_model()
# So we FORCE .keras or .h5 only.
if MODEL_KERAS_PATH.exists():
model = tf.keras.models.load_model(str(MODEL_KERAS_PATH), compile=False)
model_source = MODEL_KERAS_PATH.name
elif MODEL_H5_PATH.exists():
model = tf.keras.models.load_model(str(MODEL_H5_PATH), compile=False)
model_source = MODEL_H5_PATH.name
else:
raise FileNotFoundError(
'Model not found. Upload `final_keypoints_cnn.keras` (recommended) or `final_keypoints_cnn.h5` into /src.'
)
if not TARGET_COLS_PATH.exists():
raise FileNotFoundError('Missing file: target_cols.json (put it in /src)')
if not PREPROCESS_PATH.exists():
raise FileNotFoundError('Missing file: preprocess_config.json (put it in /src)')
with open(TARGET_COLS_PATH, 'r') as f:
target_cols = json.load(f)
with open(PREPROCESS_PATH, 'r') as f:
preprocess_cfg = json.load(f)
return model, target_cols, preprocess_cfg, model_source
# -------------------------
# Helpers
# -------------------------
def preprocess_image(pil_img: Image.Image, img_size=(96, 96)) -> np.ndarray:
# Convert to grayscale like the Kaggle dataset (96x96, 1 channel)
img = pil_img.convert('L').resize(img_size)
arr = np.array(img).astype(np.float32) / 255.0 # normalize x / 255
arr = np.expand_dims(arr, axis=-1) # (96, 96, 1)
arr = np.expand_dims(arr, axis=0) # (1, 96, 96, 1)
return arr
def draw_keypoints(pil_img: Image.Image, keypoints_xy: np.ndarray) -> Image.Image:
# keypoints_xy shape: (15, 2) -> x,y
import PIL.ImageDraw as ImageDraw
img = pil_img.convert('RGB').resize((96, 96))
draw = ImageDraw.Draw(img)
for (x, y) in keypoints_xy:
r = 2
draw.ellipse((x - r, y - r, x + r, y + r), outline='red', width=2)
return img
def to_xy(pred_30: np.ndarray) -> np.ndarray:
# pred_30 shape: (30,)
pts = pred_30.reshape(-1, 2)
return pts
# -------------------------
# UI: checklist
# -------------------------
with st.expander('Model files checklist'):
st.markdown(
'- Put files inside **`/src`** in your HuggingFace Space.\n'
'- Required:\n'
' - `final_keypoints_cnn.keras` (recommended) **or** `final_keypoints_cnn.h5`\n'
' - `target_cols.json`\n'
' - `preprocess_config.json`\n'
'- Optional: `history.pkl` (not needed for inference)\n'
'\n'
'β
Tip: If you still have a folder `final_keypoints_cnn_savedmodel/`, remove it or ignore it. '
'This app does **not** load SavedModel folders.'
)
# -------------------------
# Load model + configs
# -------------------------
try:
model, target_cols, preprocess_cfg, model_source = load_assets()
st.success(f'Model loaded: {model_source}')
except Exception as e:
st.error(str(e))
st.stop()
# -------------------------
# Upload + Predict
# -------------------------
uploaded = st.file_uploader('Upload an image (jpg/png)', type=['jpg', 'jpeg', 'png'])
if uploaded is not None:
pil_img = Image.open(uploaded)
st.subheader('Input image')
st.image(pil_img, use_container_width=True)
x = preprocess_image(pil_img, img_size=(96, 96))
# Predict
pred = model.predict(x, verbose=0)[0] # shape (30,)
# If your model predicts normalized coordinates, you must de-normalize:
# Your training: (y - 48) / 48 => inference: y = y_pred * 48 + 48
# We do it safely here:
pred = (pred * 48.0) + 48.0
# Clip to valid [0, 96]
pred = np.clip(pred, 0.0, 96.0)
pts = to_xy(pred)
st.subheader('Prediction (keypoints on original image)')
w, h = pil_img.size # original size
scale_x = w / 96.0
scale_y = h / 96.0
pts_scaled = [(x * scale_x, y * scale_y) for (x, y) in pts]
overlay = draw_keypoints(pil_img, pts_scaled)
st.image(overlay, use_container_width=True)
st.subheader('Prediction (keypoints on 96Γ96)')
img96 = pil_img.resize((96, 96)).convert('RGB')
overlay96 = draw_keypoints(img96, pts)
st.image(overlay96, use_container_width=False)
st.subheader('Keypoints table (x, y)')
# Build a nice table using target_cols order
# target_cols is typically like: ['left_eye_center_x', 'left_eye_center_y', ...]
rows = []
for i in range(0, len(target_cols), 2):
name_x = target_cols[i]
name_y = target_cols[i + 1]
rows.append({
'keypoint': name_x.replace('_x', ''),
'x_name': name_x,
'y_name': name_y,
'x': float(pred[i]),
'y': float(pred[i + 1]),
})
st.dataframe(rows, use_container_width=True)
else:
st.info('Upload an image to get predictions.') |