FocusGuardBaseModel / ui /live_demo.py
Kexin-251202's picture
Deploy base model
c86c45b verified
import argparse
import os
import sys
import time
import cv2
import numpy as np
from mediapipe.tasks.python.vision import FaceLandmarksConnections
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from ui.pipeline import (
FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, GRUPipeline,
_load_gru_artifacts, _latest_model_artifacts,
)
from models.face_mesh import FaceMeshDetector
FONT = cv2.FONT_HERSHEY_SIMPLEX
CYAN = (255, 255, 0)
GREEN = (0, 255, 0)
MAGENTA = (255, 0, 255)
ORANGE = (0, 165, 255)
RED = (0, 0, 255)
WHITE = (255, 255, 255)
YELLOW = (0, 255, 255)
LIGHT_GREEN = (144, 238, 144)
_TESSELATION = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
_CONTOURS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
MESH_FULL = 0
MESH_CONTOURS = 1
MESH_OFF = 2
_MESH_NAMES = ["FULL MESH", "CONTOURS", "MESH OFF"]
MODE_GEO = 0
MODE_MLP = 1
MODE_GRU = 2
MODE_HYBRID = 3
_MODE_NAMES = ["GEOMETRIC", "MLP", "GRU", "HYBRID"]
_MODE_KEYS = {ord("1"): MODE_GEO, ord("2"): MODE_MLP, ord("3"): MODE_GRU, ord("4"): MODE_HYBRID}
def _lm_to_px(landmarks, idx, w, h):
return (int(landmarks[idx, 0] * w), int(landmarks[idx, 1] * h))
def draw_tessellation(frame, landmarks, w, h):
overlay = frame.copy()
for conn in _TESSELATION:
pt1 = _lm_to_px(landmarks, conn[0], w, h)
pt2 = _lm_to_px(landmarks, conn[1], w, h)
cv2.line(overlay, pt1, pt2, (200, 200, 200), 1, cv2.LINE_AA)
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
def draw_contours(frame, landmarks, w, h):
for conn in _CONTOURS:
pt1 = _lm_to_px(landmarks, conn[0], w, h)
pt2 = _lm_to_px(landmarks, conn[1], w, h)
cv2.line(frame, pt1, pt2, CYAN, 1, cv2.LINE_AA)
for indices in [_LEFT_EYEBROW, _RIGHT_EYEBROW]:
for i in range(len(indices) - 1):
pt1 = _lm_to_px(landmarks, indices[i], w, h)
pt2 = _lm_to_px(landmarks, indices[i + 1], w, h)
cv2.line(frame, pt1, pt2, LIGHT_GREEN, 2, cv2.LINE_AA)
for i in range(len(_NOSE_BRIDGE) - 1):
pt1 = _lm_to_px(landmarks, _NOSE_BRIDGE[i], w, h)
pt2 = _lm_to_px(landmarks, _NOSE_BRIDGE[i + 1], w, h)
cv2.line(frame, pt1, pt2, ORANGE, 1, cv2.LINE_AA)
for i in range(len(_LIPS_OUTER) - 1):
pt1 = _lm_to_px(landmarks, _LIPS_OUTER[i], w, h)
pt2 = _lm_to_px(landmarks, _LIPS_OUTER[i + 1], w, h)
cv2.line(frame, pt1, pt2, MAGENTA, 1, cv2.LINE_AA)
for i in range(len(_LIPS_INNER) - 1):
pt1 = _lm_to_px(landmarks, _LIPS_INNER[i], w, h)
pt2 = _lm_to_px(landmarks, _LIPS_INNER[i + 1], w, h)
cv2.line(frame, pt1, pt2, (200, 0, 200), 1, cv2.LINE_AA)
def draw_eyes_and_irises(frame, landmarks, w, h):
left_pts = np.array(
[_lm_to_px(landmarks, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES],
dtype=np.int32,
)
cv2.polylines(frame, [left_pts], True, GREEN, 2, cv2.LINE_AA)
right_pts = np.array(
[_lm_to_px(landmarks, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES],
dtype=np.int32,
)
cv2.polylines(frame, [right_pts], True, GREEN, 2, cv2.LINE_AA)
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
for idx in indices:
pt = _lm_to_px(landmarks, idx, w, h)
cv2.circle(frame, pt, 3, YELLOW, -1, cv2.LINE_AA)
for iris_indices, eye_inner, eye_outer in [
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
]:
iris_pts = np.array(
[_lm_to_px(landmarks, i, w, h) for i in iris_indices],
dtype=np.int32,
)
center = iris_pts[0]
if len(iris_pts) >= 5:
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
radius = max(int(np.mean(radii)), 2)
cv2.circle(frame, tuple(center), radius, MAGENTA, 2, cv2.LINE_AA)
cv2.circle(frame, tuple(center), 2, WHITE, -1, cv2.LINE_AA)
eye_center_x = (landmarks[eye_inner, 0] + landmarks[eye_outer, 0]) / 2.0
eye_center_y = (landmarks[eye_inner, 1] + landmarks[eye_outer, 1]) / 2.0
eye_center = (int(eye_center_x * w), int(eye_center_y * h))
dx = center[0] - eye_center[0]
dy = center[1] - eye_center[1]
gaze_end = (int(center[0] + dx * 3), int(center[1] + dy * 3))
cv2.line(frame, tuple(center), gaze_end, RED, 1, cv2.LINE_AA)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--camera", type=int, default=0)
parser.add_argument("--mlp-dir", type=str, default=None)
parser.add_argument("--max-angle", type=float, default=22.0)
parser.add_argument("--eye-model", type=str, default=None)
parser.add_argument("--eye-backend", type=str, default="yolo", choices=["yolo", "geometric"])
parser.add_argument("--eye-blend", type=float, default=0.5)
args = parser.parse_args()
model_dir = args.mlp_dir or os.path.join(_PROJECT_ROOT, "checkpoints")
detector = FaceMeshDetector()
pipelines = {}
available_modes = []
pipelines[MODE_GEO] = FaceMeshPipeline(
max_angle=args.max_angle,
eye_model_path=args.eye_model,
eye_backend=args.eye_backend,
eye_blend=args.eye_blend,
detector=detector,
)
available_modes.append(MODE_GEO)
mlp_path, _, _ = _latest_model_artifacts(model_dir)
if mlp_path is not None:
try:
pipelines[MODE_MLP] = MLPPipeline(model_dir=model_dir, detector=detector)
available_modes.append(MODE_MLP)
except Exception as e:
print(f"[DEMO] MLP unavailable: {e}")
try:
pipelines[MODE_HYBRID] = HybridFocusPipeline(
model_dir=model_dir,
eye_model_path=args.eye_model,
eye_backend=args.eye_backend,
eye_blend=args.eye_blend,
max_angle=args.max_angle,
detector=detector,
)
available_modes.append(MODE_HYBRID)
except Exception as e:
print(f"[DEMO] Hybrid unavailable: {e}")
gru_arts = _load_gru_artifacts(model_dir)
if gru_arts[0] is not None:
try:
pipelines[MODE_GRU] = GRUPipeline(model_dir=model_dir, detector=detector)
available_modes.append(MODE_GRU)
except Exception as e:
print(f"[DEMO] GRU unavailable: {e}")
current_mode = available_modes[0]
pipeline = pipelines[current_mode]
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
print("[DEMO] ERROR: Cannot open camera")
return
mode_hint = " ".join(f"{k+1}:{_MODE_NAMES[k]}" for k in available_modes)
print(f"[DEMO] Available modes: {mode_hint}")
print(f"[DEMO] Active: {_MODE_NAMES[current_mode]}")
print("[DEMO] q=quit m=mesh 1-4=switch mode")
prev_time = time.time()
fps = 0.0
mesh_mode = MESH_FULL
try:
while True:
ret, frame = cap.read()
if not ret:
break
result = pipeline.process_frame(frame)
now = time.time()
fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
prev_time = now
h, w = frame.shape[:2]
lm = result["landmarks"]
if lm is not None:
if mesh_mode == MESH_FULL:
draw_tessellation(frame, lm, w, h)
draw_contours(frame, lm, w, h)
elif mesh_mode == MESH_CONTOURS:
draw_contours(frame, lm, w, h)
draw_eyes_and_irises(frame, lm, w, h)
if hasattr(pipeline, "head_pose"):
pipeline.head_pose.draw_axes(frame, lm)
if result.get("left_bbox") and result.get("right_bbox"):
lx1, ly1, lx2, ly2 = result["left_bbox"]
rx1, ry1, rx2, ry2 = result["right_bbox"]
cv2.rectangle(frame, (lx1, ly1), (lx2, ly2), YELLOW, 1)
cv2.rectangle(frame, (rx1, ry1), (rx2, ry2), YELLOW, 1)
# --- HUD ---
status = "FOCUSED" if result["is_focused"] else "NOT FOCUSED"
status_color = GREEN if result["is_focused"] else RED
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
cv2.putText(frame, status, (10, 28), FONT, 0.8, status_color, 2, cv2.LINE_AA)
mode_label = _MODE_NAMES[current_mode]
cv2.putText(frame, f"{mode_label} {_MESH_NAMES[mesh_mode]} FPS:{fps:.0f}",
(w - 340, 28), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
detail = ""
if current_mode == MODE_GEO:
sf = result.get("s_face", 0)
se = result.get("s_eye", 0)
rs = result.get("raw_score", 0)
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
detail = f"S_face:{sf:.2f} S_eye:{se:.2f}{mar_s} score:{rs:.2f}"
elif current_mode == MODE_MLP:
mp = result.get("mlp_prob", 0)
rs = result.get("raw_score", 0)
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
detail = f"mlp_prob:{mp:.2f} score:{rs:.2f}{mar_s}"
elif current_mode == MODE_GRU:
gp = result.get("gru_prob", 0)
rs = result.get("raw_score", 0)
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
detail = f"gru_prob:{gp:.2f} score:{rs:.2f}{mar_s}"
elif current_mode == MODE_HYBRID:
mp = result.get("mlp_prob", 0)
gs = result.get("geo_score", 0)
fs = result.get("focus_score", 0)
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
detail = f"focus:{fs:.2f} mlp:{mp:.2f} geo:{gs:.2f}{mar_s}"
cv2.putText(frame, detail, (10, 48), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
if result.get("is_yawning"):
cv2.putText(frame, "YAWN", (10, 75), FONT, 0.7, ORANGE, 2, cv2.LINE_AA)
if result.get("yaw") is not None:
cv2.putText(
frame,
f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
(w - 280, 48), FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA,
)
cv2.putText(frame, f"q:quit m:mesh {mode_hint}",
(10, h - 10), FONT, 0.35, (150, 150, 150), 1, cv2.LINE_AA)
cv2.imshow("FocusGuard", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
elif key == ord("m"):
mesh_mode = (mesh_mode + 1) % 3
print(f"[DEMO] Mesh: {_MESH_NAMES[mesh_mode]}")
elif key in _MODE_KEYS:
requested = _MODE_KEYS[key]
if requested in pipelines:
current_mode = requested
pipeline = pipelines[current_mode]
print(f"[DEMO] Switched to {_MODE_NAMES[current_mode]}")
else:
print(f"[DEMO] {_MODE_NAMES[requested]} not available (no checkpoint)")
finally:
cap.release()
cv2.destroyAllWindows()
for p in pipelines.values():
p.close()
detector.close()
print("[DEMO] Done")
if __name__ == "__main__":
main()