Yingtao-Zheng commited on
Commit
8bbb872
·
1 Parent(s): 05616fb

Upload partially updated files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +41 -0
  2. Dockerfile +27 -0
  3. api/history +0 -0
  4. api/import +0 -0
  5. api/sessions +0 -0
  6. app.py +1 -0
  7. checkpoints/hybrid_focus_config.json +10 -0
  8. checkpoints/meta_best.npz +3 -0
  9. checkpoints/mlp_best.pt +3 -0
  10. checkpoints/model_best.joblib +3 -0
  11. checkpoints/scaler_best.joblib +3 -0
  12. checkpoints/xgboost_face_orientation_best.json +0 -0
  13. docker-compose.yml +5 -0
  14. eslint.config.js +29 -0
  15. index.html +17 -0
  16. main.py +964 -0
  17. models/README.md +53 -0
  18. models/__init__.py +1 -0
  19. models/cnn/CNN_MODEL/.claude/settings.local.json +7 -0
  20. models/cnn/CNN_MODEL/.gitattributes +1 -0
  21. models/cnn/CNN_MODEL/.gitignore +4 -0
  22. models/cnn/CNN_MODEL/README.md +74 -0
  23. models/cnn/CNN_MODEL/notebooks/eye_classifier_colab.ipynb +0 -0
  24. models/cnn/CNN_MODEL/scripts/focus_infer.py +199 -0
  25. models/cnn/CNN_MODEL/scripts/predict_image.py +49 -0
  26. models/cnn/CNN_MODEL/scripts/video_infer.py +281 -0
  27. models/cnn/CNN_MODEL/scripts/webcam_live.py +184 -0
  28. models/cnn/CNN_MODEL/weights/yolo11s-cls.pt +3 -0
  29. models/cnn/__init__.py +0 -0
  30. models/cnn/eye_attention/__init__.py +1 -0
  31. models/cnn/eye_attention/classifier.py +169 -0
  32. models/cnn/eye_attention/crop.py +70 -0
  33. models/cnn/eye_attention/train.py +0 -0
  34. models/cnn/notebooks/EyeCNN.ipynb +107 -0
  35. models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb +0 -0
  36. models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb +0 -0
  37. models/cnn/notebooks/README.md +1 -0
  38. models/collect_features.py +356 -0
  39. models/eye_classifier.py +69 -0
  40. models/eye_crop.py +77 -0
  41. models/eye_scorer.py +168 -0
  42. models/face_mesh.py +94 -0
  43. models/head_pose.py +121 -0
  44. models/mlp/__init__.py +0 -0
  45. models/mlp/eval_accuracy.py +54 -0
  46. models/mlp/sweep.py +66 -0
  47. models/mlp/train.py +232 -0
  48. models/xgboost/add_accuracy.py +60 -0
  49. models/xgboost/checkpoints/face_orientation_best.json +0 -0
  50. models/xgboost/eval_accuracy.py +54 -0
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules/
11
+ dist/
12
+ dist-ssr/
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/
17
+ .idea/
18
+ .DS_Store
19
+ *.suo
20
+ *.ntvs*
21
+ *.njsproj
22
+ *.sln
23
+ *.sw?
24
+ *.py[cod]
25
+ *$py.class
26
+ *.so
27
+ .Python
28
+ venv/
29
+ .venv/
30
+ env/
31
+ .env
32
+ *.egg-info/
33
+ .eggs/
34
+ build/
35
+ Thumbs.db
36
+
37
+ # Project specific
38
+ focus_guard.db
39
+ static/
40
+ __pycache__/
41
+ docs/
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
5
+
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ WORKDIR /app
9
+
10
+ RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev build-essential nodejs npm && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY requirements.txt ./
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY . .
16
+
17
+ RUN npm install && npm run build && mkdir -p /app/static && cp -R dist/* /app/static/
18
+
19
+ ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
20
+ RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
21
+
22
+ RUN mkdir -p /app/data && chown -R user:user /app
23
+
24
+ USER user
25
+ EXPOSE 7860
26
+
27
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-level", "debug"]
api/history ADDED
File without changes
api/import ADDED
File without changes
api/sessions ADDED
File without changes
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from main import app
checkpoints/hybrid_focus_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "w_mlp": 0.6000000000000001,
3
+ "w_geo": 0.3999999999999999,
4
+ "threshold": 0.35,
5
+ "use_yawn_veto": true,
6
+ "geo_face_weight": 0.4,
7
+ "geo_eye_weight": 0.6,
8
+ "mar_yawn_threshold": 0.55,
9
+ "metric": "f1"
10
+ }
checkpoints/meta_best.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
3
+ size 4196
checkpoints/mlp_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2f55129785b6882c304483aa5399f5bf6c9ed6e73dfec7ca6f36cd0436156c8
3
+ size 14497
checkpoints/model_best.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:183f2d4419e0eb1e58704e5a7312eb61e331523566d4dc551054a07b3aac7557
3
+ size 5775881
checkpoints/scaler_best.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02ed6b4c0d99e0254c6a740a949da2384db58ec7d3e6df6432b9bfcd3a296c71
3
+ size 783
checkpoints/xgboost_face_orientation_best.json ADDED
The diff for this file is too large to render. See raw diff
 
docker-compose.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ services:
2
+ focus-guard:
3
+ build: .
4
+ ports:
5
+ - "7860:7860"
eslint.config.js ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import { defineConfig, globalIgnores } from 'eslint/config'
6
+
7
+ export default defineConfig([
8
+ globalIgnores(['dist']),
9
+ {
10
+ files: ['**/*.{js,jsx}'],
11
+ extends: [
12
+ js.configs.recommended,
13
+ reactHooks.configs.flat.recommended,
14
+ reactRefresh.configs.vite,
15
+ ],
16
+ languageOptions: {
17
+ ecmaVersion: 2020,
18
+ globals: globals.browser,
19
+ parserOptions: {
20
+ ecmaVersion: 'latest',
21
+ ecmaFeatures: { jsx: true },
22
+ sourceType: 'module',
23
+ },
24
+ },
25
+ rules: {
26
+ 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
+ },
28
+ },
29
+ ])
index.html ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
8
+ <title>Focus Guard</title>
9
+ <link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
10
+ </head>
11
+
12
+ <body>
13
+ <div id="root"></div>
14
+ <script type="module" src="/src/main.jsx"></script>
15
+ </body>
16
+
17
+ </html>
main.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Optional, List, Any
7
+ import base64
8
+ import cv2
9
+ import numpy as np
10
+ import aiosqlite
11
+ import json
12
+ from datetime import datetime, timedelta
13
+ import math
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Callable
17
+ import asyncio
18
+ import concurrent.futures
19
+ import threading
20
+
21
+ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
22
+ from av import VideoFrame
23
+
24
+ from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
+ from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
26
+ from models.face_mesh import FaceMeshDetector
27
+
28
+ # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
29
+
30
+ _FONT = cv2.FONT_HERSHEY_SIMPLEX
31
+ _CYAN = (255, 255, 0)
32
+ _GREEN = (0, 255, 0)
33
+ _MAGENTA = (255, 0, 255)
34
+ _ORANGE = (0, 165, 255)
35
+ _RED = (0, 0, 255)
36
+ _WHITE = (255, 255, 255)
37
+ _LIGHT_GREEN = (144, 238, 144)
38
+
39
+ _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
40
+ _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
41
+ _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
42
+ _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
43
+ _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
44
+ _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
45
+ _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
46
+ _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
47
+ _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
48
+
49
+
50
+ def _lm_px(lm, idx, w, h):
51
+ return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
52
+
53
+
54
+ def _draw_polyline(frame, lm, indices, w, h, color, thickness):
55
+ for i in range(len(indices) - 1):
56
+ cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
57
+
58
+
59
+ def _draw_face_mesh(frame, lm, w, h):
60
+ """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
61
+ # Tessellation (gray triangular grid, semi-transparent)
62
+ overlay = frame.copy()
63
+ for s, e in _TESSELATION_CONNS:
64
+ cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
65
+ cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
66
+ # Contours
67
+ for s, e in _CONTOUR_CONNS:
68
+ cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
69
+ # Eyebrows
70
+ _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
71
+ _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
72
+ # Nose
73
+ _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
74
+ # Lips
75
+ _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
76
+ _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
77
+ # Eyes
78
+ left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
79
+ cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
80
+ right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
81
+ cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
82
+ # EAR key points
83
+ for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
84
+ for idx in indices:
85
+ cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
86
+ # Irises + gaze lines
87
+ for iris_idx, eye_inner, eye_outer in [
88
+ (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
89
+ (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
90
+ ]:
91
+ iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
92
+ center = iris_pts[0]
93
+ if len(iris_pts) >= 5:
94
+ radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
95
+ radius = max(int(np.mean(radii)), 2)
96
+ cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
97
+ cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
98
+ eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
99
+ eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
100
+ dx, dy = center[0] - eye_cx, center[1] - eye_cy
101
+ cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
102
+
103
+
104
+ def _draw_hud(frame, result, model_name):
105
+ """Draw status bar and detail overlay matching live_demo.py."""
106
+ h, w = frame.shape[:2]
107
+ is_focused = result["is_focused"]
108
+ status = "FOCUSED" if is_focused else "NOT FOCUSED"
109
+ color = _GREEN if is_focused else _RED
110
+
111
+ # Top bar
112
+ cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
113
+ cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
114
+ cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
115
+
116
+ # Detail line
117
+ conf = result.get("mlp_prob", result.get("raw_score", 0.0))
118
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
119
+ sf = result.get("s_face", 0)
120
+ se = result.get("s_eye", 0)
121
+ detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
122
+ cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
123
+
124
+ # Head pose (top right)
125
+ if result.get("yaw") is not None:
126
+ cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
127
+ (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
128
+
129
+ # Yawn indicator
130
+ if result.get("is_yawning"):
131
+ cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
132
+
133
+ # Landmark indices used for face mesh drawing on client (union of all groups).
134
+ # Sending only these instead of all 478 saves ~60% of the landmarks payload.
135
+ _MESH_INDICES = sorted(set(
136
+ [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
137
+ + [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
138
+ + [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
139
+ + [468,469,470,471,472, 473,474,475,476,477] # irises
140
+ + [70,63,105,66,107,55,65,52,53,46] # left eyebrow
141
+ + [300,293,334,296,336,285,295,282,283,276] # right eyebrow
142
+ + [6,197,195,5,4,1,19,94,2] # nose bridge
143
+ + [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
144
+ + [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
145
+ + [33,160,158,133,153,145] # left EAR key points
146
+ + [362,385,387,263,373,380] # right EAR key points
147
+ ))
148
+ # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
149
+ _MESH_INDEX_SET = set(_MESH_INDICES)
150
+
151
+ # Initialize FastAPI app
152
+ app = FastAPI(title="Focus Guard API")
153
+
154
+ # Add CORS middleware
155
+ app.add_middleware(
156
+ CORSMiddleware,
157
+ allow_origins=["*"],
158
+ allow_credentials=True,
159
+ allow_methods=["*"],
160
+ allow_headers=["*"],
161
+ )
162
+
163
+ # Global variables
164
+ db_path = "focus_guard.db"
165
+ pcs = set()
166
+ _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
167
+
168
+ async def _wait_for_ice_gathering(pc: RTCPeerConnection):
169
+ if pc.iceGatheringState == "complete":
170
+ return
171
+ done = asyncio.Event()
172
+
173
+ @pc.on("icegatheringstatechange")
174
+ def _on_state_change():
175
+ if pc.iceGatheringState == "complete":
176
+ done.set()
177
+
178
+ await done.wait()
179
+
180
+ # ================ DATABASE MODELS ================
181
+
182
+ async def init_database():
183
+ """Initialize SQLite database with required tables"""
184
+ async with aiosqlite.connect(db_path) as db:
185
+ # FocusSessions table
186
+ await db.execute("""
187
+ CREATE TABLE IF NOT EXISTS focus_sessions (
188
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
189
+ start_time TIMESTAMP NOT NULL,
190
+ end_time TIMESTAMP,
191
+ duration_seconds INTEGER DEFAULT 0,
192
+ focus_score REAL DEFAULT 0.0,
193
+ total_frames INTEGER DEFAULT 0,
194
+ focused_frames INTEGER DEFAULT 0,
195
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
196
+ )
197
+ """)
198
+
199
+ # FocusEvents table
200
+ await db.execute("""
201
+ CREATE TABLE IF NOT EXISTS focus_events (
202
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
203
+ session_id INTEGER NOT NULL,
204
+ timestamp TIMESTAMP NOT NULL,
205
+ is_focused BOOLEAN NOT NULL,
206
+ confidence REAL NOT NULL,
207
+ detection_data TEXT,
208
+ FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
209
+ )
210
+ """)
211
+
212
+ # UserSettings table
213
+ await db.execute("""
214
+ CREATE TABLE IF NOT EXISTS user_settings (
215
+ id INTEGER PRIMARY KEY CHECK (id = 1),
216
+ sensitivity INTEGER DEFAULT 6,
217
+ notification_enabled BOOLEAN DEFAULT 1,
218
+ notification_threshold INTEGER DEFAULT 30,
219
+ frame_rate INTEGER DEFAULT 30,
220
+ model_name TEXT DEFAULT 'mlp'
221
+ )
222
+ """)
223
+
224
+ # Insert default settings if not exists
225
+ await db.execute("""
226
+ INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
227
+ VALUES (1, 6, 1, 30, 30, 'mlp')
228
+ """)
229
+
230
+ await db.commit()
231
+
232
+ # ================ PYDANTIC MODELS ================
233
+
234
+ class SessionCreate(BaseModel):
235
+ pass
236
+
237
+ class SessionEnd(BaseModel):
238
+ session_id: int
239
+
240
+ class SettingsUpdate(BaseModel):
241
+ sensitivity: Optional[int] = None
242
+ notification_enabled: Optional[bool] = None
243
+ notification_threshold: Optional[int] = None
244
+ frame_rate: Optional[int] = None
245
+ model_name: Optional[str] = None
246
+
247
+ class VideoTransformTrack(VideoStreamTrack):
248
+ def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
249
+ super().__init__()
250
+ self.track = track
251
+ self.session_id = session_id
252
+ self.get_channel = get_channel
253
+ self.last_inference_time = 0
254
+ self.min_inference_interval = 1 / 60
255
+ self.last_frame = None
256
+
257
+ async def recv(self):
258
+ frame = await self.track.recv()
259
+ img = frame.to_ndarray(format="bgr24")
260
+ if img is None:
261
+ return frame
262
+
263
+ # Normalize size for inference/drawing
264
+ img = cv2.resize(img, (640, 480))
265
+
266
+ now = datetime.now().timestamp()
267
+ do_infer = (now - self.last_inference_time) >= self.min_inference_interval
268
+
269
+ if do_infer:
270
+ self.last_inference_time = now
271
+
272
+ model_name = _cached_model_name
273
+ if model_name not in pipelines or pipelines.get(model_name) is None:
274
+ model_name = 'mlp'
275
+ active_pipeline = pipelines.get(model_name)
276
+
277
+ if active_pipeline is not None:
278
+ loop = asyncio.get_event_loop()
279
+ out = await loop.run_in_executor(
280
+ _inference_executor,
281
+ _process_frame_safe,
282
+ active_pipeline,
283
+ img,
284
+ model_name,
285
+ )
286
+ is_focused = out["is_focused"]
287
+ confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
288
+ metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
289
+
290
+ # Draw face mesh + HUD on the video frame
291
+ h_f, w_f = img.shape[:2]
292
+ lm = out.get("landmarks")
293
+ if lm is not None:
294
+ _draw_face_mesh(img, lm, w_f, h_f)
295
+ _draw_hud(img, out, model_name)
296
+ else:
297
+ is_focused = False
298
+ confidence = 0.0
299
+ metadata = {"model": model_name}
300
+ cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
301
+ cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
302
+
303
+ if self.session_id:
304
+ await store_focus_event(self.session_id, is_focused, confidence, metadata)
305
+
306
+ channel = self.get_channel()
307
+ if channel and channel.readyState == "open":
308
+ try:
309
+ channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
310
+ except Exception:
311
+ pass
312
+
313
+ self.last_frame = img
314
+ elif self.last_frame is not None:
315
+ img = self.last_frame
316
+
317
+ new_frame = VideoFrame.from_ndarray(img, format="bgr24")
318
+ new_frame.pts = frame.pts
319
+ new_frame.time_base = frame.time_base
320
+ return new_frame
321
+
322
+ # ================ DATABASE OPERATIONS ================
323
+
324
+ async def create_session():
325
+ async with aiosqlite.connect(db_path) as db:
326
+ cursor = await db.execute(
327
+ "INSERT INTO focus_sessions (start_time) VALUES (?)",
328
+ (datetime.now().isoformat(),)
329
+ )
330
+ await db.commit()
331
+ return cursor.lastrowid
332
+
333
+ async def end_session(session_id: int):
334
+ async with aiosqlite.connect(db_path) as db:
335
+ cursor = await db.execute(
336
+ "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
337
+ (session_id,)
338
+ )
339
+ row = await cursor.fetchone()
340
+
341
+ if not row:
342
+ return None
343
+
344
+ start_time_str, total_frames, focused_frames = row
345
+ start_time = datetime.fromisoformat(start_time_str)
346
+ end_time = datetime.now()
347
+ duration = (end_time - start_time).total_seconds()
348
+ focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
349
+
350
+ await db.execute("""
351
+ UPDATE focus_sessions
352
+ SET end_time = ?, duration_seconds = ?, focus_score = ?
353
+ WHERE id = ?
354
+ """, (end_time.isoformat(), int(duration), focus_score, session_id))
355
+
356
+ await db.commit()
357
+
358
+ return {
359
+ 'session_id': session_id,
360
+ 'start_time': start_time_str,
361
+ 'end_time': end_time.isoformat(),
362
+ 'duration_seconds': int(duration),
363
+ 'focus_score': round(focus_score, 3),
364
+ 'total_frames': total_frames,
365
+ 'focused_frames': focused_frames
366
+ }
367
+
368
+ async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
369
+ async with aiosqlite.connect(db_path) as db:
370
+ await db.execute("""
371
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
372
+ VALUES (?, ?, ?, ?, ?)
373
+ """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
374
+
375
+ await db.execute("""
376
+ UPDATE focus_sessions
377
+ SET total_frames = total_frames + 1,
378
+ focused_frames = focused_frames + ?
379
+ WHERE id = ?
380
+ """, (1 if is_focused else 0, session_id))
381
+ await db.commit()
382
+
383
+
384
+ class _EventBuffer:
385
+ """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
386
+
387
+ def __init__(self, flush_interval: float = 2.0):
388
+ self._buf: list = []
389
+ self._lock = asyncio.Lock()
390
+ self._flush_interval = flush_interval
391
+ self._task: asyncio.Task | None = None
392
+ self._total_frames = 0
393
+ self._focused_frames = 0
394
+
395
+ def start(self):
396
+ if self._task is None:
397
+ self._task = asyncio.create_task(self._flush_loop())
398
+
399
+ async def stop(self):
400
+ if self._task:
401
+ self._task.cancel()
402
+ try:
403
+ await self._task
404
+ except asyncio.CancelledError:
405
+ pass
406
+ self._task = None
407
+ await self._flush()
408
+
409
+ def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
410
+ self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
411
+ self._total_frames += 1
412
+ if is_focused:
413
+ self._focused_frames += 1
414
+
415
+ async def _flush_loop(self):
416
+ while True:
417
+ await asyncio.sleep(self._flush_interval)
418
+ await self._flush()
419
+
420
+ async def _flush(self):
421
+ async with self._lock:
422
+ if not self._buf:
423
+ return
424
+ batch = self._buf[:]
425
+ total = self._total_frames
426
+ focused = self._focused_frames
427
+ self._buf.clear()
428
+ self._total_frames = 0
429
+ self._focused_frames = 0
430
+
431
+ if not batch:
432
+ return
433
+
434
+ session_id = batch[0][0]
435
+ try:
436
+ async with aiosqlite.connect(db_path) as db:
437
+ await db.executemany("""
438
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
439
+ VALUES (?, ?, ?, ?, ?)
440
+ """, batch)
441
+ await db.execute("""
442
+ UPDATE focus_sessions
443
+ SET total_frames = total_frames + ?,
444
+ focused_frames = focused_frames + ?
445
+ WHERE id = ?
446
+ """, (total, focused, session_id))
447
+ await db.commit()
448
+ except Exception as e:
449
+ print(f"[DB] Flush error: {e}")
450
+
451
+ # ================ STARTUP/SHUTDOWN ================
452
+
453
+ pipelines = {
454
+ "geometric": None,
455
+ "mlp": None,
456
+ "hybrid": None,
457
+ "xgboost": None,
458
+ }
459
+
460
+ # Thread pool for CPU-bound inference so the event loop stays responsive.
461
+ _inference_executor = concurrent.futures.ThreadPoolExecutor(
462
+ max_workers=4,
463
+ thread_name_prefix="inference",
464
+ )
465
+ # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
466
+ # multiple frames are processed in parallel by the thread pool.
467
+ _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
468
+
469
+
470
+ def _process_frame_safe(pipeline, frame, model_name: str):
471
+ """Run process_frame in executor with per-pipeline lock."""
472
+ with _pipeline_locks[model_name]:
473
+ return pipeline.process_frame(frame)
474
+
475
+ @app.on_event("startup")
476
+ async def startup_event():
477
+ global pipelines, _cached_model_name
478
+ print(" Starting Focus Guard API...")
479
+ await init_database()
480
+ # Load cached model name from DB
481
+ async with aiosqlite.connect(db_path) as db:
482
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
483
+ row = await cursor.fetchone()
484
+ if row:
485
+ _cached_model_name = row[0]
486
+ print("[OK] Database initialized")
487
+
488
+ try:
489
+ pipelines["geometric"] = FaceMeshPipeline()
490
+ print("[OK] FaceMeshPipeline (geometric) loaded")
491
+ except Exception as e:
492
+ print(f"[WARN] FaceMeshPipeline unavailable: {e}")
493
+
494
+ try:
495
+ pipelines["mlp"] = MLPPipeline()
496
+ print("[OK] MLPPipeline loaded")
497
+ except Exception as e:
498
+ print(f"[ERR] Failed to load MLPPipeline: {e}")
499
+
500
+ try:
501
+ pipelines["hybrid"] = HybridFocusPipeline()
502
+ print("[OK] HybridFocusPipeline loaded")
503
+ except Exception as e:
504
+ print(f"[WARN] HybridFocusPipeline unavailable: {e}")
505
+
506
+ try:
507
+ pipelines["xgboost"] = XGBoostPipeline()
508
+ print("[OK] XGBoostPipeline loaded")
509
+ except Exception as e:
510
+ print(f"[ERR] Failed to load XGBoostPipeline: {e}")
511
+
512
+ @app.on_event("shutdown")
513
+ async def shutdown_event():
514
+ _inference_executor.shutdown(wait=False)
515
+ print(" Shutting down Focus Guard API...")
516
+
517
+ # ================ WEBRTC SIGNALING ================
518
+
519
+ @app.post("/api/webrtc/offer")
520
+ async def webrtc_offer(offer: dict):
521
+ try:
522
+ print(f"Received WebRTC offer")
523
+
524
+ pc = RTCPeerConnection()
525
+ pcs.add(pc)
526
+
527
+ session_id = await create_session()
528
+ print(f"Created session: {session_id}")
529
+
530
+ channel_ref = {"channel": None}
531
+
532
+ @pc.on("datachannel")
533
+ def on_datachannel(channel):
534
+ print(f"Data channel opened")
535
+ channel_ref["channel"] = channel
536
+
537
+ @pc.on("track")
538
+ def on_track(track):
539
+ print(f"Received track: {track.kind}")
540
+ if track.kind == "video":
541
+ local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
542
+ pc.addTrack(local_track)
543
+ print(f"Video track added")
544
+
545
+ @track.on("ended")
546
+ async def on_ended():
547
+ print(f"Track ended")
548
+
549
+ @pc.on("connectionstatechange")
550
+ async def on_connectionstatechange():
551
+ print(f"Connection state changed: {pc.connectionState}")
552
+ if pc.connectionState in ("failed", "closed", "disconnected"):
553
+ try:
554
+ await end_session(session_id)
555
+ except Exception as e:
556
+ print(f"⚠Error ending session: {e}")
557
+ pcs.discard(pc)
558
+ await pc.close()
559
+
560
+ await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
561
+ print(f"Remote description set")
562
+
563
+ answer = await pc.createAnswer()
564
+ await pc.setLocalDescription(answer)
565
+ print(f"Answer created")
566
+
567
+ await _wait_for_ice_gathering(pc)
568
+ print(f"ICE gathering complete")
569
+
570
+ return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
571
+
572
+ except Exception as e:
573
+ print(f"WebRTC offer error: {e}")
574
+ import traceback
575
+ traceback.print_exc()
576
+ raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
577
+
578
+ # ================ WEBSOCKET ================
579
+
580
+ @app.websocket("/ws/video")
581
+ async def websocket_endpoint(websocket: WebSocket):
582
+ await websocket.accept()
583
+ session_id = None
584
+ frame_count = 0
585
+ running = True
586
+ event_buffer = _EventBuffer(flush_interval=2.0)
587
+
588
+ # Latest frame slot — only the most recent frame is kept, older ones are dropped.
589
+ # Using a dict so nested functions can mutate without nonlocal issues.
590
+ _slot = {"frame": None}
591
+ _frame_ready = asyncio.Event()
592
+
593
+ async def _receive_loop():
594
+ """Receive messages as fast as possible. Binary = frame, text = control."""
595
+ nonlocal session_id, running
596
+ try:
597
+ while running:
598
+ msg = await websocket.receive()
599
+ msg_type = msg.get("type", "")
600
+
601
+ if msg_type == "websocket.disconnect":
602
+ running = False
603
+ _frame_ready.set()
604
+ return
605
+
606
+ # Binary message → JPEG frame (fast path, no base64)
607
+ raw_bytes = msg.get("bytes")
608
+ if raw_bytes is not None and len(raw_bytes) > 0:
609
+ _slot["frame"] = raw_bytes
610
+ _frame_ready.set()
611
+ continue
612
+
613
+ # Text message → JSON control command (or legacy base64 frame)
614
+ text = msg.get("text")
615
+ if not text:
616
+ continue
617
+ data = json.loads(text)
618
+
619
+ if data["type"] == "frame":
620
+ # Legacy base64 path (fallback)
621
+ _slot["frame"] = base64.b64decode(data["image"])
622
+ _frame_ready.set()
623
+
624
+ elif data["type"] == "start_session":
625
+ session_id = await create_session()
626
+ event_buffer.start()
627
+ for p in pipelines.values():
628
+ if p is not None and hasattr(p, "reset_session"):
629
+ p.reset_session()
630
+ await websocket.send_json({"type": "session_started", "session_id": session_id})
631
+
632
+ elif data["type"] == "end_session":
633
+ if session_id:
634
+ await event_buffer.stop()
635
+ summary = await end_session(session_id)
636
+ if summary:
637
+ await websocket.send_json({"type": "session_ended", "summary": summary})
638
+ session_id = None
639
+ except WebSocketDisconnect:
640
+ running = False
641
+ _frame_ready.set()
642
+ except Exception as e:
643
+ print(f"[WS] receive error: {e}")
644
+ running = False
645
+ _frame_ready.set()
646
+
647
+ async def _process_loop():
648
+ """Process only the latest frame, dropping stale ones."""
649
+ nonlocal frame_count, running
650
+ loop = asyncio.get_event_loop()
651
+ while running:
652
+ await _frame_ready.wait()
653
+ _frame_ready.clear()
654
+ if not running:
655
+ return
656
+
657
+ # Grab latest frame and clear slot
658
+ raw = _slot["frame"]
659
+ _slot["frame"] = None
660
+ if raw is None:
661
+ continue
662
+
663
+ try:
664
+ nparr = np.frombuffer(raw, np.uint8)
665
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
666
+ if frame is None:
667
+ continue
668
+ frame = cv2.resize(frame, (640, 480))
669
+
670
+ model_name = _cached_model_name
671
+ if model_name not in pipelines or pipelines.get(model_name) is None:
672
+ model_name = "mlp"
673
+ active_pipeline = pipelines.get(model_name)
674
+
675
+ landmarks_list = None
676
+ if active_pipeline is not None:
677
+ out = await loop.run_in_executor(
678
+ _inference_executor,
679
+ _process_frame_safe,
680
+ active_pipeline,
681
+ frame,
682
+ model_name,
683
+ )
684
+ is_focused = out["is_focused"]
685
+ confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
686
+
687
+ lm = out.get("landmarks")
688
+ if lm is not None:
689
+ # Send all 478 landmarks as flat array for tessellation drawing
690
+ landmarks_list = [
691
+ [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
692
+ for i in range(lm.shape[0])
693
+ ]
694
+
695
+ if session_id:
696
+ event_buffer.add(session_id, is_focused, confidence, {
697
+ "s_face": out.get("s_face", 0.0),
698
+ "s_eye": out.get("s_eye", 0.0),
699
+ "mar": out.get("mar", 0.0),
700
+ "model": model_name,
701
+ })
702
+ else:
703
+ is_focused = False
704
+ confidence = 0.0
705
+
706
+ resp = {
707
+ "type": "detection",
708
+ "focused": is_focused,
709
+ "confidence": round(confidence, 3),
710
+ "model": model_name,
711
+ "fc": frame_count,
712
+ }
713
+ if active_pipeline is not None:
714
+ # Send detailed metrics for HUD
715
+ if out.get("yaw") is not None:
716
+ resp["yaw"] = round(out["yaw"], 1)
717
+ resp["pitch"] = round(out["pitch"], 1)
718
+ resp["roll"] = round(out["roll"], 1)
719
+ if out.get("mar") is not None:
720
+ resp["mar"] = round(out["mar"], 3)
721
+ resp["sf"] = round(out.get("s_face", 0), 3)
722
+ resp["se"] = round(out.get("s_eye", 0), 3)
723
+ if landmarks_list is not None:
724
+ resp["lm"] = landmarks_list
725
+ await websocket.send_json(resp)
726
+ frame_count += 1
727
+ except Exception as e:
728
+ print(f"[WS] process error: {e}")
729
+
730
+ try:
731
+ await asyncio.gather(_receive_loop(), _process_loop())
732
+ except Exception:
733
+ pass
734
+ finally:
735
+ running = False
736
+ if session_id:
737
+ await event_buffer.stop()
738
+ await end_session(session_id)
739
+
740
+ # ================ API ENDPOINTS ================
741
+
742
+ @app.post("/api/sessions/start")
743
+ async def api_start_session():
744
+ session_id = await create_session()
745
+ return {"session_id": session_id}
746
+
747
+ @app.post("/api/sessions/end")
748
+ async def api_end_session(data: SessionEnd):
749
+ summary = await end_session(data.session_id)
750
+ if not summary: raise HTTPException(status_code=404, detail="Session not found")
751
+ return summary
752
+
753
+ @app.get("/api/sessions")
754
+ async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
755
+ async with aiosqlite.connect(db_path) as db:
756
+ db.row_factory = aiosqlite.Row
757
+
758
+ # NEW: If importing/exporting all, remove limit if special flag or high limit
759
+ # For simplicity: if limit is -1, return all
760
+ limit_clause = "LIMIT ? OFFSET ?"
761
+ params = []
762
+
763
+ base_query = "SELECT * FROM focus_sessions"
764
+ where_clause = ""
765
+
766
+ if filter == "today":
767
+ date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
768
+ where_clause = " WHERE start_time >= ?"
769
+ params.append(date_filter.isoformat())
770
+ elif filter == "week":
771
+ date_filter = datetime.now() - timedelta(days=7)
772
+ where_clause = " WHERE start_time >= ?"
773
+ params.append(date_filter.isoformat())
774
+ elif filter == "month":
775
+ date_filter = datetime.now() - timedelta(days=30)
776
+ where_clause = " WHERE start_time >= ?"
777
+ params.append(date_filter.isoformat())
778
+ elif filter == "all":
779
+ # Just ensure we only get completed sessions or all sessions
780
+ where_clause = " WHERE end_time IS NOT NULL"
781
+
782
+ query = f"{base_query}{where_clause} ORDER BY start_time DESC"
783
+
784
+ # Handle Limit for Exports
785
+ if limit == -1:
786
+ # No limit clause for export
787
+ pass
788
+ else:
789
+ query += f" {limit_clause}"
790
+ params.extend([limit, offset])
791
+
792
+ cursor = await db.execute(query, tuple(params))
793
+ rows = await cursor.fetchall()
794
+ return [dict(row) for row in rows]
795
+
796
+ # --- NEW: Import Endpoint ---
797
+ @app.post("/api/import")
798
+ async def import_sessions(sessions: List[dict]):
799
+ count = 0
800
+ try:
801
+ async with aiosqlite.connect(db_path) as db:
802
+ for session in sessions:
803
+ # Use .get() to handle potential missing fields from older versions or edits
804
+ await db.execute("""
805
+ INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
806
+ VALUES (?, ?, ?, ?, ?, ?, ?)
807
+ """, (
808
+ session.get('start_time'),
809
+ session.get('end_time'),
810
+ session.get('duration_seconds', 0),
811
+ session.get('focus_score', 0.0),
812
+ session.get('total_frames', 0),
813
+ session.get('focused_frames', 0),
814
+ session.get('created_at', session.get('start_time'))
815
+ ))
816
+ count += 1
817
+ await db.commit()
818
+ return {"status": "success", "count": count}
819
+ except Exception as e:
820
+ print(f"Import Error: {e}")
821
+ return {"status": "error", "message": str(e)}
822
+
823
+ # --- NEW: Clear History Endpoint ---
824
+ @app.delete("/api/history")
825
+ async def clear_history():
826
+ try:
827
+ async with aiosqlite.connect(db_path) as db:
828
+ # Delete events first (foreign key good practice)
829
+ await db.execute("DELETE FROM focus_events")
830
+ await db.execute("DELETE FROM focus_sessions")
831
+ await db.commit()
832
+ return {"status": "success", "message": "History cleared"}
833
+ except Exception as e:
834
+ return {"status": "error", "message": str(e)}
835
+
836
+ @app.get("/api/sessions/{session_id}")
837
+ async def get_session(session_id: int):
838
+ async with aiosqlite.connect(db_path) as db:
839
+ db.row_factory = aiosqlite.Row
840
+ cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
841
+ row = await cursor.fetchone()
842
+ if not row: raise HTTPException(status_code=404, detail="Session not found")
843
+ session = dict(row)
844
+ cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
845
+ events = [dict(r) for r in await cursor.fetchall()]
846
+ session['events'] = events
847
+ return session
848
+
849
+ @app.get("/api/settings")
850
+ async def get_settings():
851
+ async with aiosqlite.connect(db_path) as db:
852
+ db.row_factory = aiosqlite.Row
853
+ cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
854
+ row = await cursor.fetchone()
855
+ if row: return dict(row)
856
+ else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
857
+
858
+ @app.put("/api/settings")
859
+ async def update_settings(settings: SettingsUpdate):
860
+ async with aiosqlite.connect(db_path) as db:
861
+ cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
862
+ exists = await cursor.fetchone()
863
+ if not exists:
864
+ await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
865
+ await db.commit()
866
+
867
+ updates = []
868
+ params = []
869
+ if settings.sensitivity is not None:
870
+ updates.append("sensitivity = ?")
871
+ params.append(max(1, min(10, settings.sensitivity)))
872
+ if settings.notification_enabled is not None:
873
+ updates.append("notification_enabled = ?")
874
+ params.append(settings.notification_enabled)
875
+ if settings.notification_threshold is not None:
876
+ updates.append("notification_threshold = ?")
877
+ params.append(max(5, min(300, settings.notification_threshold)))
878
+ if settings.frame_rate is not None:
879
+ updates.append("frame_rate = ?")
880
+ params.append(max(5, min(60, settings.frame_rate)))
881
+ if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
882
+ updates.append("model_name = ?")
883
+ params.append(settings.model_name)
884
+ global _cached_model_name
885
+ _cached_model_name = settings.model_name
886
+
887
+ if updates:
888
+ query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
889
+ await db.execute(query, params)
890
+ await db.commit()
891
+ return {"status": "success", "updated": len(updates) > 0}
892
+
893
+ @app.get("/api/stats/summary")
894
+ async def get_stats_summary():
895
+ async with aiosqlite.connect(db_path) as db:
896
+ cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
897
+ total_sessions = (await cursor.fetchone())[0]
898
+ cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
899
+ total_focus_time = (await cursor.fetchone())[0] or 0
900
+ cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
901
+ avg_focus_score = (await cursor.fetchone())[0] or 0.0
902
+ cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
903
+ dates = [row[0] for row in await cursor.fetchall()]
904
+
905
+ streak_days = 0
906
+ if dates:
907
+ current_date = datetime.now().date()
908
+ for i, date_str in enumerate(dates):
909
+ session_date = datetime.fromisoformat(date_str).date()
910
+ expected_date = current_date - timedelta(days=i)
911
+ if session_date == expected_date: streak_days += 1
912
+ else: break
913
+ return {
914
+ 'total_sessions': total_sessions,
915
+ 'total_focus_time': int(total_focus_time),
916
+ 'avg_focus_score': round(avg_focus_score, 3),
917
+ 'streak_days': streak_days
918
+ }
919
+
920
+ @app.get("/api/models")
921
+ async def get_available_models():
922
+ """Return list of loaded model names and which is currently active."""
923
+ available = [name for name, p in pipelines.items() if p is not None]
924
+ async with aiosqlite.connect(db_path) as db:
925
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
926
+ row = await cursor.fetchone()
927
+ current = row[0] if row else "mlp"
928
+ if current not in available and available:
929
+ current = available[0]
930
+ return {"available": available, "current": current}
931
+
932
+ @app.get("/api/mesh-topology")
933
+ async def get_mesh_topology():
934
+ """Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
935
+ return {"tessellation": _TESSELATION_CONNS}
936
+
937
+ @app.get("/health")
938
+ async def health_check():
939
+ available = [name for name, p in pipelines.items() if p is not None]
940
+ return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
941
+
942
+ # ================ STATIC FILES (SPA SUPPORT) ================
943
+
944
+ # Resolve static dir from this file so it works regardless of cwd
945
+ _STATIC_DIR = Path(__file__).resolve().parent / "static"
946
+ _ASSETS_DIR = _STATIC_DIR / "assets"
947
+
948
+ # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
949
+ if _ASSETS_DIR.is_dir():
950
+ app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
951
+
952
+ # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
953
+ @app.get("/{full_path:path}")
954
+ async def serve_react_app(full_path: str, request: Request):
955
+ if full_path.startswith("api") or full_path.startswith("ws"):
956
+ raise HTTPException(status_code=404, detail="Not Found")
957
+ # Don't serve HTML for asset paths; let them 404 so we don't break module script loading
958
+ if full_path.startswith("assets") or full_path.startswith("assets/"):
959
+ raise HTTPException(status_code=404, detail="Not Found")
960
+
961
+ index_path = _STATIC_DIR / "index.html"
962
+ if index_path.is_file():
963
+ return FileResponse(str(index_path))
964
+ return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
models/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/
2
+
3
+ Feature extraction modules and model training scripts.
4
+
5
+ ## 1. Feature Extraction
6
+
7
+ Root-level modules form the real-time inference pipeline:
8
+
9
+ | Module | Input | Output |
10
+ |--------|-------|--------|
11
+ | `face_mesh.py` | BGR frame | 478 MediaPipe landmarks |
12
+ | `head_pose.py` | Landmarks, frame size | yaw, pitch, roll, face/eye score, gaze offset, head deviation |
13
+ | `eye_scorer.py` | Landmarks | EAR (left/right/avg), gaze ratio (h/v), MAR |
14
+ | `eye_crop.py` | Landmarks, frame | Cropped eye region images |
15
+ | `eye_classifier.py` | Eye crops or landmarks | Eye open/closed prediction (geometric fallback) |
16
+ | `collect_features.py` | BGR frame | 17-d feature vector + temporal features (PERCLOS, blink rate, etc.) |
17
+
18
+ ## 2. Training Scripts
19
+
20
+ | Folder | Model | Command |
21
+ |--------|-------|---------|
22
+ | `mlp/` | PyTorch MLP (64→32, 2-class) | `python -m models.mlp.train` |
23
+ | `xgboost/` | XGBoost (600 trees, depth 8) | `python -m models.xgboost.train` |
24
+
25
+ ### mlp/
26
+
27
+ - `train.py` — training loop with early stopping, ClearML opt-in
28
+ - `sweep.py` — hyperparameter search (Optuna: lr, batch_size)
29
+ - `eval_accuracy.py` — load checkpoint and print test metrics
30
+ - Saves to **`checkpoints/mlp_best.pt`**
31
+
32
+ ### xgboost/
33
+
34
+ - `train.py` — training with eval-set logging
35
+ - `sweep.py` / `sweep_local.py` — hyperparameter search (Optuna + ClearML)
36
+ - `eval_accuracy.py` — load checkpoint and print test metrics
37
+ - Saves to **`checkpoints/xgboost_face_orientation_best.json`**
38
+
39
+ ## 3. Data Loading
40
+
41
+ All training scripts import from `data_preparation.prepare_dataset`:
42
+
43
+ ```python
44
+ from data_preparation.prepare_dataset import get_numpy_splits # XGBoost
45
+ from data_preparation.prepare_dataset import get_dataloaders # MLP (PyTorch)
46
+ ```
47
+
48
+ ## 4. Results
49
+
50
+ | Model | Test Accuracy | F1 | ROC-AUC |
51
+ |-------|--------------|-----|---------|
52
+ | XGBoost | 95.87% | 0.959 | 0.991 |
53
+ | MLP | 92.92% | 0.929 | 0.971 |
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/cnn/CNN_MODEL/.claude/settings.local.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(# Check Dataset_subset counts echo \"\"=== Dataset_subset/train/open ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/train/open/ | wc -l && echo \"\"=== Dataset_subset/train/closed ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/train/closed/ | wc -l && echo \"\"=== Dataset_subset/val/open ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/val/open/ | wc -l && echo \"\"=== Dataset_subset/val/closed ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/val/closed/)"
5
+ ]
6
+ }
7
+ }
models/cnn/CNN_MODEL/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ DATA/** filter=lfs diff=lfs merge=lfs -text
models/cnn/CNN_MODEL/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Dataset/train/
2
+ Dataset/val/
3
+ Dataset/test/
4
+ .DS_Store
models/cnn/CNN_MODEL/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eye Open / Closed Classifier (YOLOv11-CLS)
2
+
3
+
4
+ Binary classifier: **open** vs **closed** eyes.
5
+ Used as a baseline for eye-tracking, drowsiness, or focus detection.
6
+
7
+ ---
8
+
9
+ ## Model team task
10
+
11
+ - **Train** the YOLOv11s-cls eye classifier in a **separate notebook** (data split, epochs, GPU, export `best.pt`).
12
+ - Provide **trained weights** (`best.pt`) for this repo’s evaluation and inference scripts.
13
+
14
+
15
+
16
+ ---
17
+
18
+ ## Repo contents
19
+
20
+ - **notebooks/eye_classifier_colab.ipynb** — Data download (Kaggle), clean, split, undersample, **evaluate** (needs `best.pt` from model team), export.
21
+ - **scripts/predict_image.py** — Run classifier on single images (needs `best.pt`).
22
+ - **scripts/webcam_live.py** — Live webcam open/closed (needs `best.pt` + optional `weights/face_landmarker.task`).
23
+ - **scripts/video_infer.py** — Run on video files.
24
+ - **scripts/focus_infer.py** — Focus/attention inference.
25
+ - **weights/** — Put `best.pt` here; `face_landmarker.task` is downloaded on first webcam run if missing.
26
+ - **docs/** — Extra docs (e.g. UNNECESSARY_FILES.md if present).
27
+
28
+ ---
29
+
30
+ ## Dataset
31
+
32
+ - **Source:** [Kaggle — open/closed eyes](https://www.kaggle.com/datasets/sehriyarmemmedli/open-closed-eyes-dataset)
33
+ - The Colab notebook downloads it via `kagglehub`; no local copy in repo.
34
+
35
+ ---
36
+
37
+ ## Weights
38
+
39
+ - Put **best.pt** from the model team in **weights/best.pt** (or `runs/classify/runs_cls/eye_open_closed_cpu/weights/best.pt`).
40
+ - For webcam: **face_landmarker.task** is downloaded into **weights/** on first run if missing.
41
+
42
+ ---
43
+
44
+ ## Local setup
45
+
46
+ ```bash
47
+ pip install ultralytics opencv-python mediapipe "numpy<2"
48
+ ```
49
+
50
+ Optional: use a venv. From repo root:
51
+ - `python scripts/predict_image.py <image.png>`
52
+ - `python scripts/webcam_live.py`
53
+ - `python scripts/video_infer.py` (expects 1.mp4 / 2.mp4 in repo root or set `VIDEOS` env)
54
+ - `python scripts/focus_infer.py`
55
+
56
+ ---
57
+
58
+ ## Project structure
59
+
60
+ ```
61
+ ├── notebooks/
62
+ │ └── eye_classifier_colab.ipynb # Data + eval (no training)
63
+ ├── scripts/
64
+ │ ├── predict_image.py
65
+ │ ├── webcam_live.py
66
+ │ ├── video_infer.py
67
+ │ └── focus_infer.py
68
+ ├── weights/ # best.pt, face_landmarker.task
69
+ ├── docs/ # extra docs
70
+ ├── README.md
71
+ └── venv/ # optional
72
+ ```
73
+
74
+ Training and weight generation: **model team, separate notebook.**
models/cnn/CNN_MODEL/notebooks/eye_classifier_colab.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/cnn/CNN_MODEL/scripts/focus_infer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+
10
+
11
+ def list_images(folder: Path):
12
+ exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
13
+ return sorted([p for p in folder.iterdir() if p.suffix.lower() in exts])
14
+
15
+
16
+ def find_weights(project_root: Path) -> Path | None:
17
+ candidates = [
18
+ project_root / "weights" / "best.pt",
19
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
20
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
21
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
22
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
23
+ ]
24
+ return next((p for p in candidates if p.is_file()), None)
25
+
26
+
27
+ def detect_eyelid_boundary(gray: np.ndarray) -> np.ndarray | None:
28
+ """
29
+ Returns an ellipse fit to the largest contour near the eye boundary.
30
+ Output format: (center(x,y), (axis1, axis2), angle) or None.
31
+ """
32
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
33
+ edges = cv2.Canny(blur, 40, 120)
34
+ edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
35
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
36
+ if not contours:
37
+ return None
38
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
39
+ for c in contours:
40
+ if len(c) >= 5 and cv2.contourArea(c) > 50:
41
+ return cv2.fitEllipse(c)
42
+ return None
43
+
44
+
45
+ def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
46
+ """
47
+ More robust pupil detection:
48
+ - enhance contrast (CLAHE)
49
+ - find dark blobs
50
+ - score by circularity and proximity to center
51
+ """
52
+ h, w = gray.shape
53
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
54
+ eq = clahe.apply(gray)
55
+ blur = cv2.GaussianBlur(eq, (7, 7), 0)
56
+
57
+ # Focus on the central region to avoid eyelashes/edges
58
+ cx, cy = w // 2, h // 2
59
+ rx, ry = int(w * 0.3), int(h * 0.3)
60
+ x0, x1 = max(cx - rx, 0), min(cx + rx, w)
61
+ y0, y1 = max(cy - ry, 0), min(cy + ry, h)
62
+ roi = blur[y0:y1, x0:x1]
63
+
64
+ # Inverted threshold to capture dark pupil
65
+ _, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
66
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
67
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
68
+
69
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
70
+ if not contours:
71
+ return None
72
+
73
+ best = None
74
+ best_score = -1.0
75
+ for c in contours:
76
+ area = cv2.contourArea(c)
77
+ if area < 15:
78
+ continue
79
+ perimeter = cv2.arcLength(c, True)
80
+ if perimeter <= 0:
81
+ continue
82
+ circularity = 4 * np.pi * (area / (perimeter * perimeter))
83
+ if circularity < 0.3:
84
+ continue
85
+ m = cv2.moments(c)
86
+ if m["m00"] == 0:
87
+ continue
88
+ px = int(m["m10"] / m["m00"]) + x0
89
+ py = int(m["m01"] / m["m00"]) + y0
90
+
91
+ # Score by circularity and distance to center
92
+ dist = np.hypot(px - cx, py - cy) / max(w, h)
93
+ score = circularity - dist
94
+ if score > best_score:
95
+ best_score = score
96
+ best = (px, py)
97
+
98
+ return best
99
+
100
+
101
+ def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
102
+ """
103
+ Decide focus based on pupil offset from image center.
104
+ """
105
+ h, w = img_shape
106
+ cx, cy = w // 2, h // 2
107
+ px, py = pupil_center
108
+ dx = abs(px - cx) / max(w, 1)
109
+ dy = abs(py - cy) / max(h, 1)
110
+ return (dx < 0.12) and (dy < 0.12)
111
+
112
+
113
+ def annotate(img_bgr: np.ndarray, ellipse, pupil_center, focused: bool, cls_label: str, conf: float):
114
+ out = img_bgr.copy()
115
+ if ellipse is not None:
116
+ cv2.ellipse(out, ellipse, (0, 255, 255), 2)
117
+ if pupil_center is not None:
118
+ cv2.circle(out, pupil_center, 4, (0, 0, 255), -1)
119
+ label = f"{cls_label} ({conf:.2f}) | focused={int(focused)}"
120
+ cv2.putText(out, label, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
121
+ return out
122
+
123
+
124
+ def main():
125
+ project_root = Path(__file__).resolve().parent.parent
126
+ data_dir = project_root / "Dataset"
127
+ alt_data_dir = project_root / "DATA"
128
+ out_dir = project_root / "runs_focus"
129
+ out_dir.mkdir(parents=True, exist_ok=True)
130
+
131
+ weights = find_weights(project_root)
132
+ if weights is None:
133
+ print("Weights not found. Train first.")
134
+ return
135
+
136
+ # Support both Dataset/test/{open,closed} and Dataset/{open,closed}
137
+ def resolve_test_dirs(root: Path):
138
+ test_open = root / "test" / "open"
139
+ test_closed = root / "test" / "closed"
140
+ if test_open.exists() and test_closed.exists():
141
+ return test_open, test_closed
142
+ test_open = root / "open"
143
+ test_closed = root / "closed"
144
+ if test_open.exists() and test_closed.exists():
145
+ return test_open, test_closed
146
+ alt_closed = root / "close"
147
+ if test_open.exists() and alt_closed.exists():
148
+ return test_open, alt_closed
149
+ return None, None
150
+
151
+ test_open, test_closed = resolve_test_dirs(data_dir)
152
+ if (test_open is None or test_closed is None) and alt_data_dir.exists():
153
+ test_open, test_closed = resolve_test_dirs(alt_data_dir)
154
+
155
+ if not test_open.exists() or not test_closed.exists():
156
+ print("Test folders missing. Expected:")
157
+ print(test_open)
158
+ print(test_closed)
159
+ return
160
+
161
+ test_files = list_images(test_open) + list_images(test_closed)
162
+ print("Total test images:", len(test_files))
163
+ max_images = int(os.getenv("MAX_IMAGES", "0"))
164
+ if max_images > 0:
165
+ test_files = test_files[:max_images]
166
+ print("Limiting to MAX_IMAGES:", max_images)
167
+
168
+ model = YOLO(str(weights))
169
+ results = model.predict(test_files, imgsz=224, device="cpu", verbose=False)
170
+
171
+ names = model.names
172
+ for r in results:
173
+ probs = r.probs
174
+ top_idx = int(probs.top1)
175
+ top_conf = float(probs.top1conf)
176
+ pred_label = names[top_idx]
177
+
178
+ img = cv2.imread(r.path)
179
+ if img is None:
180
+ continue
181
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
182
+
183
+ ellipse = detect_eyelid_boundary(gray)
184
+ pupil_center = detect_pupil_center(gray)
185
+ focused = False
186
+ if pred_label.lower() == "open" and pupil_center is not None:
187
+ focused = is_focused(pupil_center, gray.shape)
188
+
189
+ annotated = annotate(img, ellipse, pupil_center, focused, pred_label, top_conf)
190
+ out_path = out_dir / (Path(r.path).stem + "_annotated.jpg")
191
+ cv2.imwrite(str(out_path), annotated)
192
+
193
+ print(f"{Path(r.path).name}: pred={pred_label} conf={top_conf:.3f} focused={focused}")
194
+
195
+ print(f"\nAnnotated outputs saved to: {out_dir}")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()
models/cnn/CNN_MODEL/scripts/predict_image.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the eye open/closed model on one or more images."""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ from ultralytics import YOLO
6
+
7
+
8
+ def main():
9
+ project_root = Path(__file__).resolve().parent.parent
10
+ weight_candidates = [
11
+ project_root / "weights" / "best.pt",
12
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
13
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
14
+ ]
15
+ weights = next((p for p in weight_candidates if p.is_file()), None)
16
+ if weights is None:
17
+ print("Weights not found. Put best.pt in weights/ or runs/.../weights/ (from model team).")
18
+ sys.exit(1)
19
+
20
+ if len(sys.argv) < 2:
21
+ print("Usage: python scripts/predict_image.py <image1> [image2 ...]")
22
+ print("Example: python scripts/predict_image.py path/to/image.png")
23
+ sys.exit(0)
24
+
25
+ model = YOLO(str(weights))
26
+ names = model.names
27
+
28
+ for path in sys.argv[1:]:
29
+ p = Path(path)
30
+ if not p.is_file():
31
+ print(p, "- file not found")
32
+ continue
33
+ try:
34
+ results = model.predict(str(p), imgsz=224, device="cpu", verbose=False)
35
+ except Exception as e:
36
+ print(p, "- error:", e)
37
+ continue
38
+ if not results:
39
+ print(p, "- no result")
40
+ continue
41
+ r = results[0]
42
+ top_idx = int(r.probs.top1)
43
+ conf = float(r.probs.top1conf)
44
+ label = names[top_idx]
45
+ print(f"{p.name}: {label} ({conf:.2%})")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
models/cnn/CNN_MODEL/scripts/video_infer.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from ultralytics import YOLO
9
+
10
+ try:
11
+ import mediapipe as mp
12
+ except Exception: # pragma: no cover
13
+ mp = None
14
+
15
+
16
+ def find_weights(project_root: Path) -> Path | None:
17
+ candidates = [
18
+ project_root / "weights" / "best.pt",
19
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
20
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
21
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
22
+ project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
23
+ ]
24
+ return next((p for p in candidates if p.is_file()), None)
25
+
26
+
27
+ def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
28
+ h, w = gray.shape
29
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
30
+ eq = clahe.apply(gray)
31
+ blur = cv2.GaussianBlur(eq, (7, 7), 0)
32
+
33
+ cx, cy = w // 2, h // 2
34
+ rx, ry = int(w * 0.3), int(h * 0.3)
35
+ x0, x1 = max(cx - rx, 0), min(cx + rx, w)
36
+ y0, y1 = max(cy - ry, 0), min(cy + ry, h)
37
+ roi = blur[y0:y1, x0:x1]
38
+
39
+ _, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
40
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
41
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
42
+
43
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
44
+ if not contours:
45
+ return None
46
+
47
+ best = None
48
+ best_score = -1.0
49
+ for c in contours:
50
+ area = cv2.contourArea(c)
51
+ if area < 15:
52
+ continue
53
+ perimeter = cv2.arcLength(c, True)
54
+ if perimeter <= 0:
55
+ continue
56
+ circularity = 4 * np.pi * (area / (perimeter * perimeter))
57
+ if circularity < 0.3:
58
+ continue
59
+ m = cv2.moments(c)
60
+ if m["m00"] == 0:
61
+ continue
62
+ px = int(m["m10"] / m["m00"]) + x0
63
+ py = int(m["m01"] / m["m00"]) + y0
64
+
65
+ dist = np.hypot(px - cx, py - cy) / max(w, h)
66
+ score = circularity - dist
67
+ if score > best_score:
68
+ best_score = score
69
+ best = (px, py)
70
+
71
+ return best
72
+
73
+
74
+ def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
75
+ h, w = img_shape
76
+ cx = w // 2
77
+ px, _ = pupil_center
78
+ dx = abs(px - cx) / max(w, 1)
79
+ return dx < 0.12
80
+
81
+
82
+ def classify_frame(model: YOLO, frame: np.ndarray) -> tuple[str, float]:
83
+ # Use classifier directly on frame (assumes frame is eye crop)
84
+ results = model.predict(frame, imgsz=224, device="cpu", verbose=False)
85
+ r = results[0]
86
+ probs = r.probs
87
+ top_idx = int(probs.top1)
88
+ top_conf = float(probs.top1conf)
89
+ pred_label = model.names[top_idx]
90
+ return pred_label, top_conf
91
+
92
+
93
+ def annotate_frame(frame: np.ndarray, label: str, focused: bool, conf: float, time_sec: float):
94
+ out = frame.copy()
95
+ text = f"{label} | focused={int(focused)} | conf={conf:.2f} | t={time_sec:.2f}s"
96
+ cv2.putText(out, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
97
+ return out
98
+
99
+
100
+ def write_segments(path: Path, segments: list[tuple[float, float, str]]):
101
+ with path.open("w") as f:
102
+ for start, end, label in segments:
103
+ f.write(f"{start:.2f},{end:.2f},{label}\n")
104
+
105
+
106
+ def process_video(video_path: Path, model: YOLO | None):
107
+ cap = cv2.VideoCapture(str(video_path))
108
+ if not cap.isOpened():
109
+ print(f"Failed to open {video_path}")
110
+ return
111
+
112
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
113
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
114
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
115
+
116
+ out_path = video_path.with_name(video_path.stem + "_pred.mp4")
117
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
118
+ writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height))
119
+
120
+ csv_path = video_path.with_name(video_path.stem + "_predictions.csv")
121
+ seg_path = video_path.with_name(video_path.stem + "_segments.txt")
122
+
123
+ frame_idx = 0
124
+ last_label = None
125
+ seg_start = 0.0
126
+ segments: list[tuple[float, float, str]] = []
127
+
128
+ with csv_path.open("w") as fcsv:
129
+ fcsv.write("time_sec,label,focused,conf\n")
130
+ if mp is None:
131
+ print("mediapipe is not installed. Falling back to classifier-only mode.")
132
+ use_mp = mp is not None
133
+ if use_mp:
134
+ mp_face_mesh = mp.solutions.face_mesh
135
+ face_mesh = mp_face_mesh.FaceMesh(
136
+ static_image_mode=False,
137
+ max_num_faces=1,
138
+ refine_landmarks=True,
139
+ min_detection_confidence=0.5,
140
+ min_tracking_confidence=0.5,
141
+ )
142
+
143
+ while True:
144
+ ret, frame = cap.read()
145
+ if not ret:
146
+ break
147
+ time_sec = frame_idx / fps
148
+ conf = 0.0
149
+ pred_label = "open"
150
+ focused = False
151
+
152
+ if use_mp:
153
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
154
+ res = face_mesh.process(rgb)
155
+ if res.multi_face_landmarks:
156
+ lm = res.multi_face_landmarks[0].landmark
157
+ h, w = frame.shape[:2]
158
+
159
+ # Eye landmarks (MediaPipe FaceMesh)
160
+ left_eye = [33, 160, 158, 133, 153, 144]
161
+ right_eye = [362, 385, 387, 263, 373, 380]
162
+ left_iris = [468, 469, 470, 471]
163
+ right_iris = [473, 474, 475, 476]
164
+
165
+ def pts(idxs):
166
+ return np.array([(int(lm[i].x * w), int(lm[i].y * h)) for i in idxs])
167
+
168
+ def ear(eye_pts):
169
+ # EAR using 6 points
170
+ p1, p2, p3, p4, p5, p6 = eye_pts
171
+ v1 = np.linalg.norm(p2 - p6)
172
+ v2 = np.linalg.norm(p3 - p5)
173
+ h1 = np.linalg.norm(p1 - p4)
174
+ return (v1 + v2) / (2.0 * h1 + 1e-6)
175
+
176
+ le = pts(left_eye)
177
+ re = pts(right_eye)
178
+ le_ear = ear(le)
179
+ re_ear = ear(re)
180
+ ear_avg = (le_ear + re_ear) / 2.0
181
+
182
+ # openness threshold
183
+ pred_label = "open" if ear_avg > 0.22 else "closed"
184
+
185
+ # iris centers
186
+ li = pts(left_iris)
187
+ ri = pts(right_iris)
188
+ li_c = li.mean(axis=0).astype(int)
189
+ ri_c = ri.mean(axis=0).astype(int)
190
+
191
+ # eye centers (midpoint of corners)
192
+ le_c = ((le[0] + le[3]) / 2).astype(int)
193
+ re_c = ((re[0] + re[3]) / 2).astype(int)
194
+
195
+ # focus = iris close to eye center horizontally for both eyes
196
+ le_dx = abs(li_c[0] - le_c[0]) / max(np.linalg.norm(le[0] - le[3]), 1)
197
+ re_dx = abs(ri_c[0] - re_c[0]) / max(np.linalg.norm(re[0] - re[3]), 1)
198
+ focused = (pred_label == "open") and (le_dx < 0.18) and (re_dx < 0.18)
199
+
200
+ # draw eye boundaries
201
+ cv2.polylines(frame, [le], True, (0, 255, 255), 1)
202
+ cv2.polylines(frame, [re], True, (0, 255, 255), 1)
203
+ # draw iris centers
204
+ cv2.circle(frame, tuple(li_c), 3, (0, 0, 255), -1)
205
+ cv2.circle(frame, tuple(ri_c), 3, (0, 0, 255), -1)
206
+ else:
207
+ pred_label = "closed"
208
+ focused = False
209
+ else:
210
+ if model is not None:
211
+ pred_label, conf = classify_frame(model, frame)
212
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
213
+ pupil_center = detect_pupil_center(gray) if pred_label.lower() == "open" else None
214
+ focused = False
215
+ if pred_label.lower() == "open" and pupil_center is not None:
216
+ focused = is_focused(pupil_center, gray.shape)
217
+
218
+ if pred_label.lower() != "open":
219
+ focused = False
220
+
221
+ label = "open_focused" if (pred_label.lower() == "open" and focused) else "open_not_focused"
222
+ if pred_label.lower() != "open":
223
+ label = "closed_not_focused"
224
+
225
+ fcsv.write(f"{time_sec:.2f},{label},{int(focused)},{conf:.4f}\n")
226
+
227
+ if last_label is None:
228
+ last_label = label
229
+ seg_start = time_sec
230
+ elif label != last_label:
231
+ segments.append((seg_start, time_sec, last_label))
232
+ seg_start = time_sec
233
+ last_label = label
234
+
235
+ annotated = annotate_frame(frame, label, focused, conf, time_sec)
236
+ writer.write(annotated)
237
+ frame_idx += 1
238
+
239
+ if last_label is not None:
240
+ end_time = frame_idx / fps
241
+ segments.append((seg_start, end_time, last_label))
242
+ write_segments(seg_path, segments)
243
+
244
+ cap.release()
245
+ writer.release()
246
+ print(f"Saved: {out_path}")
247
+ print(f"CSV: {csv_path}")
248
+ print(f"Segments: {seg_path}")
249
+
250
+
251
+ def main():
252
+ project_root = Path(__file__).resolve().parent.parent
253
+ weights = find_weights(project_root)
254
+ model = YOLO(str(weights)) if weights is not None else None
255
+
256
+ # Default to 1.mp4 and 2.mp4 in project root
257
+ videos = []
258
+ for name in ["1.mp4", "2.mp4"]:
259
+ p = project_root / name
260
+ if p.exists():
261
+ videos.append(p)
262
+
263
+ # Also allow passing paths via env var
264
+ extra = os.getenv("VIDEOS", "")
265
+ for v in [x.strip() for x in extra.split(",") if x.strip()]:
266
+ vp = Path(v)
267
+ if not vp.is_absolute():
268
+ vp = project_root / vp
269
+ if vp.exists():
270
+ videos.append(vp)
271
+
272
+ if not videos:
273
+ print("No videos found. Expected 1.mp4 / 2.mp4 in project root.")
274
+ return
275
+
276
+ for v in videos:
277
+ process_video(v, model)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
models/cnn/CNN_MODEL/scripts/webcam_live.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Live webcam: detect face, crop each eye, run open/closed classifier, show on screen.
3
+ Requires: opencv-python, ultralytics, mediapipe (pip install mediapipe).
4
+ Press 'q' to quit.
5
+ """
6
+ import urllib.request
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from ultralytics import YOLO
12
+
13
+ try:
14
+ import mediapipe as mp
15
+ _mp_has_solutions = hasattr(mp, "solutions")
16
+ except ImportError:
17
+ mp = None
18
+ _mp_has_solutions = False
19
+
20
+ # New MediaPipe Tasks API (Face Landmarker) eye indices
21
+ LEFT_EYE_INDICES_NEW = [263, 249, 390, 373, 374, 380, 381, 382, 362, 466, 388, 387, 386, 385, 384, 398]
22
+ RIGHT_EYE_INDICES_NEW = [33, 7, 163, 144, 145, 153, 154, 155, 133, 246, 161, 160, 159, 158, 157, 173]
23
+ # Old Face Mesh (solutions) indices
24
+ LEFT_EYE_INDICES_OLD = [33, 160, 158, 133, 153, 144]
25
+ RIGHT_EYE_INDICES_OLD = [362, 385, 387, 263, 373, 380]
26
+ EYE_PADDING = 0.35
27
+
28
+
29
+ def find_weights(project_root: Path) -> Path | None:
30
+ candidates = [
31
+ project_root / "weights" / "best.pt",
32
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
33
+ project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
34
+ ]
35
+ return next((p for p in candidates if p.is_file()), None)
36
+
37
+
38
+ def get_eye_roi(frame: np.ndarray, landmarks, indices: list[int]) -> np.ndarray | None:
39
+ h, w = frame.shape[:2]
40
+ pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in indices])
41
+ x_min, y_min = pts.min(axis=0)
42
+ x_max, y_max = pts.max(axis=0)
43
+ dx = max(int((x_max - x_min) * EYE_PADDING), 8)
44
+ dy = max(int((y_max - y_min) * EYE_PADDING), 8)
45
+ x0 = max(0, x_min - dx)
46
+ y0 = max(0, y_min - dy)
47
+ x1 = min(w, x_max + dx)
48
+ y1 = min(h, y_max + dy)
49
+ if x1 <= x0 or y1 <= y0:
50
+ return None
51
+ return frame[y0:y1, x0:x1].copy()
52
+
53
+
54
+ def _run_with_solutions(mp, model, cap):
55
+ face_mesh = mp.solutions.face_mesh.FaceMesh(
56
+ static_image_mode=False,
57
+ max_num_faces=1,
58
+ refine_landmarks=True,
59
+ min_detection_confidence=0.5,
60
+ min_tracking_confidence=0.5,
61
+ )
62
+ while True:
63
+ ret, frame = cap.read()
64
+ if not ret:
65
+ break
66
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
67
+ results = face_mesh.process(rgb)
68
+ left_label, left_conf = "—", 0.0
69
+ right_label, right_conf = "—", 0.0
70
+ if results.multi_face_landmarks:
71
+ lm = results.multi_face_landmarks[0].landmark
72
+ for roi, indices, side in [
73
+ (get_eye_roi(frame, lm, LEFT_EYE_INDICES_OLD), LEFT_EYE_INDICES_OLD, "left"),
74
+ (get_eye_roi(frame, lm, RIGHT_EYE_INDICES_OLD), RIGHT_EYE_INDICES_OLD, "right"),
75
+ ]:
76
+ if roi is not None and roi.size > 0:
77
+ try:
78
+ pred = model.predict(roi, imgsz=224, device="cpu", verbose=False)
79
+ if pred:
80
+ r = pred[0]
81
+ label = model.names[int(r.probs.top1)]
82
+ conf = float(r.probs.top1conf)
83
+ if side == "left":
84
+ left_label, left_conf = label, conf
85
+ else:
86
+ right_label, right_conf = label, conf
87
+ except Exception:
88
+ pass
89
+ cv2.putText(frame, f"L: {left_label} ({left_conf:.0%})", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
90
+ cv2.putText(frame, f"R: {right_label} ({right_conf:.0%})", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
91
+ cv2.imshow("Eye open/closed (q to quit)", frame)
92
+ if cv2.waitKey(1) & 0xFF == ord("q"):
93
+ break
94
+
95
+
96
+ def _run_with_tasks(project_root: Path, model, cap):
97
+ from mediapipe.tasks.python import BaseOptions
98
+ from mediapipe.tasks.python.vision import FaceLandmarker, FaceLandmarkerOptions
99
+ from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode
100
+ from mediapipe.tasks.python.vision.core import image as image_lib
101
+
102
+ model_path = project_root / "weights" / "face_landmarker.task"
103
+ if not model_path.is_file():
104
+ print("Downloading face_landmarker.task ...")
105
+ url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task"
106
+ urllib.request.urlretrieve(url, model_path)
107
+ print("Done.")
108
+
109
+ options = FaceLandmarkerOptions(
110
+ base_options=BaseOptions(model_asset_path=str(model_path)),
111
+ running_mode=running_mode.VisionTaskRunningMode.IMAGE,
112
+ num_faces=1,
113
+ )
114
+ face_landmarker = FaceLandmarker.create_from_options(options)
115
+ ImageFormat = image_lib.ImageFormat
116
+
117
+ while True:
118
+ ret, frame = cap.read()
119
+ if not ret:
120
+ break
121
+ left_label, left_conf = "—", 0.0
122
+ right_label, right_conf = "—", 0.0
123
+
124
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
125
+ rgb_contiguous = np.ascontiguousarray(rgb)
126
+ mp_image = image_lib.Image(ImageFormat.SRGB, rgb_contiguous)
127
+ result = face_landmarker.detect(mp_image)
128
+
129
+ if result.face_landmarks:
130
+ lm = result.face_landmarks[0]
131
+ for roi, side in [
132
+ (get_eye_roi(frame, lm, LEFT_EYE_INDICES_NEW), "left"),
133
+ (get_eye_roi(frame, lm, RIGHT_EYE_INDICES_NEW), "right"),
134
+ ]:
135
+ if roi is not None and roi.size > 0:
136
+ try:
137
+ pred = model.predict(roi, imgsz=224, device="cpu", verbose=False)
138
+ if pred:
139
+ r = pred[0]
140
+ label = model.names[int(r.probs.top1)]
141
+ conf = float(r.probs.top1conf)
142
+ if side == "left":
143
+ left_label, left_conf = label, conf
144
+ else:
145
+ right_label, right_conf = label, conf
146
+ except Exception:
147
+ pass
148
+
149
+ cv2.putText(frame, f"L: {left_label} ({left_conf:.0%})", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
150
+ cv2.putText(frame, f"R: {right_label} ({right_conf:.0%})", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
151
+ cv2.imshow("Eye open/closed (q to quit)", frame)
152
+ if cv2.waitKey(1) & 0xFF == ord("q"):
153
+ break
154
+
155
+
156
+ def main():
157
+ project_root = Path(__file__).resolve().parent.parent
158
+ weights = find_weights(project_root)
159
+ if weights is None:
160
+ print("Weights not found. Put best.pt in weights/ or runs/.../weights/ (from model team).")
161
+ return
162
+ if mp is None:
163
+ print("MediaPipe required. Install: pip install mediapipe")
164
+ return
165
+
166
+ model = YOLO(str(weights))
167
+ cap = cv2.VideoCapture(0)
168
+ if not cap.isOpened():
169
+ print("Could not open webcam.")
170
+ return
171
+
172
+ print("Live eye open/closed on your face. Press 'q' to quit.")
173
+ try:
174
+ if _mp_has_solutions:
175
+ _run_with_solutions(mp, model, cap)
176
+ else:
177
+ _run_with_tasks(project_root, model, cap)
178
+ finally:
179
+ cap.release()
180
+ cv2.destroyAllWindows()
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
models/cnn/CNN_MODEL/weights/yolo11s-cls.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2b605d1c8c212b434a75a32759a6f7adf1d2b29c35f76bdccd4c794cb653cf2
3
+ size 13630112
models/cnn/__init__.py ADDED
File without changes
models/cnn/eye_attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/cnn/eye_attention/classifier.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+ import numpy as np
7
+
8
+
9
+ class EyeClassifier(ABC):
10
+ @property
11
+ @abstractmethod
12
+ def name(self) -> str:
13
+ pass
14
+
15
+ @abstractmethod
16
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
17
+ pass
18
+
19
+
20
+ class GeometricOnlyClassifier(EyeClassifier):
21
+ @property
22
+ def name(self) -> str:
23
+ return "geometric"
24
+
25
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
26
+ return 1.0
27
+
28
+
29
+ class YOLOv11Classifier(EyeClassifier):
30
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
31
+ from ultralytics import YOLO
32
+
33
+ self._model = YOLO(checkpoint_path)
34
+ self._device = device
35
+
36
+ names = self._model.names
37
+ self._attentive_idx = None
38
+ for idx, cls_name in names.items():
39
+ if cls_name in ("open", "attentive"):
40
+ self._attentive_idx = idx
41
+ break
42
+ if self._attentive_idx is None:
43
+ self._attentive_idx = max(names.keys())
44
+ print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
45
+
46
+ @property
47
+ def name(self) -> str:
48
+ return "yolo"
49
+
50
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
51
+ if not crops_bgr:
52
+ return 1.0
53
+ results = self._model.predict(crops_bgr, device=self._device, verbose=False)
54
+ scores = [float(r.probs.data[self._attentive_idx]) for r in results]
55
+ return sum(scores) / len(scores) if scores else 1.0
56
+
57
+
58
+ class EyeCNNClassifier(EyeClassifier):
59
+ """Loader for the custom PyTorch EyeCNN (trained on Kaggle eye crops)."""
60
+
61
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
62
+ import torch
63
+ import torch.nn as nn
64
+
65
+ class EyeCNN(nn.Module):
66
+ def __init__(self, num_classes=2, dropout_rate=0.3):
67
+ super().__init__()
68
+ self.conv_layers = nn.Sequential(
69
+ nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
70
+ nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
71
+ nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2),
72
+ nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2),
73
+ )
74
+ self.fc_layers = nn.Sequential(
75
+ nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
76
+ nn.Linear(256, 512), nn.ReLU(), nn.Dropout(dropout_rate),
77
+ nn.Linear(512, num_classes),
78
+ )
79
+
80
+ def forward(self, x):
81
+ return self.fc_layers(self.conv_layers(x))
82
+
83
+ self._device = torch.device(device)
84
+ checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
85
+ dropout_rate = checkpoint.get("config", {}).get("dropout_rate", 0.35)
86
+ self._model = EyeCNN(num_classes=2, dropout_rate=dropout_rate)
87
+ self._model.load_state_dict(checkpoint["model_state_dict"])
88
+ self._model.to(self._device)
89
+ self._model.eval()
90
+
91
+ self._transform = None # built lazily
92
+
93
+ def _get_transform(self):
94
+ if self._transform is None:
95
+ from torchvision import transforms
96
+ self._transform = transforms.Compose([
97
+ transforms.ToPILImage(),
98
+ transforms.Resize((96, 96)),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize(
101
+ mean=[0.485, 0.456, 0.406],
102
+ std=[0.229, 0.224, 0.225],
103
+ ),
104
+ ])
105
+ return self._transform
106
+
107
+ @property
108
+ def name(self) -> str:
109
+ return "eye_cnn"
110
+
111
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
112
+ if not crops_bgr:
113
+ return 1.0
114
+
115
+ import torch
116
+ import cv2
117
+
118
+ transform = self._get_transform()
119
+ scores = []
120
+ for crop in crops_bgr:
121
+ if crop is None or crop.size == 0:
122
+ scores.append(1.0)
123
+ continue
124
+ rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
125
+ tensor = transform(rgb).unsqueeze(0).to(self._device)
126
+ with torch.no_grad():
127
+ output = self._model(tensor)
128
+ prob = torch.softmax(output, dim=1)[0, 1].item() # prob of "open"
129
+ scores.append(prob)
130
+ return sum(scores) / len(scores)
131
+
132
+
133
+ _EXT_TO_BACKEND = {".pth": "cnn", ".pt": "yolo"}
134
+
135
+
136
+ def load_eye_classifier(
137
+ path: str | None = None,
138
+ backend: str = "yolo",
139
+ device: str = "cpu",
140
+ ) -> EyeClassifier:
141
+ if backend == "geometric":
142
+ return GeometricOnlyClassifier()
143
+
144
+ if path is None:
145
+ print(f"[CLASSIFIER] No model path for backend {backend!r}, falling back to geometric")
146
+ return GeometricOnlyClassifier()
147
+
148
+ ext = os.path.splitext(path)[1].lower()
149
+ inferred = _EXT_TO_BACKEND.get(ext)
150
+ if inferred and inferred != backend:
151
+ print(f"[CLASSIFIER] File extension {ext!r} implies backend {inferred!r}, "
152
+ f"overriding requested {backend!r}")
153
+ backend = inferred
154
+
155
+ print(f"[CLASSIFIER] backend={backend!r}, path={path!r}")
156
+
157
+ if backend == "cnn":
158
+ return EyeCNNClassifier(path, device=device)
159
+
160
+ if backend == "yolo":
161
+ try:
162
+ return YOLOv11Classifier(path, device=device)
163
+ except ImportError:
164
+ print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
165
+ raise
166
+
167
+ raise ValueError(
168
+ f"Unknown eye backend {backend!r}. Choose from: yolo, cnn, geometric"
169
+ )
models/cnn/eye_attention/crop.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from models.pretrained.face_mesh.face_mesh import FaceMeshDetector
5
+
6
+ LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
7
+ RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
8
+
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+ CROP_SIZE = 96
13
+
14
+
15
+ def _bbox_from_landmarks(
16
+ landmarks: np.ndarray,
17
+ indices: list[int],
18
+ frame_w: int,
19
+ frame_h: int,
20
+ expand: float = 0.4,
21
+ ) -> tuple[int, int, int, int]:
22
+ pts = landmarks[indices, :2]
23
+ px = pts[:, 0] * frame_w
24
+ py = pts[:, 1] * frame_h
25
+
26
+ x_min, x_max = px.min(), px.max()
27
+ y_min, y_max = py.min(), py.max()
28
+ w = x_max - x_min
29
+ h = y_max - y_min
30
+ cx = (x_min + x_max) / 2
31
+ cy = (y_min + y_max) / 2
32
+
33
+ size = max(w, h) * (1 + expand)
34
+ half = size / 2
35
+
36
+ x1 = int(max(cx - half, 0))
37
+ y1 = int(max(cy - half, 0))
38
+ x2 = int(min(cx + half, frame_w))
39
+ y2 = int(min(cy + half, frame_h))
40
+
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def extract_eye_crops(
45
+ frame: np.ndarray,
46
+ landmarks: np.ndarray,
47
+ expand: float = 0.4,
48
+ crop_size: int = CROP_SIZE,
49
+ ) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
50
+ h, w = frame.shape[:2]
51
+
52
+ left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
53
+ right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
54
+
55
+ left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
56
+ right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
57
+
58
+ left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
59
+ right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
60
+
61
+ return left_crop, right_crop, left_bbox, right_bbox
62
+
63
+
64
+ def crop_to_tensor(crop_bgr: np.ndarray):
65
+ import torch
66
+
67
+ rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
68
+ for c in range(3):
69
+ rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
70
+ return torch.from_numpy(rgb.transpose(2, 0, 1))
models/cnn/eye_attention/train.py ADDED
File without changes
models/cnn/notebooks/EyeCNN.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "import os\n",
23
+ "import torch\n",
24
+ "import torch.nn as nn\n",
25
+ "import torch.optim as optim\n",
26
+ "from torch.utils.data import DataLoader\n",
27
+ "from torchvision import datasets, transforms\n",
28
+ "\n",
29
+ "from google.colab import drive\n",
30
+ "drive.mount('/content/drive')\n",
31
+ "!cp -r /content/drive/MyDrive/Dataset_clean /content/\n",
32
+ "\n",
33
+ "#Verify structure\n",
34
+ "for split in ['train', 'val', 'test']:\n",
35
+ " path = f'/content/Dataset_clean/{split}'\n",
36
+ " classes = os.listdir(path)\n",
37
+ " total = sum(len(os.listdir(os.path.join(path, c))) for c in classes)\n",
38
+ " print(f'{split}: {total} images | classes: {classes}')"
39
+ ],
40
+ "metadata": {
41
+ "colab": {
42
+ "base_uri": "https://localhost:8080/"
43
+ },
44
+ "id": "sE1F3em-V5go",
45
+ "outputId": "2c73a9a6-a198-468c-a2cc-253b2de7cc3f"
46
+ },
47
+ "execution_count": null,
48
+ "outputs": [
49
+ {
50
+ "output_type": "stream",
51
+ "name": "stdout",
52
+ "text": [
53
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
54
+ ]
55
+ }
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {
62
+ "id": "nG2bh66rQ56G"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "class EyeCNN(nn.Module):\n",
67
+ " def __init__(self, num_classes=2):\n",
68
+ " super(EyeCNN, self).__init__()\n",
69
+ " self.conv_layers = nn.Sequential(\n",
70
+ " nn.Conv2d(3, 32, 3, 1, 1),\n",
71
+ " nn.BatchNorm2d(32),\n",
72
+ " nn.ReLU(),\n",
73
+ " nn.MaxPool2d(2, 2),\n",
74
+ "\n",
75
+ " nn.Conv2d(32, 64, 3, 1, 1),\n",
76
+ " nn.BatchNorm2d(64),\n",
77
+ " nn.ReLU(),\n",
78
+ " nn.MaxPool2d(2, 2),\n",
79
+ "\n",
80
+ " nn.Conv2d(64, 128, 3, 1, 1),\n",
81
+ " nn.BatchNorm2d(128),\n",
82
+ " nn.ReLU(),\n",
83
+ " nn.MaxPool2d(2, 2),\n",
84
+ "\n",
85
+ " nn.Conv2d(128, 256, 3, 1, 1),\n",
86
+ " nn.BatchNorm2d(256),\n",
87
+ " nn.ReLU(),\n",
88
+ " nn.MaxPool2d(2, 2)\n",
89
+ " )\n",
90
+ "\n",
91
+ " self.fc_layers = nn.Sequential(\n",
92
+ " nn.AdaptiveAvgPool2d((1, 1)),\n",
93
+ " nn.Flatten(),\n",
94
+ " nn.Linear(256, 512),\n",
95
+ " nn.ReLU(),\n",
96
+ " nn.Dropout(0.35),\n",
97
+ " nn.Linear(512, num_classes)\n",
98
+ " )\n",
99
+ "\n",
100
+ " def forward(self, x):\n",
101
+ " x = self.conv_layers(x)\n",
102
+ " x = self.fc_layers(x)\n",
103
+ " return x"
104
+ ]
105
+ }
106
+ ]
107
+ }
models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/cnn/notebooks/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # GAP Large Project
models/collect_features.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import collections
4
+ import math
5
+ import os
6
+ import sys
7
+ import time
8
+
9
+ import cv2
10
+ import numpy as np
11
+
12
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+ if _PROJECT_ROOT not in sys.path:
14
+ sys.path.insert(0, _PROJECT_ROOT)
15
+
16
+ from models.face_mesh import FaceMeshDetector
17
+ from models.head_pose import HeadPoseEstimator
18
+ from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
19
+
20
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
21
+ GREEN = (0, 255, 0)
22
+ RED = (0, 0, 255)
23
+ WHITE = (255, 255, 255)
24
+ YELLOW = (0, 255, 255)
25
+ ORANGE = (0, 165, 255)
26
+ GRAY = (120, 120, 120)
27
+
28
+ FEATURE_NAMES = [
29
+ "ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar",
30
+ "yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation",
31
+ "perclos", "blink_rate", "closure_duration", "yawn_duration",
32
+ ]
33
+
34
+ NUM_FEATURES = len(FEATURE_NAMES)
35
+ assert NUM_FEATURES == 17
36
+
37
+
38
+ class TemporalTracker:
39
+ EAR_BLINK_THRESH = 0.21
40
+ MAR_YAWN_THRESH = 0.55
41
+ PERCLOS_WINDOW = 60
42
+ BLINK_WINDOW_SEC = 30.0
43
+
44
+ def __init__(self):
45
+ self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW)
46
+ self.blink_timestamps = collections.deque()
47
+ self._eyes_closed = False
48
+ self._closure_start = None
49
+ self._yawn_start = None
50
+
51
+ def update(self, ear_avg, mar, now=None):
52
+ if now is None:
53
+ now = time.time()
54
+
55
+ closed = ear_avg < self.EAR_BLINK_THRESH
56
+ self.ear_history.append(1.0 if closed else 0.0)
57
+ perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0
58
+
59
+ if self._eyes_closed and not closed:
60
+ self.blink_timestamps.append(now)
61
+ self._eyes_closed = closed
62
+
63
+ cutoff = now - self.BLINK_WINDOW_SEC
64
+ while self.blink_timestamps and self.blink_timestamps[0] < cutoff:
65
+ self.blink_timestamps.popleft()
66
+ blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC)
67
+
68
+ if closed:
69
+ if self._closure_start is None:
70
+ self._closure_start = now
71
+ closure_dur = now - self._closure_start
72
+ else:
73
+ self._closure_start = None
74
+ closure_dur = 0.0
75
+
76
+ yawning = mar > self.MAR_YAWN_THRESH
77
+ if yawning:
78
+ if self._yawn_start is None:
79
+ self._yawn_start = now
80
+ yawn_dur = now - self._yawn_start
81
+ else:
82
+ self._yawn_start = None
83
+ yawn_dur = 0.0
84
+
85
+ return perclos, blink_rate, closure_dur, yawn_dur
86
+
87
+
88
+ def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal,
89
+ *, _pre=None):
90
+ from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
91
+
92
+ p = _pre or {}
93
+
94
+ ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR))
95
+ ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR))
96
+ ear_avg = (ear_left + ear_right) / 2.0
97
+
98
+ if "h_gaze" in p and "v_gaze" in p:
99
+ h_gaze, v_gaze = p["h_gaze"], p["v_gaze"]
100
+ else:
101
+ h_gaze, v_gaze = compute_gaze_ratio(landmarks)
102
+
103
+ mar = p.get("mar", compute_mar(landmarks))
104
+
105
+ angles = p.get("angles")
106
+ if angles is None:
107
+ angles = head_pose.estimate(landmarks, w, h)
108
+ yaw = angles[0] if angles else 0.0
109
+ pitch = angles[1] if angles else 0.0
110
+ roll = angles[2] if angles else 0.0
111
+
112
+ s_face = p.get("s_face", head_pose.score(landmarks, w, h))
113
+ s_eye = p.get("s_eye", eye_scorer.score(landmarks))
114
+
115
+ gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
116
+ head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream
117
+
118
+ perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
119
+
120
+ return np.array([
121
+ ear_left, ear_right, ear_avg,
122
+ h_gaze, v_gaze,
123
+ mar,
124
+ yaw, pitch, roll,
125
+ s_face, s_eye,
126
+ gaze_offset,
127
+ head_deviation,
128
+ perclos, blink_rate, closure_dur, yawn_dur,
129
+ ], dtype=np.float32)
130
+
131
+
132
+ def quality_report(labels):
133
+ n = len(labels)
134
+ n1 = int((labels == 1).sum())
135
+ n0 = n - n1
136
+ transitions = int(np.sum(np.diff(labels) != 0))
137
+ duration_sec = n / 30.0 # approximate at 30fps
138
+
139
+ warnings = []
140
+
141
+ print(f"\n{'='*50}")
142
+ print(f" DATA QUALITY REPORT")
143
+ print(f"{'='*50}")
144
+ print(f" Total samples : {n}")
145
+ print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)")
146
+ print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)")
147
+ print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)")
148
+ print(f" Transitions : {transitions}")
149
+ if transitions > 0:
150
+ print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)")
151
+
152
+ # checks
153
+ if duration_sec < 120:
154
+ warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)")
155
+
156
+ if n < 3000:
157
+ warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)")
158
+
159
+ balance = n1 / max(n, 1)
160
+ if balance < 0.3 or balance > 0.7:
161
+ warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused")
162
+
163
+ if transitions < 10:
164
+ warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+")
165
+
166
+ if transitions == 1:
167
+ warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — "
168
+ "model will learn temporal position, not focus patterns")
169
+
170
+ if warnings:
171
+ print(f"\n ⚠️ WARNINGS ({len(warnings)}):")
172
+ for w in warnings:
173
+ print(f" • {w}")
174
+ print(f"\n Consider re-recording this session.")
175
+ else:
176
+ print(f"\n ✅ All checks passed!")
177
+
178
+ print(f"{'='*50}\n")
179
+ return len(warnings) == 0
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Main
184
+ def main():
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument("--name", type=str, default="session",
187
+ help="Your name or session ID")
188
+ parser.add_argument("--camera", type=int, default=0,
189
+ help="Camera index")
190
+ parser.add_argument("--duration", type=int, default=600,
191
+ help="Max recording time (seconds, default 10 min)")
192
+ parser.add_argument("--output-dir", type=str,
193
+ default=os.path.join(_PROJECT_ROOT, "data", "collected_data"),
194
+ help="Where to save .npz files")
195
+ args = parser.parse_args()
196
+
197
+ os.makedirs(args.output_dir, exist_ok=True)
198
+
199
+ detector = FaceMeshDetector()
200
+ head_pose = HeadPoseEstimator()
201
+ eye_scorer = EyeBehaviourScorer()
202
+ temporal = TemporalTracker()
203
+
204
+ cap = cv2.VideoCapture(args.camera)
205
+ if not cap.isOpened():
206
+ print("[COLLECT] ERROR: can't open camera")
207
+ return
208
+
209
+ print("[COLLECT] Data Collection Tool")
210
+ print(f"[COLLECT] Session: {args.name}, max {args.duration}s")
211
+ print(f"[COLLECT] Features per frame: {NUM_FEATURES}")
212
+ print("[COLLECT] Controls:")
213
+ print(" 1 = FOCUSED (looking at screen normally)")
214
+ print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)")
215
+ print(" p = pause")
216
+ print(" q = save & quit")
217
+ print()
218
+ print("[COLLECT] TIPS for good data:")
219
+ print(" • Switch between 1 and 0 every 10-30 seconds")
220
+ print(" • Aim for 20+ transitions total")
221
+ print(" • Act out varied scenarios: reading, phone, talking, drowsy")
222
+ print(" • Record at least 5 minutes")
223
+ print()
224
+
225
+ features_list = []
226
+ labels_list = []
227
+ label = None # None = paused
228
+ transitions = 0 # count label switches
229
+ prev_label = None
230
+ status = "PAUSED -- press 1 (focused) or 0 (not focused)"
231
+ t_start = time.time()
232
+ prev_time = time.time()
233
+ fps = 0.0
234
+
235
+ try:
236
+ while True:
237
+ elapsed = time.time() - t_start
238
+ if elapsed > args.duration:
239
+ print(f"[COLLECT] Time limit ({args.duration}s)")
240
+ break
241
+
242
+ ret, frame = cap.read()
243
+ if not ret:
244
+ break
245
+
246
+ h, w = frame.shape[:2]
247
+ landmarks = detector.process(frame)
248
+ face_ok = landmarks is not None
249
+
250
+ if face_ok and label is not None:
251
+ vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
252
+ features_list.append(vec)
253
+ labels_list.append(label)
254
+
255
+ if prev_label is not None and label != prev_label:
256
+ transitions += 1
257
+ prev_label = label
258
+
259
+ now = time.time()
260
+ fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
261
+ prev_time = now
262
+
263
+ # --- draw UI ---
264
+ n = len(labels_list)
265
+ n1 = sum(1 for x in labels_list if x == 1)
266
+ n0 = n - n1
267
+ remaining = max(0, args.duration - elapsed)
268
+
269
+ bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80))
270
+ cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1)
271
+ cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA)
272
+ cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}",
273
+ (10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA)
274
+ cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
275
+ cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA)
276
+
277
+ if n > 0:
278
+ bar_w = min(w - 20, 300)
279
+ bar_x = w - bar_w - 10
280
+ bar_y = 58
281
+ frac = n1 / n
282
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1)
283
+ cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1)
284
+ cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8),
285
+ FONT, 0.3, GRAY, 1, cv2.LINE_AA)
286
+
287
+ if not face_ok:
288
+ cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA)
289
+
290
+ # red dot = recording
291
+ if label is not None and face_ok:
292
+ cv2.circle(frame, (w - 20, 80), 8, RED, -1)
293
+
294
+ # live warnings
295
+ warn_y = h - 35
296
+ if n > 100 and transitions < 3:
297
+ cv2.putText(frame, "! Switch more often (aim for 20+ transitions)",
298
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
299
+ warn_y -= 18
300
+ if elapsed > 30 and n > 0:
301
+ bal = n1 / n
302
+ if bal < 0.25 or bal > 0.75:
303
+ cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other",
304
+ (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
305
+ warn_y -= 18
306
+
307
+ cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit",
308
+ (10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA)
309
+
310
+ cv2.imshow("FocusGuard -- Data Collection", frame)
311
+
312
+ key = cv2.waitKey(1) & 0xFF
313
+ if key == ord("1"):
314
+ label = 1
315
+ status = "Recording: FOCUSED"
316
+ print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})")
317
+ elif key == ord("0"):
318
+ label = 0
319
+ status = "Recording: NOT FOCUSED"
320
+ print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})")
321
+ elif key == ord("p"):
322
+ label = None
323
+ status = "PAUSED"
324
+ print(f"[COLLECT] paused (n={n})")
325
+ elif key == ord("q"):
326
+ break
327
+
328
+ finally:
329
+ cap.release()
330
+ cv2.destroyAllWindows()
331
+ detector.close()
332
+
333
+ if len(features_list) > 0:
334
+ feats = np.stack(features_list)
335
+ labs = np.array(labels_list, dtype=np.int64)
336
+
337
+ ts = time.strftime("%Y%m%d_%H%M%S")
338
+ fname = f"{args.name}_{ts}.npz"
339
+ fpath = os.path.join(args.output_dir, fname)
340
+ np.savez(fpath,
341
+ features=feats,
342
+ labels=labs,
343
+ feature_names=np.array(FEATURE_NAMES))
344
+
345
+ print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}")
346
+ print(f" Shape: {feats.shape} ({NUM_FEATURES} features)")
347
+
348
+ quality_report(labs)
349
+ else:
350
+ print("\n[COLLECT] No data collected")
351
+
352
+ print("[COLLECT] Done")
353
+
354
+
355
+ if __name__ == "__main__":
356
+ main()
models/eye_classifier.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+
7
+
8
+ class EyeClassifier(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def name(self) -> str:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
16
+ pass
17
+
18
+
19
+ class GeometricOnlyClassifier(EyeClassifier):
20
+ @property
21
+ def name(self) -> str:
22
+ return "geometric"
23
+
24
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
25
+ return 1.0
26
+
27
+
28
+ class YOLOv11Classifier(EyeClassifier):
29
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
30
+ from ultralytics import YOLO
31
+
32
+ self._model = YOLO(checkpoint_path)
33
+ self._device = device
34
+
35
+ names = self._model.names
36
+ self._attentive_idx = None
37
+ for idx, cls_name in names.items():
38
+ if cls_name in ("open", "attentive"):
39
+ self._attentive_idx = idx
40
+ break
41
+ if self._attentive_idx is None:
42
+ self._attentive_idx = max(names.keys())
43
+ print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
44
+
45
+ @property
46
+ def name(self) -> str:
47
+ return "yolo"
48
+
49
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
50
+ if not crops_bgr:
51
+ return 1.0
52
+ results = self._model.predict(crops_bgr, device=self._device, verbose=False)
53
+ scores = [float(r.probs.data[self._attentive_idx]) for r in results]
54
+ return sum(scores) / len(scores) if scores else 1.0
55
+
56
+
57
+ def load_eye_classifier(
58
+ path: str | None = None,
59
+ backend: str = "yolo",
60
+ device: str = "cpu",
61
+ ) -> EyeClassifier:
62
+ if path is None or backend == "geometric":
63
+ return GeometricOnlyClassifier()
64
+
65
+ try:
66
+ return YOLOv11Classifier(path, device=device)
67
+ except ImportError:
68
+ print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
69
+ raise
models/eye_crop.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from models.face_mesh import FaceMeshDetector
5
+
6
+ LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
7
+ RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
8
+
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+ CROP_SIZE = 96
13
+
14
+
15
+ def _bbox_from_landmarks(
16
+ landmarks: np.ndarray,
17
+ indices: list[int],
18
+ frame_w: int,
19
+ frame_h: int,
20
+ expand: float = 0.4,
21
+ ) -> tuple[int, int, int, int]:
22
+ pts = landmarks[indices, :2]
23
+ px = pts[:, 0] * frame_w
24
+ py = pts[:, 1] * frame_h
25
+
26
+ x_min, x_max = px.min(), px.max()
27
+ y_min, y_max = py.min(), py.max()
28
+ w = x_max - x_min
29
+ h = y_max - y_min
30
+ cx = (x_min + x_max) / 2
31
+ cy = (y_min + y_max) / 2
32
+
33
+ size = max(w, h) * (1 + expand)
34
+ half = size / 2
35
+
36
+ x1 = int(max(cx - half, 0))
37
+ y1 = int(max(cy - half, 0))
38
+ x2 = int(min(cx + half, frame_w))
39
+ y2 = int(min(cy + half, frame_h))
40
+
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def extract_eye_crops(
45
+ frame: np.ndarray,
46
+ landmarks: np.ndarray,
47
+ expand: float = 0.4,
48
+ crop_size: int = CROP_SIZE,
49
+ ) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
50
+ h, w = frame.shape[:2]
51
+
52
+ left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
53
+ right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
54
+
55
+ left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
56
+ right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
57
+
58
+ if left_crop.size == 0:
59
+ left_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
60
+ else:
61
+ left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
62
+
63
+ if right_crop.size == 0:
64
+ right_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
65
+ else:
66
+ right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
67
+
68
+ return left_crop, right_crop, left_bbox, right_bbox
69
+
70
+
71
+ def crop_to_tensor(crop_bgr: np.ndarray):
72
+ import torch
73
+
74
+ rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
75
+ for c in range(3):
76
+ rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
77
+ return torch.from_numpy(rgb.transpose(2, 0, 1))
models/eye_scorer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+
5
+ _LEFT_EYE_EAR = [33, 160, 158, 133, 153, 145]
6
+ _RIGHT_EYE_EAR = [362, 385, 387, 263, 373, 380]
7
+
8
+ _LEFT_IRIS_CENTER = 468
9
+ _RIGHT_IRIS_CENTER = 473
10
+
11
+ _LEFT_EYE_INNER = 133
12
+ _LEFT_EYE_OUTER = 33
13
+ _RIGHT_EYE_INNER = 362
14
+ _RIGHT_EYE_OUTER = 263
15
+
16
+ _LEFT_EYE_TOP = 159
17
+ _LEFT_EYE_BOTTOM = 145
18
+ _RIGHT_EYE_TOP = 386
19
+ _RIGHT_EYE_BOTTOM = 374
20
+
21
+ _MOUTH_TOP = 13
22
+ _MOUTH_BOTTOM = 14
23
+ _MOUTH_LEFT = 78
24
+ _MOUTH_RIGHT = 308
25
+ _MOUTH_UPPER_1 = 82
26
+ _MOUTH_UPPER_2 = 312
27
+ _MOUTH_LOWER_1 = 87
28
+ _MOUTH_LOWER_2 = 317
29
+
30
+ MAR_YAWN_THRESHOLD = 0.55
31
+
32
+
33
+ def _distance(p1: np.ndarray, p2: np.ndarray) -> float:
34
+ return float(np.linalg.norm(p1 - p2))
35
+
36
+
37
+ def compute_ear(landmarks: np.ndarray, eye_indices: list[int]) -> float:
38
+ p1 = landmarks[eye_indices[0], :2]
39
+ p2 = landmarks[eye_indices[1], :2]
40
+ p3 = landmarks[eye_indices[2], :2]
41
+ p4 = landmarks[eye_indices[3], :2]
42
+ p5 = landmarks[eye_indices[4], :2]
43
+ p6 = landmarks[eye_indices[5], :2]
44
+
45
+ vertical1 = _distance(p2, p6)
46
+ vertical2 = _distance(p3, p5)
47
+ horizontal = _distance(p1, p4)
48
+
49
+ if horizontal < 1e-6:
50
+ return 0.0
51
+
52
+ return (vertical1 + vertical2) / (2.0 * horizontal)
53
+
54
+
55
+ def compute_avg_ear(landmarks: np.ndarray) -> float:
56
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
57
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
58
+ return (left_ear + right_ear) / 2.0
59
+
60
+
61
+ def compute_gaze_ratio(landmarks: np.ndarray) -> tuple[float, float]:
62
+ left_iris = landmarks[_LEFT_IRIS_CENTER, :2]
63
+ left_inner = landmarks[_LEFT_EYE_INNER, :2]
64
+ left_outer = landmarks[_LEFT_EYE_OUTER, :2]
65
+ left_top = landmarks[_LEFT_EYE_TOP, :2]
66
+ left_bottom = landmarks[_LEFT_EYE_BOTTOM, :2]
67
+
68
+ right_iris = landmarks[_RIGHT_IRIS_CENTER, :2]
69
+ right_inner = landmarks[_RIGHT_EYE_INNER, :2]
70
+ right_outer = landmarks[_RIGHT_EYE_OUTER, :2]
71
+ right_top = landmarks[_RIGHT_EYE_TOP, :2]
72
+ right_bottom = landmarks[_RIGHT_EYE_BOTTOM, :2]
73
+
74
+ left_h_total = _distance(left_inner, left_outer)
75
+ right_h_total = _distance(right_inner, right_outer)
76
+
77
+ if left_h_total < 1e-6 or right_h_total < 1e-6:
78
+ return 0.5, 0.5
79
+
80
+ left_h_ratio = _distance(left_outer, left_iris) / left_h_total
81
+ right_h_ratio = _distance(right_outer, right_iris) / right_h_total
82
+ h_ratio = (left_h_ratio + right_h_ratio) / 2.0
83
+
84
+ left_v_total = _distance(left_top, left_bottom)
85
+ right_v_total = _distance(right_top, right_bottom)
86
+
87
+ if left_v_total < 1e-6 or right_v_total < 1e-6:
88
+ return h_ratio, 0.5
89
+
90
+ left_v_ratio = _distance(left_top, left_iris) / left_v_total
91
+ right_v_ratio = _distance(right_top, right_iris) / right_v_total
92
+ v_ratio = (left_v_ratio + right_v_ratio) / 2.0
93
+
94
+ return float(np.clip(h_ratio, 0, 1)), float(np.clip(v_ratio, 0, 1))
95
+
96
+
97
+ def compute_mar(landmarks: np.ndarray) -> float:
98
+ top = landmarks[_MOUTH_TOP, :2]
99
+ bottom = landmarks[_MOUTH_BOTTOM, :2]
100
+ left = landmarks[_MOUTH_LEFT, :2]
101
+ right = landmarks[_MOUTH_RIGHT, :2]
102
+ upper1 = landmarks[_MOUTH_UPPER_1, :2]
103
+ lower1 = landmarks[_MOUTH_LOWER_1, :2]
104
+ upper2 = landmarks[_MOUTH_UPPER_2, :2]
105
+ lower2 = landmarks[_MOUTH_LOWER_2, :2]
106
+
107
+ horizontal = _distance(left, right)
108
+ if horizontal < 1e-6:
109
+ return 0.0
110
+ v1 = _distance(upper1, lower1)
111
+ v2 = _distance(top, bottom)
112
+ v3 = _distance(upper2, lower2)
113
+ return (v1 + v2 + v3) / (2.0 * horizontal)
114
+
115
+
116
+ class EyeBehaviourScorer:
117
+ def __init__(
118
+ self,
119
+ ear_open: float = 0.30,
120
+ ear_closed: float = 0.16,
121
+ gaze_max_offset: float = 0.28,
122
+ ):
123
+ self.ear_open = ear_open
124
+ self.ear_closed = ear_closed
125
+ self.gaze_max_offset = gaze_max_offset
126
+
127
+ def _ear_score(self, ear: float) -> float:
128
+ if ear >= self.ear_open:
129
+ return 1.0
130
+ if ear <= self.ear_closed:
131
+ return 0.0
132
+ return (ear - self.ear_closed) / (self.ear_open - self.ear_closed)
133
+
134
+ def _gaze_score(self, h_ratio: float, v_ratio: float) -> float:
135
+ h_offset = abs(h_ratio - 0.5)
136
+ v_offset = abs(v_ratio - 0.5)
137
+ offset = math.sqrt(h_offset**2 + v_offset**2)
138
+ t = min(offset / self.gaze_max_offset, 1.0)
139
+ return 0.5 * (1.0 + math.cos(math.pi * t))
140
+
141
+ def score(self, landmarks: np.ndarray) -> float:
142
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
143
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
144
+ # Use minimum EAR so closing ONE eye is enough to drop the score
145
+ ear = min(left_ear, right_ear)
146
+ ear_s = self._ear_score(ear)
147
+ if ear_s < 0.3:
148
+ return ear_s
149
+ h_ratio, v_ratio = compute_gaze_ratio(landmarks)
150
+ gaze_s = self._gaze_score(h_ratio, v_ratio)
151
+ return ear_s * gaze_s
152
+
153
+ def detailed_score(self, landmarks: np.ndarray) -> dict:
154
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
155
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
156
+ ear = min(left_ear, right_ear)
157
+ ear_s = self._ear_score(ear)
158
+ h_ratio, v_ratio = compute_gaze_ratio(landmarks)
159
+ gaze_s = self._gaze_score(h_ratio, v_ratio)
160
+ s_eye = ear_s if ear_s < 0.3 else ear_s * gaze_s
161
+ return {
162
+ "ear": round(ear, 4),
163
+ "ear_score": round(ear_s, 4),
164
+ "h_gaze": round(h_ratio, 4),
165
+ "v_gaze": round(v_ratio, 4),
166
+ "gaze_score": round(gaze_s, 4),
167
+ "s_eye": round(s_eye, 4),
168
+ }
models/face_mesh.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from urllib.request import urlretrieve
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import mediapipe as mp
9
+ from mediapipe.tasks.python.vision import FaceLandmarkerOptions, FaceLandmarker, RunningMode
10
+ from mediapipe.tasks import python as mp_tasks
11
+
12
+ _MODEL_URL = (
13
+ "https://storage.googleapis.com/mediapipe-models/face_landmarker/"
14
+ "face_landmarker/float16/latest/face_landmarker.task"
15
+ )
16
+
17
+
18
+ def _ensure_model() -> str:
19
+ cache_dir = Path(os.environ.get(
20
+ "FOCUSGUARD_CACHE_DIR",
21
+ Path.home() / ".cache" / "focusguard",
22
+ ))
23
+ model_path = cache_dir / "face_landmarker.task"
24
+ if model_path.exists():
25
+ return str(model_path)
26
+ cache_dir.mkdir(parents=True, exist_ok=True)
27
+ print(f"[FACE_MESH] Downloading model to {model_path}...")
28
+ urlretrieve(_MODEL_URL, model_path)
29
+ print("[FACE_MESH] Download complete.")
30
+ return str(model_path)
31
+
32
+
33
+ class FaceMeshDetector:
34
+ LEFT_EYE_INDICES = [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]
35
+ RIGHT_EYE_INDICES = [362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398]
36
+ LEFT_IRIS_INDICES = [468, 469, 470, 471, 472]
37
+ RIGHT_IRIS_INDICES = [473, 474, 475, 476, 477]
38
+
39
+ def __init__(
40
+ self,
41
+ max_num_faces: int = 1,
42
+ min_detection_confidence: float = 0.5,
43
+ min_tracking_confidence: float = 0.5,
44
+ ):
45
+ model_path = _ensure_model()
46
+ options = FaceLandmarkerOptions(
47
+ base_options=mp_tasks.BaseOptions(model_asset_path=model_path),
48
+ num_faces=max_num_faces,
49
+ min_face_detection_confidence=min_detection_confidence,
50
+ min_face_presence_confidence=min_detection_confidence,
51
+ min_tracking_confidence=min_tracking_confidence,
52
+ running_mode=RunningMode.VIDEO,
53
+ )
54
+ self._landmarker = FaceLandmarker.create_from_options(options)
55
+ self._t0 = time.monotonic()
56
+ self._last_ts = 0
57
+
58
+ def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
59
+ # BGR in -> (478,3) norm x,y,z or None
60
+ rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
61
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
62
+ ts = max(int((time.monotonic() - self._t0) * 1000), self._last_ts + 1)
63
+ self._last_ts = ts
64
+ result = self._landmarker.detect_for_video(mp_image, ts)
65
+
66
+ if not result.face_landmarks:
67
+ return None
68
+
69
+ face = result.face_landmarks[0]
70
+ return np.array([(lm.x, lm.y, lm.z) for lm in face], dtype=np.float32)
71
+
72
+ def get_pixel_landmarks(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> np.ndarray:
73
+ # norm -> pixel (x,y)
74
+ pixel = np.zeros((landmarks.shape[0], 2), dtype=np.int32)
75
+ pixel[:, 0] = (landmarks[:, 0] * frame_w).astype(np.int32)
76
+ pixel[:, 1] = (landmarks[:, 1] * frame_h).astype(np.int32)
77
+ return pixel
78
+
79
+ def get_3d_landmarks(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> np.ndarray:
80
+ # norm -> pixel-scale x,y,z (z scaled by width)
81
+ pts = np.zeros_like(landmarks)
82
+ pts[:, 0] = landmarks[:, 0] * frame_w
83
+ pts[:, 1] = landmarks[:, 1] * frame_h
84
+ pts[:, 2] = landmarks[:, 2] * frame_w
85
+ return pts
86
+
87
+ def close(self):
88
+ self._landmarker.close()
89
+
90
+ def __enter__(self):
91
+ return self
92
+
93
+ def __exit__(self, *args):
94
+ self.close()
models/head_pose.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ _LANDMARK_INDICES = [1, 152, 33, 263, 61, 291]
7
+
8
+ _MODEL_POINTS = np.array(
9
+ [
10
+ [0.0, 0.0, 0.0],
11
+ [0.0, -330.0, -65.0],
12
+ [-225.0, 170.0, -135.0],
13
+ [225.0, 170.0, -135.0],
14
+ [-150.0, -150.0, -125.0],
15
+ [150.0, -150.0, -125.0],
16
+ ],
17
+ dtype=np.float64,
18
+ )
19
+
20
+
21
+ class HeadPoseEstimator:
22
+ def __init__(self, max_angle: float = 30.0, roll_weight: float = 0.5):
23
+ self.max_angle = max_angle
24
+ self.roll_weight = roll_weight
25
+ self._camera_matrix = None
26
+ self._frame_size = None
27
+ self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
28
+ self._cache_key = None
29
+ self._cache_result = None
30
+
31
+ def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
32
+ if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
33
+ return self._camera_matrix
34
+ focal_length = float(frame_w)
35
+ cx, cy = frame_w / 2.0, frame_h / 2.0
36
+ self._camera_matrix = np.array(
37
+ [[focal_length, 0, cx], [0, focal_length, cy], [0, 0, 1]],
38
+ dtype=np.float64,
39
+ )
40
+ self._frame_size = (frame_w, frame_h)
41
+ return self._camera_matrix
42
+
43
+ def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
44
+ key = (landmarks.data.tobytes(), frame_w, frame_h)
45
+ if self._cache_key == key:
46
+ return self._cache_result
47
+
48
+ image_points = np.array(
49
+ [
50
+ [landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
51
+ for i in _LANDMARK_INDICES
52
+ ],
53
+ dtype=np.float64,
54
+ )
55
+ camera_matrix = self._get_camera_matrix(frame_w, frame_h)
56
+ success, rvec, tvec = cv2.solvePnP(
57
+ _MODEL_POINTS,
58
+ image_points,
59
+ camera_matrix,
60
+ self._dist_coeffs,
61
+ flags=cv2.SOLVEPNP_ITERATIVE,
62
+ )
63
+ result = (success, rvec, tvec, image_points)
64
+ self._cache_key = key
65
+ self._cache_result = result
66
+ return result
67
+
68
+ def estimate(
69
+ self, landmarks: np.ndarray, frame_w: int, frame_h: int
70
+ ) -> tuple[float, float, float] | None:
71
+ success, rvec, tvec, _ = self._solve(landmarks, frame_w, frame_h)
72
+ if not success:
73
+ return None
74
+
75
+ rmat, _ = cv2.Rodrigues(rvec)
76
+ nose_dir = rmat @ np.array([0.0, 0.0, 1.0])
77
+ face_up = rmat @ np.array([0.0, 1.0, 0.0])
78
+
79
+ yaw = math.degrees(math.atan2(nose_dir[0], -nose_dir[2]))
80
+ pitch = math.degrees(math.asin(np.clip(-nose_dir[1], -1.0, 1.0)))
81
+ roll = math.degrees(math.atan2(face_up[0], -face_up[1]))
82
+
83
+ return (yaw, pitch, roll)
84
+
85
+ def score(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> float:
86
+ angles = self.estimate(landmarks, frame_w, frame_h)
87
+ if angles is None:
88
+ return 0.0
89
+
90
+ yaw, pitch, roll = angles
91
+ deviation = math.sqrt(yaw**2 + pitch**2 + (self.roll_weight * roll) ** 2)
92
+ t = min(deviation / self.max_angle, 1.0)
93
+ return 0.5 * (1.0 + math.cos(math.pi * t))
94
+
95
+ def draw_axes(
96
+ self,
97
+ frame: np.ndarray,
98
+ landmarks: np.ndarray,
99
+ axis_length: float = 50.0,
100
+ ) -> np.ndarray:
101
+ h, w = frame.shape[:2]
102
+ success, rvec, tvec, image_points = self._solve(landmarks, w, h)
103
+ if not success:
104
+ return frame
105
+
106
+ camera_matrix = self._get_camera_matrix(w, h)
107
+ nose = tuple(image_points[0].astype(int))
108
+
109
+ axes_3d = np.float64(
110
+ [[axis_length, 0, 0], [0, axis_length, 0], [0, 0, axis_length]]
111
+ )
112
+ projected, _ = cv2.projectPoints(
113
+ axes_3d, rvec, tvec, camera_matrix, self._dist_coeffs
114
+ )
115
+
116
+ colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
117
+ for i, color in enumerate(colors):
118
+ pt = tuple(projected[i].ravel().astype(int))
119
+ cv2.line(frame, nose, pt, color, 2)
120
+
121
+ return frame
models/mlp/__init__.py ADDED
File without changes
models/mlp/eval_accuracy.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load saved MLP checkpoint and print test accuracy, F1, AUC."""
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ import torch
7
+ from sklearn.metrics import f1_score, roc_auc_score
8
+
9
+ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
10
+ if REPO_ROOT not in sys.path:
11
+ sys.path.insert(0, REPO_ROOT)
12
+
13
+ from data_preparation.prepare_dataset import get_dataloaders
14
+ from models.mlp.train import BaseModel
15
+
16
+ CKPT_PATH = os.path.join(REPO_ROOT, "checkpoints", "mlp_best.pt")
17
+
18
+
19
+ def main():
20
+ if not os.path.isfile(CKPT_PATH):
21
+ print(f"No checkpoint at {CKPT_PATH}. Train first: python -m models.mlp.train")
22
+ return
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ train_loader, val_loader, test_loader, num_features, num_classes, _ = get_dataloaders(
26
+ model_name="face_orientation",
27
+ batch_size=32,
28
+ split_ratios=(0.7, 0.15, 0.15),
29
+ seed=42,
30
+ )
31
+
32
+ model = BaseModel(num_features, num_classes).to(device)
33
+ model.load_state_dict(torch.load(CKPT_PATH, map_location=device, weights_only=True))
34
+ model.eval()
35
+
36
+ criterion = torch.nn.CrossEntropyLoss()
37
+ test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(
38
+ test_loader, criterion, device
39
+ )
40
+
41
+ f1 = float(f1_score(test_labels, test_preds, average="weighted"))
42
+ if num_classes > 2:
43
+ auc = float(roc_auc_score(test_labels, test_probs, multi_class="ovr", average="weighted"))
44
+ else:
45
+ auc = float(roc_auc_score(test_labels, test_probs[:, 1]))
46
+
47
+ print("MLP (face_orientation) — test set")
48
+ print(" Accuracy: {:.2%}".format(test_acc))
49
+ print(" F1: {:.4f}".format(f1))
50
+ print(" ROC-AUC: {:.4f}".format(auc))
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
models/mlp/sweep.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLP Hyperparameter Sweep (Optuna).
3
+ Run: python -m models.mlp.sweep
4
+ """
5
+ import os
6
+ import sys
7
+
8
+ import optuna
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+
13
+ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
14
+ if REPO_ROOT not in sys.path:
15
+ sys.path.insert(0, REPO_ROOT)
16
+
17
+ from data_preparation.prepare_dataset import get_dataloaders
18
+ from models.mlp.train import BaseModel, set_seed
19
+
20
+ SEED = 42
21
+ N_TRIALS = 20
22
+ EPOCHS_PER_TRIAL = 15
23
+
24
+
25
+ def objective(trial):
26
+ set_seed(SEED)
27
+
28
+ lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
29
+ batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
30
+
31
+ train_loader, val_loader, _, num_features, num_classes, _ = get_dataloaders(
32
+ model_name="face_orientation",
33
+ batch_size=batch_size,
34
+ split_ratios=(0.7, 0.15, 0.15),
35
+ seed=SEED,
36
+ )
37
+
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ model = BaseModel(num_features, num_classes).to(device)
40
+ criterion = nn.CrossEntropyLoss()
41
+ optimizer = optim.Adam(model.parameters(), lr=lr)
42
+
43
+ best_val_acc = 0.0
44
+ for epoch in range(1, EPOCHS_PER_TRIAL + 1):
45
+ model.training_step(train_loader, optimizer, criterion, device)
46
+ val_loss, val_acc = model.validation_step(val_loader, criterion, device)
47
+ if val_acc > best_val_acc:
48
+ best_val_acc = val_acc
49
+
50
+ return 1.0 - best_val_acc # minimize (1 - accuracy)
51
+
52
+
53
+ def main():
54
+ study = optuna.create_study(direction="minimize", study_name="mlp_sweep")
55
+ print(f"[SWEEP] MLP Optuna sweep: {N_TRIALS} trials, {EPOCHS_PER_TRIAL} epochs each")
56
+ study.optimize(objective, n_trials=N_TRIALS)
57
+
58
+ print("\n[SWEEP] Top-5 trials by validation accuracy")
59
+ best = sorted(study.trials, key=lambda t: t.value if t.value is not None else float("inf"))[:5]
60
+ for i, t in enumerate(best, 1):
61
+ acc = (1.0 - t.value) * 100
62
+ print(f" #{i} Val Acc: {acc:.2f}% params={t.params}")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
models/mlp/train.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os, sys
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from sklearn.metrics import f1_score, roc_auc_score
10
+
11
+ from data_preparation.prepare_dataset import get_dataloaders
12
+
13
+ USE_CLEARML = False
14
+
15
+ _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
16
+ CFG = {
17
+ "model_name": "face_orientation",
18
+ "epochs": 30,
19
+ "batch_size": 32,
20
+ "lr": 1e-3,
21
+ "seed": 42,
22
+ "split_ratios": (0.7, 0.15, 0.15),
23
+ "checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
24
+ "logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
25
+ }
26
+
27
+
28
+ # ==== ClearML (opt-in) =============================================
29
+ task = None
30
+ if USE_CLEARML:
31
+ from clearml import Task
32
+ task = Task.init(
33
+ project_name="Focus Guard",
34
+ task_name="MLP Model Training",
35
+ tags=["training", "mlp_model"]
36
+ )
37
+ task.connect(CFG)
38
+
39
+
40
+
41
+ # ==== Model =============================================
42
+ def set_seed(seed: int):
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ if torch.cuda.is_available():
47
+ torch.cuda.manual_seed_all(seed)
48
+
49
+
50
+ class BaseModel(nn.Module):
51
+ def __init__(self, num_features: int, num_classes: int):
52
+ super().__init__()
53
+ self.network = nn.Sequential(
54
+ nn.Linear(num_features, 64),
55
+ nn.ReLU(),
56
+ nn.Linear(64, 32),
57
+ nn.ReLU(),
58
+ nn.Linear(32, num_classes),
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.network(x)
63
+
64
+ def training_step(self, loader, optimizer, criterion, device):
65
+ self.train()
66
+ total_loss = 0.0
67
+ correct = 0
68
+ total = 0
69
+
70
+ for features, labels in loader:
71
+ features, labels = features.to(device), labels.to(device)
72
+
73
+ optimizer.zero_grad()
74
+ outputs = self(features)
75
+ loss = criterion(outputs, labels)
76
+ loss.backward()
77
+ optimizer.step()
78
+
79
+ total_loss += loss.item() * features.size(0)
80
+ correct += (outputs.argmax(dim=1) == labels).sum().item()
81
+ total += features.size(0)
82
+
83
+ return total_loss / total, correct / total
84
+
85
+ @torch.no_grad()
86
+ def validation_step(self, loader, criterion, device):
87
+ self.eval()
88
+ total_loss = 0.0
89
+ correct = 0
90
+ total = 0
91
+
92
+ for features, labels in loader:
93
+ features, labels = features.to(device), labels.to(device)
94
+ outputs = self(features)
95
+ loss = criterion(outputs, labels)
96
+
97
+ total_loss += loss.item() * features.size(0)
98
+ correct += (outputs.argmax(dim=1) == labels).sum().item()
99
+ total += features.size(0)
100
+
101
+ return total_loss / total, correct / total
102
+
103
+ @torch.no_grad()
104
+ def test_step(self, loader, criterion, device):
105
+ self.eval()
106
+ total_loss = 0.0
107
+ correct = 0
108
+ total = 0
109
+
110
+ all_preds = []
111
+ all_labels = []
112
+ all_probs = []
113
+
114
+ for features, labels in loader:
115
+ features, labels = features.to(device), labels.to(device)
116
+ outputs = self(features)
117
+ loss = criterion(outputs, labels)
118
+
119
+ total_loss += loss.item() * features.size(0)
120
+ preds = outputs.argmax(dim=1)
121
+ correct += (preds == labels).sum().item()
122
+ total += features.size(0)
123
+
124
+ probs = torch.softmax(outputs, dim=1)
125
+ all_preds.extend(preds.cpu().numpy())
126
+ all_labels.extend(labels.cpu().numpy())
127
+ all_probs.extend(probs.cpu().numpy())
128
+
129
+ return total_loss / total, correct / total, np.array(all_probs), np.array(all_preds), np.array(all_labels)
130
+
131
+
132
+ def main():
133
+ set_seed(CFG["seed"])
134
+
135
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136
+ print(f"[TRAIN] Device: {device}")
137
+ print(f"[TRAIN] Model: {CFG['model_name']}")
138
+
139
+ train_loader, val_loader, test_loader, num_features, num_classes, scaler = get_dataloaders(
140
+ model_name=CFG["model_name"],
141
+ batch_size=CFG["batch_size"],
142
+ split_ratios=CFG["split_ratios"],
143
+ seed=CFG["seed"],
144
+ )
145
+
146
+ model = BaseModel(num_features, num_classes).to(device)
147
+ criterion = nn.CrossEntropyLoss()
148
+ optimizer = optim.Adam(model.parameters(), lr=CFG["lr"])
149
+
150
+ param_count = sum(p.numel() for p in model.parameters())
151
+ print(f"[TRAIN] Parameters: {param_count:,}")
152
+
153
+ ckpt_dir = CFG["checkpoints_dir"]
154
+ os.makedirs(ckpt_dir, exist_ok=True)
155
+ best_ckpt_path = os.path.join(ckpt_dir, "mlp_best.pt")
156
+
157
+ history = {
158
+ "model_name": CFG["model_name"],
159
+ "param_count": param_count,
160
+ "epochs": [],
161
+ "train_loss": [],
162
+ "train_acc": [],
163
+ "val_loss": [],
164
+ "val_acc": [],
165
+ }
166
+
167
+ best_val_acc = 0.0
168
+
169
+ print(f"\n{'Epoch':>6} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>10} | {'Val Acc':>9}")
170
+ print("-" * 60)
171
+
172
+ for epoch in range(1, CFG["epochs"] + 1):
173
+ train_loss, train_acc = model.training_step(train_loader, optimizer, criterion, device)
174
+ val_loss, val_acc = model.validation_step(val_loader, criterion, device)
175
+
176
+ history["epochs"].append(epoch)
177
+ history["train_loss"].append(round(train_loss, 4))
178
+ history["train_acc"].append(round(train_acc, 4))
179
+ history["val_loss"].append(round(val_loss, 4))
180
+ history["val_acc"].append(round(val_acc, 4))
181
+
182
+
183
+ current_lr = optimizer.param_groups[0]['lr']
184
+ if task is not None:
185
+ task.logger.report_scalar("Loss", "Train", float(train_loss), iteration=epoch)
186
+ task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
187
+ task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
188
+ task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
189
+ task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
190
+ task.logger.flush()
191
+
192
+ marker = ""
193
+ if val_acc > best_val_acc:
194
+ best_val_acc = val_acc
195
+ torch.save(model.state_dict(), best_ckpt_path)
196
+ marker = " *"
197
+
198
+ print(f"{epoch:>6} | {train_loss:>10.4f} | {train_acc:>8.2%} | {val_loss:>10.4f} | {val_acc:>8.2%}{marker}")
199
+
200
+ print(f"\nBest validation accuracy: {best_val_acc:.2%}")
201
+ print(f"Checkpoint saved to: {best_ckpt_path}")
202
+
203
+ model.load_state_dict(torch.load(best_ckpt_path, weights_only=True))
204
+ test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(test_loader, criterion, device)
205
+
206
+ test_f1 = f1_score(test_labels, test_preds, average='weighted')
207
+ # Handle potentially >2 classes for AUC
208
+ if num_classes > 2:
209
+ test_auc = roc_auc_score(test_labels, test_probs, multi_class='ovr', average='weighted')
210
+ else:
211
+ test_auc = roc_auc_score(test_labels, test_probs[:, 1])
212
+
213
+ print(f"\n[TEST] Loss: {test_loss:.4f} | Accuracy: {test_acc:.2%}")
214
+ print(f"[TEST] F1: {test_f1:.4f} | ROC-AUC: {test_auc:.4f}")
215
+
216
+ history["test_loss"] = round(test_loss, 4)
217
+ history["test_acc"] = round(test_acc, 4)
218
+ history["test_f1"] = round(test_f1, 4)
219
+ history["test_auc"] = round(test_auc, 4)
220
+
221
+ logs_dir = CFG["logs_dir"]
222
+ os.makedirs(logs_dir, exist_ok=True)
223
+ log_path = os.path.join(logs_dir, f"{CFG['model_name']}_training_log.json")
224
+
225
+ with open(log_path, "w") as f:
226
+ json.dump(history, f, indent=2)
227
+
228
+ print(f"[LOG] Training history saved to: {log_path}")
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()
models/xgboost/add_accuracy.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from xgboost import XGBClassifier
4
+ from sklearn.metrics import accuracy_score
5
+ from data_preparation.prepare_dataset import get_numpy_splits
6
+ import os
7
+
8
+ print("Loading dataset for evaluation...")
9
+ splits, _, _, _ = get_numpy_splits(
10
+ model_name="face_orientation",
11
+ split_ratios=(0.7, 0.15, 0.15),
12
+ seed=42,
13
+ scale=False
14
+ )
15
+ X_train, y_train = splits["X_train"], splits["y_train"]
16
+ X_val, y_val = splits["X_val"], splits["y_val"]
17
+
18
+ csv_path = 'models/xgboost/sweep_results_all_40.csv'
19
+ df = pd.read_csv(csv_path)
20
+
21
+ # We will calculate accuracy for each row
22
+ accuracies = []
23
+
24
+ print(f"Re-evaluating {len(df)} configurations for accuracy. This will take a few minutes...")
25
+ for idx, row in df.iterrows():
26
+ params = {
27
+ "n_estimators": int(row["n_estimators"]),
28
+ "max_depth": int(row["max_depth"]),
29
+ "learning_rate": float(row["learning_rate"]),
30
+ "subsample": float(row["subsample"]),
31
+ "colsample_bytree": float(row["colsample_bytree"]),
32
+ "reg_alpha": float(row["reg_alpha"]),
33
+ "reg_lambda": float(row["reg_lambda"]),
34
+ "random_state": 42,
35
+ "use_label_encoder": False,
36
+ "verbosity": 0,
37
+ "eval_metric": "logloss"
38
+ }
39
+
40
+ # Train the exact same model quickly
41
+ model = XGBClassifier(**params)
42
+ model.fit(X_train, y_train)
43
+
44
+ # Get validation predictions and calculate accuracy
45
+ val_preds = model.predict(X_val)
46
+ acc = accuracy_score(y_val, val_preds)
47
+ accuracies.append(round(acc, 4))
48
+
49
+ if (idx + 1) % 5 == 0:
50
+ print(f"Processed {idx + 1}/{len(df)} trials...")
51
+
52
+ # Add accuracy column and save back to CSV
53
+ df.insert(2, 'val_accuracy', accuracies)
54
+ df.to_csv(csv_path, index=False)
55
+
56
+ print(f"\nDone! Updated {csv_path} with 'val_accuracy'.")
57
+ # Display the top 5 by accuracy now just to see
58
+ top5_acc = df.nlargest(5, 'val_accuracy')[['task_id', 'val_accuracy', 'val_f1', 'val_loss']]
59
+ print("\nTop 5 Trials by Accuracy:")
60
+ print(top5_acc.to_string(index=False))
models/xgboost/checkpoints/face_orientation_best.json ADDED
The diff for this file is too large to render. See raw diff
 
models/xgboost/eval_accuracy.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load saved XGBoost checkpoint and print test accuracy, F1, AUC."""
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ from sklearn.metrics import f1_score, roc_auc_score
7
+ from xgboost import XGBClassifier
8
+
9
+ # run from repo root: python -m models.xgboost.eval_accuracy
10
+ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if REPO_ROOT not in sys.path:
12
+ sys.path.insert(0, REPO_ROOT)
13
+
14
+ from data_preparation.prepare_dataset import get_numpy_splits
15
+
16
+ MODEL_NAME = "face_orientation"
17
+ CKPT_DIR = os.path.join(REPO_ROOT, "checkpoints")
18
+ MODEL_PATH = os.path.join(CKPT_DIR, f"xgboost_{MODEL_NAME}_best.json")
19
+
20
+
21
+ def main():
22
+ if not os.path.isfile(MODEL_PATH):
23
+ print(f"No checkpoint at {MODEL_PATH}. Train first: python -m models.xgboost.train")
24
+ return
25
+
26
+ splits, num_features, num_classes, _ = get_numpy_splits(
27
+ model_name=MODEL_NAME,
28
+ split_ratios=(0.7, 0.15, 0.15),
29
+ seed=42,
30
+ scale=False,
31
+ )
32
+ X_test = splits["X_test"]
33
+ y_test = splits["y_test"]
34
+
35
+ model = XGBClassifier()
36
+ model.load_model(MODEL_PATH)
37
+
38
+ preds = model.predict(X_test)
39
+ probs = model.predict_proba(X_test)
40
+ acc = float(np.mean(preds == y_test))
41
+ f1 = float(f1_score(y_test, preds, average="weighted"))
42
+ if num_classes > 2:
43
+ auc = float(roc_auc_score(y_test, probs, multi_class="ovr", average="weighted"))
44
+ else:
45
+ auc = float(roc_auc_score(y_test, probs[:, 1]))
46
+
47
+ print("XGBoost (face_orientation) — test set")
48
+ print(" Accuracy: {:.2%}".format(acc))
49
+ print(" F1: {:.4f}".format(f1))
50
+ print(" ROC-AUC: {:.4f}".format(auc))
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()