Spaces:
Sleeping
Sleeping
Commit ·
8bbb872
1
Parent(s): 05616fb
Upload partially updated files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +41 -0
- Dockerfile +27 -0
- api/history +0 -0
- api/import +0 -0
- api/sessions +0 -0
- app.py +1 -0
- checkpoints/hybrid_focus_config.json +10 -0
- checkpoints/meta_best.npz +3 -0
- checkpoints/mlp_best.pt +3 -0
- checkpoints/model_best.joblib +3 -0
- checkpoints/scaler_best.joblib +3 -0
- checkpoints/xgboost_face_orientation_best.json +0 -0
- docker-compose.yml +5 -0
- eslint.config.js +29 -0
- index.html +17 -0
- main.py +964 -0
- models/README.md +53 -0
- models/__init__.py +1 -0
- models/cnn/CNN_MODEL/.claude/settings.local.json +7 -0
- models/cnn/CNN_MODEL/.gitattributes +1 -0
- models/cnn/CNN_MODEL/.gitignore +4 -0
- models/cnn/CNN_MODEL/README.md +74 -0
- models/cnn/CNN_MODEL/notebooks/eye_classifier_colab.ipynb +0 -0
- models/cnn/CNN_MODEL/scripts/focus_infer.py +199 -0
- models/cnn/CNN_MODEL/scripts/predict_image.py +49 -0
- models/cnn/CNN_MODEL/scripts/video_infer.py +281 -0
- models/cnn/CNN_MODEL/scripts/webcam_live.py +184 -0
- models/cnn/CNN_MODEL/weights/yolo11s-cls.pt +3 -0
- models/cnn/__init__.py +0 -0
- models/cnn/eye_attention/__init__.py +1 -0
- models/cnn/eye_attention/classifier.py +169 -0
- models/cnn/eye_attention/crop.py +70 -0
- models/cnn/eye_attention/train.py +0 -0
- models/cnn/notebooks/EyeCNN.ipynb +107 -0
- models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb +0 -0
- models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb +0 -0
- models/cnn/notebooks/README.md +1 -0
- models/collect_features.py +356 -0
- models/eye_classifier.py +69 -0
- models/eye_crop.py +77 -0
- models/eye_scorer.py +168 -0
- models/face_mesh.py +94 -0
- models/head_pose.py +121 -0
- models/mlp/__init__.py +0 -0
- models/mlp/eval_accuracy.py +54 -0
- models/mlp/sweep.py +66 -0
- models/mlp/train.py +232 -0
- models/xgboost/add_accuracy.py +60 -0
- models/xgboost/checkpoints/face_orientation_best.json +0 -0
- models/xgboost/eval_accuracy.py +54 -0
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules/
|
| 11 |
+
dist/
|
| 12 |
+
dist-ssr/
|
| 13 |
+
*.local
|
| 14 |
+
|
| 15 |
+
# Editor directories and files
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
| 18 |
+
.DS_Store
|
| 19 |
+
*.suo
|
| 20 |
+
*.ntvs*
|
| 21 |
+
*.njsproj
|
| 22 |
+
*.sln
|
| 23 |
+
*.sw?
|
| 24 |
+
*.py[cod]
|
| 25 |
+
*$py.class
|
| 26 |
+
*.so
|
| 27 |
+
.Python
|
| 28 |
+
venv/
|
| 29 |
+
.venv/
|
| 30 |
+
env/
|
| 31 |
+
.env
|
| 32 |
+
*.egg-info/
|
| 33 |
+
.eggs/
|
| 34 |
+
build/
|
| 35 |
+
Thumbs.db
|
| 36 |
+
|
| 37 |
+
# Project specific
|
| 38 |
+
focus_guard.db
|
| 39 |
+
static/
|
| 40 |
+
__pycache__/
|
| 41 |
+
docs/
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 5 |
+
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev build-essential nodejs npm && rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
COPY requirements.txt ./
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
RUN npm install && npm run build && mkdir -p /app/static && cp -R dist/* /app/static/
|
| 18 |
+
|
| 19 |
+
ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
|
| 20 |
+
RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
|
| 21 |
+
|
| 22 |
+
RUN mkdir -p /app/data && chown -R user:user /app
|
| 23 |
+
|
| 24 |
+
USER user
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-level", "debug"]
|
api/history
ADDED
|
File without changes
|
api/import
ADDED
|
File without changes
|
api/sessions
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from main import app
|
checkpoints/hybrid_focus_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"w_mlp": 0.6000000000000001,
|
| 3 |
+
"w_geo": 0.3999999999999999,
|
| 4 |
+
"threshold": 0.35,
|
| 5 |
+
"use_yawn_veto": true,
|
| 6 |
+
"geo_face_weight": 0.4,
|
| 7 |
+
"geo_eye_weight": 0.6,
|
| 8 |
+
"mar_yawn_threshold": 0.55,
|
| 9 |
+
"metric": "f1"
|
| 10 |
+
}
|
checkpoints/meta_best.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
|
| 3 |
+
size 4196
|
checkpoints/mlp_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2f55129785b6882c304483aa5399f5bf6c9ed6e73dfec7ca6f36cd0436156c8
|
| 3 |
+
size 14497
|
checkpoints/model_best.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:183f2d4419e0eb1e58704e5a7312eb61e331523566d4dc551054a07b3aac7557
|
| 3 |
+
size 5775881
|
checkpoints/scaler_best.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02ed6b4c0d99e0254c6a740a949da2384db58ec7d3e6df6432b9bfcd3a296c71
|
| 3 |
+
size 783
|
checkpoints/xgboost_face_orientation_best.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
focus-guard:
|
| 3 |
+
build: .
|
| 4 |
+
ports:
|
| 5 |
+
- "7860:7860"
|
eslint.config.js
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 6 |
+
|
| 7 |
+
export default defineConfig([
|
| 8 |
+
globalIgnores(['dist']),
|
| 9 |
+
{
|
| 10 |
+
files: ['**/*.{js,jsx}'],
|
| 11 |
+
extends: [
|
| 12 |
+
js.configs.recommended,
|
| 13 |
+
reactHooks.configs.flat.recommended,
|
| 14 |
+
reactRefresh.configs.vite,
|
| 15 |
+
],
|
| 16 |
+
languageOptions: {
|
| 17 |
+
ecmaVersion: 2020,
|
| 18 |
+
globals: globals.browser,
|
| 19 |
+
parserOptions: {
|
| 20 |
+
ecmaVersion: 'latest',
|
| 21 |
+
ecmaFeatures: { jsx: true },
|
| 22 |
+
sourceType: 'module',
|
| 23 |
+
},
|
| 24 |
+
},
|
| 25 |
+
rules: {
|
| 26 |
+
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
])
|
index.html
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8" />
|
| 6 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 8 |
+
<title>Focus Guard</title>
|
| 9 |
+
<link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
|
| 10 |
+
</head>
|
| 11 |
+
|
| 12 |
+
<body>
|
| 13 |
+
<div id="root"></div>
|
| 14 |
+
<script type="module" src="/src/main.jsx"></script>
|
| 15 |
+
</body>
|
| 16 |
+
|
| 17 |
+
</html>
|
main.py
ADDED
|
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
|
| 2 |
+
from fastapi.staticfiles import StaticFiles
|
| 3 |
+
from fastapi.responses import FileResponse
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional, List, Any
|
| 7 |
+
import base64
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import aiosqlite
|
| 11 |
+
import json
|
| 12 |
+
from datetime import datetime, timedelta
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Callable
|
| 17 |
+
import asyncio
|
| 18 |
+
import concurrent.futures
|
| 19 |
+
import threading
|
| 20 |
+
|
| 21 |
+
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
| 22 |
+
from av import VideoFrame
|
| 23 |
+
|
| 24 |
+
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 25 |
+
from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
|
| 26 |
+
from models.face_mesh import FaceMeshDetector
|
| 27 |
+
|
| 28 |
+
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
| 29 |
+
|
| 30 |
+
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 31 |
+
_CYAN = (255, 255, 0)
|
| 32 |
+
_GREEN = (0, 255, 0)
|
| 33 |
+
_MAGENTA = (255, 0, 255)
|
| 34 |
+
_ORANGE = (0, 165, 255)
|
| 35 |
+
_RED = (0, 0, 255)
|
| 36 |
+
_WHITE = (255, 255, 255)
|
| 37 |
+
_LIGHT_GREEN = (144, 238, 144)
|
| 38 |
+
|
| 39 |
+
_TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
|
| 40 |
+
_CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
|
| 41 |
+
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
| 42 |
+
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
| 43 |
+
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
|
| 44 |
+
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
|
| 45 |
+
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
|
| 46 |
+
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
|
| 47 |
+
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _lm_px(lm, idx, w, h):
|
| 51 |
+
return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _draw_polyline(frame, lm, indices, w, h, color, thickness):
|
| 55 |
+
for i in range(len(indices) - 1):
|
| 56 |
+
cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _draw_face_mesh(frame, lm, w, h):
|
| 60 |
+
"""Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
|
| 61 |
+
# Tessellation (gray triangular grid, semi-transparent)
|
| 62 |
+
overlay = frame.copy()
|
| 63 |
+
for s, e in _TESSELATION_CONNS:
|
| 64 |
+
cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
|
| 65 |
+
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 66 |
+
# Contours
|
| 67 |
+
for s, e in _CONTOUR_CONNS:
|
| 68 |
+
cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
|
| 69 |
+
# Eyebrows
|
| 70 |
+
_draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 71 |
+
_draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 72 |
+
# Nose
|
| 73 |
+
_draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
|
| 74 |
+
# Lips
|
| 75 |
+
_draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
|
| 76 |
+
_draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
|
| 77 |
+
# Eyes
|
| 78 |
+
left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
|
| 79 |
+
cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 80 |
+
right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
|
| 81 |
+
cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 82 |
+
# EAR key points
|
| 83 |
+
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
|
| 84 |
+
for idx in indices:
|
| 85 |
+
cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
|
| 86 |
+
# Irises + gaze lines
|
| 87 |
+
for iris_idx, eye_inner, eye_outer in [
|
| 88 |
+
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
|
| 89 |
+
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
|
| 90 |
+
]:
|
| 91 |
+
iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
|
| 92 |
+
center = iris_pts[0]
|
| 93 |
+
if len(iris_pts) >= 5:
|
| 94 |
+
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
|
| 95 |
+
radius = max(int(np.mean(radii)), 2)
|
| 96 |
+
cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
|
| 97 |
+
cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
|
| 98 |
+
eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
|
| 99 |
+
eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
|
| 100 |
+
dx, dy = center[0] - eye_cx, center[1] - eye_cy
|
| 101 |
+
cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _draw_hud(frame, result, model_name):
|
| 105 |
+
"""Draw status bar and detail overlay matching live_demo.py."""
|
| 106 |
+
h, w = frame.shape[:2]
|
| 107 |
+
is_focused = result["is_focused"]
|
| 108 |
+
status = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 109 |
+
color = _GREEN if is_focused else _RED
|
| 110 |
+
|
| 111 |
+
# Top bar
|
| 112 |
+
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
|
| 113 |
+
cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
|
| 114 |
+
cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
|
| 115 |
+
|
| 116 |
+
# Detail line
|
| 117 |
+
conf = result.get("mlp_prob", result.get("raw_score", 0.0))
|
| 118 |
+
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
|
| 119 |
+
sf = result.get("s_face", 0)
|
| 120 |
+
se = result.get("s_eye", 0)
|
| 121 |
+
detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
|
| 122 |
+
cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
|
| 123 |
+
|
| 124 |
+
# Head pose (top right)
|
| 125 |
+
if result.get("yaw") is not None:
|
| 126 |
+
cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
|
| 127 |
+
(w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
|
| 128 |
+
|
| 129 |
+
# Yawn indicator
|
| 130 |
+
if result.get("is_yawning"):
|
| 131 |
+
cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
|
| 132 |
+
|
| 133 |
+
# Landmark indices used for face mesh drawing on client (union of all groups).
|
| 134 |
+
# Sending only these instead of all 478 saves ~60% of the landmarks payload.
|
| 135 |
+
_MESH_INDICES = sorted(set(
|
| 136 |
+
[10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
|
| 137 |
+
+ [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
|
| 138 |
+
+ [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
|
| 139 |
+
+ [468,469,470,471,472, 473,474,475,476,477] # irises
|
| 140 |
+
+ [70,63,105,66,107,55,65,52,53,46] # left eyebrow
|
| 141 |
+
+ [300,293,334,296,336,285,295,282,283,276] # right eyebrow
|
| 142 |
+
+ [6,197,195,5,4,1,19,94,2] # nose bridge
|
| 143 |
+
+ [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
|
| 144 |
+
+ [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
|
| 145 |
+
+ [33,160,158,133,153,145] # left EAR key points
|
| 146 |
+
+ [362,385,387,263,373,380] # right EAR key points
|
| 147 |
+
))
|
| 148 |
+
# Build a lookup: original_index -> position in sparse array, so client can reconstruct.
|
| 149 |
+
_MESH_INDEX_SET = set(_MESH_INDICES)
|
| 150 |
+
|
| 151 |
+
# Initialize FastAPI app
|
| 152 |
+
app = FastAPI(title="Focus Guard API")
|
| 153 |
+
|
| 154 |
+
# Add CORS middleware
|
| 155 |
+
app.add_middleware(
|
| 156 |
+
CORSMiddleware,
|
| 157 |
+
allow_origins=["*"],
|
| 158 |
+
allow_credentials=True,
|
| 159 |
+
allow_methods=["*"],
|
| 160 |
+
allow_headers=["*"],
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Global variables
|
| 164 |
+
db_path = "focus_guard.db"
|
| 165 |
+
pcs = set()
|
| 166 |
+
_cached_model_name = "mlp" # in-memory cache, updated via /api/settings
|
| 167 |
+
|
| 168 |
+
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 169 |
+
if pc.iceGatheringState == "complete":
|
| 170 |
+
return
|
| 171 |
+
done = asyncio.Event()
|
| 172 |
+
|
| 173 |
+
@pc.on("icegatheringstatechange")
|
| 174 |
+
def _on_state_change():
|
| 175 |
+
if pc.iceGatheringState == "complete":
|
| 176 |
+
done.set()
|
| 177 |
+
|
| 178 |
+
await done.wait()
|
| 179 |
+
|
| 180 |
+
# ================ DATABASE MODELS ================
|
| 181 |
+
|
| 182 |
+
async def init_database():
|
| 183 |
+
"""Initialize SQLite database with required tables"""
|
| 184 |
+
async with aiosqlite.connect(db_path) as db:
|
| 185 |
+
# FocusSessions table
|
| 186 |
+
await db.execute("""
|
| 187 |
+
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 188 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 189 |
+
start_time TIMESTAMP NOT NULL,
|
| 190 |
+
end_time TIMESTAMP,
|
| 191 |
+
duration_seconds INTEGER DEFAULT 0,
|
| 192 |
+
focus_score REAL DEFAULT 0.0,
|
| 193 |
+
total_frames INTEGER DEFAULT 0,
|
| 194 |
+
focused_frames INTEGER DEFAULT 0,
|
| 195 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 196 |
+
)
|
| 197 |
+
""")
|
| 198 |
+
|
| 199 |
+
# FocusEvents table
|
| 200 |
+
await db.execute("""
|
| 201 |
+
CREATE TABLE IF NOT EXISTS focus_events (
|
| 202 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 203 |
+
session_id INTEGER NOT NULL,
|
| 204 |
+
timestamp TIMESTAMP NOT NULL,
|
| 205 |
+
is_focused BOOLEAN NOT NULL,
|
| 206 |
+
confidence REAL NOT NULL,
|
| 207 |
+
detection_data TEXT,
|
| 208 |
+
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 209 |
+
)
|
| 210 |
+
""")
|
| 211 |
+
|
| 212 |
+
# UserSettings table
|
| 213 |
+
await db.execute("""
|
| 214 |
+
CREATE TABLE IF NOT EXISTS user_settings (
|
| 215 |
+
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 216 |
+
sensitivity INTEGER DEFAULT 6,
|
| 217 |
+
notification_enabled BOOLEAN DEFAULT 1,
|
| 218 |
+
notification_threshold INTEGER DEFAULT 30,
|
| 219 |
+
frame_rate INTEGER DEFAULT 30,
|
| 220 |
+
model_name TEXT DEFAULT 'mlp'
|
| 221 |
+
)
|
| 222 |
+
""")
|
| 223 |
+
|
| 224 |
+
# Insert default settings if not exists
|
| 225 |
+
await db.execute("""
|
| 226 |
+
INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
|
| 227 |
+
VALUES (1, 6, 1, 30, 30, 'mlp')
|
| 228 |
+
""")
|
| 229 |
+
|
| 230 |
+
await db.commit()
|
| 231 |
+
|
| 232 |
+
# ================ PYDANTIC MODELS ================
|
| 233 |
+
|
| 234 |
+
class SessionCreate(BaseModel):
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
class SessionEnd(BaseModel):
|
| 238 |
+
session_id: int
|
| 239 |
+
|
| 240 |
+
class SettingsUpdate(BaseModel):
|
| 241 |
+
sensitivity: Optional[int] = None
|
| 242 |
+
notification_enabled: Optional[bool] = None
|
| 243 |
+
notification_threshold: Optional[int] = None
|
| 244 |
+
frame_rate: Optional[int] = None
|
| 245 |
+
model_name: Optional[str] = None
|
| 246 |
+
|
| 247 |
+
class VideoTransformTrack(VideoStreamTrack):
|
| 248 |
+
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.track = track
|
| 251 |
+
self.session_id = session_id
|
| 252 |
+
self.get_channel = get_channel
|
| 253 |
+
self.last_inference_time = 0
|
| 254 |
+
self.min_inference_interval = 1 / 60
|
| 255 |
+
self.last_frame = None
|
| 256 |
+
|
| 257 |
+
async def recv(self):
|
| 258 |
+
frame = await self.track.recv()
|
| 259 |
+
img = frame.to_ndarray(format="bgr24")
|
| 260 |
+
if img is None:
|
| 261 |
+
return frame
|
| 262 |
+
|
| 263 |
+
# Normalize size for inference/drawing
|
| 264 |
+
img = cv2.resize(img, (640, 480))
|
| 265 |
+
|
| 266 |
+
now = datetime.now().timestamp()
|
| 267 |
+
do_infer = (now - self.last_inference_time) >= self.min_inference_interval
|
| 268 |
+
|
| 269 |
+
if do_infer:
|
| 270 |
+
self.last_inference_time = now
|
| 271 |
+
|
| 272 |
+
model_name = _cached_model_name
|
| 273 |
+
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 274 |
+
model_name = 'mlp'
|
| 275 |
+
active_pipeline = pipelines.get(model_name)
|
| 276 |
+
|
| 277 |
+
if active_pipeline is not None:
|
| 278 |
+
loop = asyncio.get_event_loop()
|
| 279 |
+
out = await loop.run_in_executor(
|
| 280 |
+
_inference_executor,
|
| 281 |
+
_process_frame_safe,
|
| 282 |
+
active_pipeline,
|
| 283 |
+
img,
|
| 284 |
+
model_name,
|
| 285 |
+
)
|
| 286 |
+
is_focused = out["is_focused"]
|
| 287 |
+
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 288 |
+
metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
|
| 289 |
+
|
| 290 |
+
# Draw face mesh + HUD on the video frame
|
| 291 |
+
h_f, w_f = img.shape[:2]
|
| 292 |
+
lm = out.get("landmarks")
|
| 293 |
+
if lm is not None:
|
| 294 |
+
_draw_face_mesh(img, lm, w_f, h_f)
|
| 295 |
+
_draw_hud(img, out, model_name)
|
| 296 |
+
else:
|
| 297 |
+
is_focused = False
|
| 298 |
+
confidence = 0.0
|
| 299 |
+
metadata = {"model": model_name}
|
| 300 |
+
cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
|
| 301 |
+
cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
|
| 302 |
+
|
| 303 |
+
if self.session_id:
|
| 304 |
+
await store_focus_event(self.session_id, is_focused, confidence, metadata)
|
| 305 |
+
|
| 306 |
+
channel = self.get_channel()
|
| 307 |
+
if channel and channel.readyState == "open":
|
| 308 |
+
try:
|
| 309 |
+
channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
|
| 310 |
+
except Exception:
|
| 311 |
+
pass
|
| 312 |
+
|
| 313 |
+
self.last_frame = img
|
| 314 |
+
elif self.last_frame is not None:
|
| 315 |
+
img = self.last_frame
|
| 316 |
+
|
| 317 |
+
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
| 318 |
+
new_frame.pts = frame.pts
|
| 319 |
+
new_frame.time_base = frame.time_base
|
| 320 |
+
return new_frame
|
| 321 |
+
|
| 322 |
+
# ================ DATABASE OPERATIONS ================
|
| 323 |
+
|
| 324 |
+
async def create_session():
|
| 325 |
+
async with aiosqlite.connect(db_path) as db:
|
| 326 |
+
cursor = await db.execute(
|
| 327 |
+
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 328 |
+
(datetime.now().isoformat(),)
|
| 329 |
+
)
|
| 330 |
+
await db.commit()
|
| 331 |
+
return cursor.lastrowid
|
| 332 |
+
|
| 333 |
+
async def end_session(session_id: int):
|
| 334 |
+
async with aiosqlite.connect(db_path) as db:
|
| 335 |
+
cursor = await db.execute(
|
| 336 |
+
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 337 |
+
(session_id,)
|
| 338 |
+
)
|
| 339 |
+
row = await cursor.fetchone()
|
| 340 |
+
|
| 341 |
+
if not row:
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
start_time_str, total_frames, focused_frames = row
|
| 345 |
+
start_time = datetime.fromisoformat(start_time_str)
|
| 346 |
+
end_time = datetime.now()
|
| 347 |
+
duration = (end_time - start_time).total_seconds()
|
| 348 |
+
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 349 |
+
|
| 350 |
+
await db.execute("""
|
| 351 |
+
UPDATE focus_sessions
|
| 352 |
+
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 353 |
+
WHERE id = ?
|
| 354 |
+
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 355 |
+
|
| 356 |
+
await db.commit()
|
| 357 |
+
|
| 358 |
+
return {
|
| 359 |
+
'session_id': session_id,
|
| 360 |
+
'start_time': start_time_str,
|
| 361 |
+
'end_time': end_time.isoformat(),
|
| 362 |
+
'duration_seconds': int(duration),
|
| 363 |
+
'focus_score': round(focus_score, 3),
|
| 364 |
+
'total_frames': total_frames,
|
| 365 |
+
'focused_frames': focused_frames
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 369 |
+
async with aiosqlite.connect(db_path) as db:
|
| 370 |
+
await db.execute("""
|
| 371 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 372 |
+
VALUES (?, ?, ?, ?, ?)
|
| 373 |
+
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 374 |
+
|
| 375 |
+
await db.execute("""
|
| 376 |
+
UPDATE focus_sessions
|
| 377 |
+
SET total_frames = total_frames + 1,
|
| 378 |
+
focused_frames = focused_frames + ?
|
| 379 |
+
WHERE id = ?
|
| 380 |
+
""", (1 if is_focused else 0, session_id))
|
| 381 |
+
await db.commit()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class _EventBuffer:
|
| 385 |
+
"""Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
|
| 386 |
+
|
| 387 |
+
def __init__(self, flush_interval: float = 2.0):
|
| 388 |
+
self._buf: list = []
|
| 389 |
+
self._lock = asyncio.Lock()
|
| 390 |
+
self._flush_interval = flush_interval
|
| 391 |
+
self._task: asyncio.Task | None = None
|
| 392 |
+
self._total_frames = 0
|
| 393 |
+
self._focused_frames = 0
|
| 394 |
+
|
| 395 |
+
def start(self):
|
| 396 |
+
if self._task is None:
|
| 397 |
+
self._task = asyncio.create_task(self._flush_loop())
|
| 398 |
+
|
| 399 |
+
async def stop(self):
|
| 400 |
+
if self._task:
|
| 401 |
+
self._task.cancel()
|
| 402 |
+
try:
|
| 403 |
+
await self._task
|
| 404 |
+
except asyncio.CancelledError:
|
| 405 |
+
pass
|
| 406 |
+
self._task = None
|
| 407 |
+
await self._flush()
|
| 408 |
+
|
| 409 |
+
def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 410 |
+
self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 411 |
+
self._total_frames += 1
|
| 412 |
+
if is_focused:
|
| 413 |
+
self._focused_frames += 1
|
| 414 |
+
|
| 415 |
+
async def _flush_loop(self):
|
| 416 |
+
while True:
|
| 417 |
+
await asyncio.sleep(self._flush_interval)
|
| 418 |
+
await self._flush()
|
| 419 |
+
|
| 420 |
+
async def _flush(self):
|
| 421 |
+
async with self._lock:
|
| 422 |
+
if not self._buf:
|
| 423 |
+
return
|
| 424 |
+
batch = self._buf[:]
|
| 425 |
+
total = self._total_frames
|
| 426 |
+
focused = self._focused_frames
|
| 427 |
+
self._buf.clear()
|
| 428 |
+
self._total_frames = 0
|
| 429 |
+
self._focused_frames = 0
|
| 430 |
+
|
| 431 |
+
if not batch:
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
session_id = batch[0][0]
|
| 435 |
+
try:
|
| 436 |
+
async with aiosqlite.connect(db_path) as db:
|
| 437 |
+
await db.executemany("""
|
| 438 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 439 |
+
VALUES (?, ?, ?, ?, ?)
|
| 440 |
+
""", batch)
|
| 441 |
+
await db.execute("""
|
| 442 |
+
UPDATE focus_sessions
|
| 443 |
+
SET total_frames = total_frames + ?,
|
| 444 |
+
focused_frames = focused_frames + ?
|
| 445 |
+
WHERE id = ?
|
| 446 |
+
""", (total, focused, session_id))
|
| 447 |
+
await db.commit()
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"[DB] Flush error: {e}")
|
| 450 |
+
|
| 451 |
+
# ================ STARTUP/SHUTDOWN ================
|
| 452 |
+
|
| 453 |
+
pipelines = {
|
| 454 |
+
"geometric": None,
|
| 455 |
+
"mlp": None,
|
| 456 |
+
"hybrid": None,
|
| 457 |
+
"xgboost": None,
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
| 461 |
+
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 462 |
+
max_workers=4,
|
| 463 |
+
thread_name_prefix="inference",
|
| 464 |
+
)
|
| 465 |
+
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
| 466 |
+
# multiple frames are processed in parallel by the thread pool.
|
| 467 |
+
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def _process_frame_safe(pipeline, frame, model_name: str):
|
| 471 |
+
"""Run process_frame in executor with per-pipeline lock."""
|
| 472 |
+
with _pipeline_locks[model_name]:
|
| 473 |
+
return pipeline.process_frame(frame)
|
| 474 |
+
|
| 475 |
+
@app.on_event("startup")
|
| 476 |
+
async def startup_event():
|
| 477 |
+
global pipelines, _cached_model_name
|
| 478 |
+
print(" Starting Focus Guard API...")
|
| 479 |
+
await init_database()
|
| 480 |
+
# Load cached model name from DB
|
| 481 |
+
async with aiosqlite.connect(db_path) as db:
|
| 482 |
+
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 483 |
+
row = await cursor.fetchone()
|
| 484 |
+
if row:
|
| 485 |
+
_cached_model_name = row[0]
|
| 486 |
+
print("[OK] Database initialized")
|
| 487 |
+
|
| 488 |
+
try:
|
| 489 |
+
pipelines["geometric"] = FaceMeshPipeline()
|
| 490 |
+
print("[OK] FaceMeshPipeline (geometric) loaded")
|
| 491 |
+
except Exception as e:
|
| 492 |
+
print(f"[WARN] FaceMeshPipeline unavailable: {e}")
|
| 493 |
+
|
| 494 |
+
try:
|
| 495 |
+
pipelines["mlp"] = MLPPipeline()
|
| 496 |
+
print("[OK] MLPPipeline loaded")
|
| 497 |
+
except Exception as e:
|
| 498 |
+
print(f"[ERR] Failed to load MLPPipeline: {e}")
|
| 499 |
+
|
| 500 |
+
try:
|
| 501 |
+
pipelines["hybrid"] = HybridFocusPipeline()
|
| 502 |
+
print("[OK] HybridFocusPipeline loaded")
|
| 503 |
+
except Exception as e:
|
| 504 |
+
print(f"[WARN] HybridFocusPipeline unavailable: {e}")
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
pipelines["xgboost"] = XGBoostPipeline()
|
| 508 |
+
print("[OK] XGBoostPipeline loaded")
|
| 509 |
+
except Exception as e:
|
| 510 |
+
print(f"[ERR] Failed to load XGBoostPipeline: {e}")
|
| 511 |
+
|
| 512 |
+
@app.on_event("shutdown")
|
| 513 |
+
async def shutdown_event():
|
| 514 |
+
_inference_executor.shutdown(wait=False)
|
| 515 |
+
print(" Shutting down Focus Guard API...")
|
| 516 |
+
|
| 517 |
+
# ================ WEBRTC SIGNALING ================
|
| 518 |
+
|
| 519 |
+
@app.post("/api/webrtc/offer")
|
| 520 |
+
async def webrtc_offer(offer: dict):
|
| 521 |
+
try:
|
| 522 |
+
print(f"Received WebRTC offer")
|
| 523 |
+
|
| 524 |
+
pc = RTCPeerConnection()
|
| 525 |
+
pcs.add(pc)
|
| 526 |
+
|
| 527 |
+
session_id = await create_session()
|
| 528 |
+
print(f"Created session: {session_id}")
|
| 529 |
+
|
| 530 |
+
channel_ref = {"channel": None}
|
| 531 |
+
|
| 532 |
+
@pc.on("datachannel")
|
| 533 |
+
def on_datachannel(channel):
|
| 534 |
+
print(f"Data channel opened")
|
| 535 |
+
channel_ref["channel"] = channel
|
| 536 |
+
|
| 537 |
+
@pc.on("track")
|
| 538 |
+
def on_track(track):
|
| 539 |
+
print(f"Received track: {track.kind}")
|
| 540 |
+
if track.kind == "video":
|
| 541 |
+
local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
|
| 542 |
+
pc.addTrack(local_track)
|
| 543 |
+
print(f"Video track added")
|
| 544 |
+
|
| 545 |
+
@track.on("ended")
|
| 546 |
+
async def on_ended():
|
| 547 |
+
print(f"Track ended")
|
| 548 |
+
|
| 549 |
+
@pc.on("connectionstatechange")
|
| 550 |
+
async def on_connectionstatechange():
|
| 551 |
+
print(f"Connection state changed: {pc.connectionState}")
|
| 552 |
+
if pc.connectionState in ("failed", "closed", "disconnected"):
|
| 553 |
+
try:
|
| 554 |
+
await end_session(session_id)
|
| 555 |
+
except Exception as e:
|
| 556 |
+
print(f"⚠Error ending session: {e}")
|
| 557 |
+
pcs.discard(pc)
|
| 558 |
+
await pc.close()
|
| 559 |
+
|
| 560 |
+
await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
|
| 561 |
+
print(f"Remote description set")
|
| 562 |
+
|
| 563 |
+
answer = await pc.createAnswer()
|
| 564 |
+
await pc.setLocalDescription(answer)
|
| 565 |
+
print(f"Answer created")
|
| 566 |
+
|
| 567 |
+
await _wait_for_ice_gathering(pc)
|
| 568 |
+
print(f"ICE gathering complete")
|
| 569 |
+
|
| 570 |
+
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
|
| 571 |
+
|
| 572 |
+
except Exception as e:
|
| 573 |
+
print(f"WebRTC offer error: {e}")
|
| 574 |
+
import traceback
|
| 575 |
+
traceback.print_exc()
|
| 576 |
+
raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
|
| 577 |
+
|
| 578 |
+
# ================ WEBSOCKET ================
|
| 579 |
+
|
| 580 |
+
@app.websocket("/ws/video")
|
| 581 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 582 |
+
await websocket.accept()
|
| 583 |
+
session_id = None
|
| 584 |
+
frame_count = 0
|
| 585 |
+
running = True
|
| 586 |
+
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 587 |
+
|
| 588 |
+
# Latest frame slot — only the most recent frame is kept, older ones are dropped.
|
| 589 |
+
# Using a dict so nested functions can mutate without nonlocal issues.
|
| 590 |
+
_slot = {"frame": None}
|
| 591 |
+
_frame_ready = asyncio.Event()
|
| 592 |
+
|
| 593 |
+
async def _receive_loop():
|
| 594 |
+
"""Receive messages as fast as possible. Binary = frame, text = control."""
|
| 595 |
+
nonlocal session_id, running
|
| 596 |
+
try:
|
| 597 |
+
while running:
|
| 598 |
+
msg = await websocket.receive()
|
| 599 |
+
msg_type = msg.get("type", "")
|
| 600 |
+
|
| 601 |
+
if msg_type == "websocket.disconnect":
|
| 602 |
+
running = False
|
| 603 |
+
_frame_ready.set()
|
| 604 |
+
return
|
| 605 |
+
|
| 606 |
+
# Binary message → JPEG frame (fast path, no base64)
|
| 607 |
+
raw_bytes = msg.get("bytes")
|
| 608 |
+
if raw_bytes is not None and len(raw_bytes) > 0:
|
| 609 |
+
_slot["frame"] = raw_bytes
|
| 610 |
+
_frame_ready.set()
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
# Text message → JSON control command (or legacy base64 frame)
|
| 614 |
+
text = msg.get("text")
|
| 615 |
+
if not text:
|
| 616 |
+
continue
|
| 617 |
+
data = json.loads(text)
|
| 618 |
+
|
| 619 |
+
if data["type"] == "frame":
|
| 620 |
+
# Legacy base64 path (fallback)
|
| 621 |
+
_slot["frame"] = base64.b64decode(data["image"])
|
| 622 |
+
_frame_ready.set()
|
| 623 |
+
|
| 624 |
+
elif data["type"] == "start_session":
|
| 625 |
+
session_id = await create_session()
|
| 626 |
+
event_buffer.start()
|
| 627 |
+
for p in pipelines.values():
|
| 628 |
+
if p is not None and hasattr(p, "reset_session"):
|
| 629 |
+
p.reset_session()
|
| 630 |
+
await websocket.send_json({"type": "session_started", "session_id": session_id})
|
| 631 |
+
|
| 632 |
+
elif data["type"] == "end_session":
|
| 633 |
+
if session_id:
|
| 634 |
+
await event_buffer.stop()
|
| 635 |
+
summary = await end_session(session_id)
|
| 636 |
+
if summary:
|
| 637 |
+
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 638 |
+
session_id = None
|
| 639 |
+
except WebSocketDisconnect:
|
| 640 |
+
running = False
|
| 641 |
+
_frame_ready.set()
|
| 642 |
+
except Exception as e:
|
| 643 |
+
print(f"[WS] receive error: {e}")
|
| 644 |
+
running = False
|
| 645 |
+
_frame_ready.set()
|
| 646 |
+
|
| 647 |
+
async def _process_loop():
|
| 648 |
+
"""Process only the latest frame, dropping stale ones."""
|
| 649 |
+
nonlocal frame_count, running
|
| 650 |
+
loop = asyncio.get_event_loop()
|
| 651 |
+
while running:
|
| 652 |
+
await _frame_ready.wait()
|
| 653 |
+
_frame_ready.clear()
|
| 654 |
+
if not running:
|
| 655 |
+
return
|
| 656 |
+
|
| 657 |
+
# Grab latest frame and clear slot
|
| 658 |
+
raw = _slot["frame"]
|
| 659 |
+
_slot["frame"] = None
|
| 660 |
+
if raw is None:
|
| 661 |
+
continue
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
nparr = np.frombuffer(raw, np.uint8)
|
| 665 |
+
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 666 |
+
if frame is None:
|
| 667 |
+
continue
|
| 668 |
+
frame = cv2.resize(frame, (640, 480))
|
| 669 |
+
|
| 670 |
+
model_name = _cached_model_name
|
| 671 |
+
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 672 |
+
model_name = "mlp"
|
| 673 |
+
active_pipeline = pipelines.get(model_name)
|
| 674 |
+
|
| 675 |
+
landmarks_list = None
|
| 676 |
+
if active_pipeline is not None:
|
| 677 |
+
out = await loop.run_in_executor(
|
| 678 |
+
_inference_executor,
|
| 679 |
+
_process_frame_safe,
|
| 680 |
+
active_pipeline,
|
| 681 |
+
frame,
|
| 682 |
+
model_name,
|
| 683 |
+
)
|
| 684 |
+
is_focused = out["is_focused"]
|
| 685 |
+
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 686 |
+
|
| 687 |
+
lm = out.get("landmarks")
|
| 688 |
+
if lm is not None:
|
| 689 |
+
# Send all 478 landmarks as flat array for tessellation drawing
|
| 690 |
+
landmarks_list = [
|
| 691 |
+
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 692 |
+
for i in range(lm.shape[0])
|
| 693 |
+
]
|
| 694 |
+
|
| 695 |
+
if session_id:
|
| 696 |
+
event_buffer.add(session_id, is_focused, confidence, {
|
| 697 |
+
"s_face": out.get("s_face", 0.0),
|
| 698 |
+
"s_eye": out.get("s_eye", 0.0),
|
| 699 |
+
"mar": out.get("mar", 0.0),
|
| 700 |
+
"model": model_name,
|
| 701 |
+
})
|
| 702 |
+
else:
|
| 703 |
+
is_focused = False
|
| 704 |
+
confidence = 0.0
|
| 705 |
+
|
| 706 |
+
resp = {
|
| 707 |
+
"type": "detection",
|
| 708 |
+
"focused": is_focused,
|
| 709 |
+
"confidence": round(confidence, 3),
|
| 710 |
+
"model": model_name,
|
| 711 |
+
"fc": frame_count,
|
| 712 |
+
}
|
| 713 |
+
if active_pipeline is not None:
|
| 714 |
+
# Send detailed metrics for HUD
|
| 715 |
+
if out.get("yaw") is not None:
|
| 716 |
+
resp["yaw"] = round(out["yaw"], 1)
|
| 717 |
+
resp["pitch"] = round(out["pitch"], 1)
|
| 718 |
+
resp["roll"] = round(out["roll"], 1)
|
| 719 |
+
if out.get("mar") is not None:
|
| 720 |
+
resp["mar"] = round(out["mar"], 3)
|
| 721 |
+
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 722 |
+
resp["se"] = round(out.get("s_eye", 0), 3)
|
| 723 |
+
if landmarks_list is not None:
|
| 724 |
+
resp["lm"] = landmarks_list
|
| 725 |
+
await websocket.send_json(resp)
|
| 726 |
+
frame_count += 1
|
| 727 |
+
except Exception as e:
|
| 728 |
+
print(f"[WS] process error: {e}")
|
| 729 |
+
|
| 730 |
+
try:
|
| 731 |
+
await asyncio.gather(_receive_loop(), _process_loop())
|
| 732 |
+
except Exception:
|
| 733 |
+
pass
|
| 734 |
+
finally:
|
| 735 |
+
running = False
|
| 736 |
+
if session_id:
|
| 737 |
+
await event_buffer.stop()
|
| 738 |
+
await end_session(session_id)
|
| 739 |
+
|
| 740 |
+
# ================ API ENDPOINTS ================
|
| 741 |
+
|
| 742 |
+
@app.post("/api/sessions/start")
|
| 743 |
+
async def api_start_session():
|
| 744 |
+
session_id = await create_session()
|
| 745 |
+
return {"session_id": session_id}
|
| 746 |
+
|
| 747 |
+
@app.post("/api/sessions/end")
|
| 748 |
+
async def api_end_session(data: SessionEnd):
|
| 749 |
+
summary = await end_session(data.session_id)
|
| 750 |
+
if not summary: raise HTTPException(status_code=404, detail="Session not found")
|
| 751 |
+
return summary
|
| 752 |
+
|
| 753 |
+
@app.get("/api/sessions")
|
| 754 |
+
async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
|
| 755 |
+
async with aiosqlite.connect(db_path) as db:
|
| 756 |
+
db.row_factory = aiosqlite.Row
|
| 757 |
+
|
| 758 |
+
# NEW: If importing/exporting all, remove limit if special flag or high limit
|
| 759 |
+
# For simplicity: if limit is -1, return all
|
| 760 |
+
limit_clause = "LIMIT ? OFFSET ?"
|
| 761 |
+
params = []
|
| 762 |
+
|
| 763 |
+
base_query = "SELECT * FROM focus_sessions"
|
| 764 |
+
where_clause = ""
|
| 765 |
+
|
| 766 |
+
if filter == "today":
|
| 767 |
+
date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
| 768 |
+
where_clause = " WHERE start_time >= ?"
|
| 769 |
+
params.append(date_filter.isoformat())
|
| 770 |
+
elif filter == "week":
|
| 771 |
+
date_filter = datetime.now() - timedelta(days=7)
|
| 772 |
+
where_clause = " WHERE start_time >= ?"
|
| 773 |
+
params.append(date_filter.isoformat())
|
| 774 |
+
elif filter == "month":
|
| 775 |
+
date_filter = datetime.now() - timedelta(days=30)
|
| 776 |
+
where_clause = " WHERE start_time >= ?"
|
| 777 |
+
params.append(date_filter.isoformat())
|
| 778 |
+
elif filter == "all":
|
| 779 |
+
# Just ensure we only get completed sessions or all sessions
|
| 780 |
+
where_clause = " WHERE end_time IS NOT NULL"
|
| 781 |
+
|
| 782 |
+
query = f"{base_query}{where_clause} ORDER BY start_time DESC"
|
| 783 |
+
|
| 784 |
+
# Handle Limit for Exports
|
| 785 |
+
if limit == -1:
|
| 786 |
+
# No limit clause for export
|
| 787 |
+
pass
|
| 788 |
+
else:
|
| 789 |
+
query += f" {limit_clause}"
|
| 790 |
+
params.extend([limit, offset])
|
| 791 |
+
|
| 792 |
+
cursor = await db.execute(query, tuple(params))
|
| 793 |
+
rows = await cursor.fetchall()
|
| 794 |
+
return [dict(row) for row in rows]
|
| 795 |
+
|
| 796 |
+
# --- NEW: Import Endpoint ---
|
| 797 |
+
@app.post("/api/import")
|
| 798 |
+
async def import_sessions(sessions: List[dict]):
|
| 799 |
+
count = 0
|
| 800 |
+
try:
|
| 801 |
+
async with aiosqlite.connect(db_path) as db:
|
| 802 |
+
for session in sessions:
|
| 803 |
+
# Use .get() to handle potential missing fields from older versions or edits
|
| 804 |
+
await db.execute("""
|
| 805 |
+
INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
|
| 806 |
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 807 |
+
""", (
|
| 808 |
+
session.get('start_time'),
|
| 809 |
+
session.get('end_time'),
|
| 810 |
+
session.get('duration_seconds', 0),
|
| 811 |
+
session.get('focus_score', 0.0),
|
| 812 |
+
session.get('total_frames', 0),
|
| 813 |
+
session.get('focused_frames', 0),
|
| 814 |
+
session.get('created_at', session.get('start_time'))
|
| 815 |
+
))
|
| 816 |
+
count += 1
|
| 817 |
+
await db.commit()
|
| 818 |
+
return {"status": "success", "count": count}
|
| 819 |
+
except Exception as e:
|
| 820 |
+
print(f"Import Error: {e}")
|
| 821 |
+
return {"status": "error", "message": str(e)}
|
| 822 |
+
|
| 823 |
+
# --- NEW: Clear History Endpoint ---
|
| 824 |
+
@app.delete("/api/history")
|
| 825 |
+
async def clear_history():
|
| 826 |
+
try:
|
| 827 |
+
async with aiosqlite.connect(db_path) as db:
|
| 828 |
+
# Delete events first (foreign key good practice)
|
| 829 |
+
await db.execute("DELETE FROM focus_events")
|
| 830 |
+
await db.execute("DELETE FROM focus_sessions")
|
| 831 |
+
await db.commit()
|
| 832 |
+
return {"status": "success", "message": "History cleared"}
|
| 833 |
+
except Exception as e:
|
| 834 |
+
return {"status": "error", "message": str(e)}
|
| 835 |
+
|
| 836 |
+
@app.get("/api/sessions/{session_id}")
|
| 837 |
+
async def get_session(session_id: int):
|
| 838 |
+
async with aiosqlite.connect(db_path) as db:
|
| 839 |
+
db.row_factory = aiosqlite.Row
|
| 840 |
+
cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
|
| 841 |
+
row = await cursor.fetchone()
|
| 842 |
+
if not row: raise HTTPException(status_code=404, detail="Session not found")
|
| 843 |
+
session = dict(row)
|
| 844 |
+
cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
|
| 845 |
+
events = [dict(r) for r in await cursor.fetchall()]
|
| 846 |
+
session['events'] = events
|
| 847 |
+
return session
|
| 848 |
+
|
| 849 |
+
@app.get("/api/settings")
|
| 850 |
+
async def get_settings():
|
| 851 |
+
async with aiosqlite.connect(db_path) as db:
|
| 852 |
+
db.row_factory = aiosqlite.Row
|
| 853 |
+
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 854 |
+
row = await cursor.fetchone()
|
| 855 |
+
if row: return dict(row)
|
| 856 |
+
else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
|
| 857 |
+
|
| 858 |
+
@app.put("/api/settings")
|
| 859 |
+
async def update_settings(settings: SettingsUpdate):
|
| 860 |
+
async with aiosqlite.connect(db_path) as db:
|
| 861 |
+
cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
|
| 862 |
+
exists = await cursor.fetchone()
|
| 863 |
+
if not exists:
|
| 864 |
+
await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
|
| 865 |
+
await db.commit()
|
| 866 |
+
|
| 867 |
+
updates = []
|
| 868 |
+
params = []
|
| 869 |
+
if settings.sensitivity is not None:
|
| 870 |
+
updates.append("sensitivity = ?")
|
| 871 |
+
params.append(max(1, min(10, settings.sensitivity)))
|
| 872 |
+
if settings.notification_enabled is not None:
|
| 873 |
+
updates.append("notification_enabled = ?")
|
| 874 |
+
params.append(settings.notification_enabled)
|
| 875 |
+
if settings.notification_threshold is not None:
|
| 876 |
+
updates.append("notification_threshold = ?")
|
| 877 |
+
params.append(max(5, min(300, settings.notification_threshold)))
|
| 878 |
+
if settings.frame_rate is not None:
|
| 879 |
+
updates.append("frame_rate = ?")
|
| 880 |
+
params.append(max(5, min(60, settings.frame_rate)))
|
| 881 |
+
if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
|
| 882 |
+
updates.append("model_name = ?")
|
| 883 |
+
params.append(settings.model_name)
|
| 884 |
+
global _cached_model_name
|
| 885 |
+
_cached_model_name = settings.model_name
|
| 886 |
+
|
| 887 |
+
if updates:
|
| 888 |
+
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 889 |
+
await db.execute(query, params)
|
| 890 |
+
await db.commit()
|
| 891 |
+
return {"status": "success", "updated": len(updates) > 0}
|
| 892 |
+
|
| 893 |
+
@app.get("/api/stats/summary")
|
| 894 |
+
async def get_stats_summary():
|
| 895 |
+
async with aiosqlite.connect(db_path) as db:
|
| 896 |
+
cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 897 |
+
total_sessions = (await cursor.fetchone())[0]
|
| 898 |
+
cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 899 |
+
total_focus_time = (await cursor.fetchone())[0] or 0
|
| 900 |
+
cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 901 |
+
avg_focus_score = (await cursor.fetchone())[0] or 0.0
|
| 902 |
+
cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
|
| 903 |
+
dates = [row[0] for row in await cursor.fetchall()]
|
| 904 |
+
|
| 905 |
+
streak_days = 0
|
| 906 |
+
if dates:
|
| 907 |
+
current_date = datetime.now().date()
|
| 908 |
+
for i, date_str in enumerate(dates):
|
| 909 |
+
session_date = datetime.fromisoformat(date_str).date()
|
| 910 |
+
expected_date = current_date - timedelta(days=i)
|
| 911 |
+
if session_date == expected_date: streak_days += 1
|
| 912 |
+
else: break
|
| 913 |
+
return {
|
| 914 |
+
'total_sessions': total_sessions,
|
| 915 |
+
'total_focus_time': int(total_focus_time),
|
| 916 |
+
'avg_focus_score': round(avg_focus_score, 3),
|
| 917 |
+
'streak_days': streak_days
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
@app.get("/api/models")
|
| 921 |
+
async def get_available_models():
|
| 922 |
+
"""Return list of loaded model names and which is currently active."""
|
| 923 |
+
available = [name for name, p in pipelines.items() if p is not None]
|
| 924 |
+
async with aiosqlite.connect(db_path) as db:
|
| 925 |
+
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 926 |
+
row = await cursor.fetchone()
|
| 927 |
+
current = row[0] if row else "mlp"
|
| 928 |
+
if current not in available and available:
|
| 929 |
+
current = available[0]
|
| 930 |
+
return {"available": available, "current": current}
|
| 931 |
+
|
| 932 |
+
@app.get("/api/mesh-topology")
|
| 933 |
+
async def get_mesh_topology():
|
| 934 |
+
"""Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
|
| 935 |
+
return {"tessellation": _TESSELATION_CONNS}
|
| 936 |
+
|
| 937 |
+
@app.get("/health")
|
| 938 |
+
async def health_check():
|
| 939 |
+
available = [name for name, p in pipelines.items() if p is not None]
|
| 940 |
+
return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
|
| 941 |
+
|
| 942 |
+
# ================ STATIC FILES (SPA SUPPORT) ================
|
| 943 |
+
|
| 944 |
+
# Resolve static dir from this file so it works regardless of cwd
|
| 945 |
+
_STATIC_DIR = Path(__file__).resolve().parent / "static"
|
| 946 |
+
_ASSETS_DIR = _STATIC_DIR / "assets"
|
| 947 |
+
|
| 948 |
+
# 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
|
| 949 |
+
if _ASSETS_DIR.is_dir():
|
| 950 |
+
app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
|
| 951 |
+
|
| 952 |
+
# 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
|
| 953 |
+
@app.get("/{full_path:path}")
|
| 954 |
+
async def serve_react_app(full_path: str, request: Request):
|
| 955 |
+
if full_path.startswith("api") or full_path.startswith("ws"):
|
| 956 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 957 |
+
# Don't serve HTML for asset paths; let them 404 so we don't break module script loading
|
| 958 |
+
if full_path.startswith("assets") or full_path.startswith("assets/"):
|
| 959 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 960 |
+
|
| 961 |
+
index_path = _STATIC_DIR / "index.html"
|
| 962 |
+
if index_path.is_file():
|
| 963 |
+
return FileResponse(str(index_path))
|
| 964 |
+
return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
|
models/README.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/
|
| 2 |
+
|
| 3 |
+
Feature extraction modules and model training scripts.
|
| 4 |
+
|
| 5 |
+
## 1. Feature Extraction
|
| 6 |
+
|
| 7 |
+
Root-level modules form the real-time inference pipeline:
|
| 8 |
+
|
| 9 |
+
| Module | Input | Output |
|
| 10 |
+
|--------|-------|--------|
|
| 11 |
+
| `face_mesh.py` | BGR frame | 478 MediaPipe landmarks |
|
| 12 |
+
| `head_pose.py` | Landmarks, frame size | yaw, pitch, roll, face/eye score, gaze offset, head deviation |
|
| 13 |
+
| `eye_scorer.py` | Landmarks | EAR (left/right/avg), gaze ratio (h/v), MAR |
|
| 14 |
+
| `eye_crop.py` | Landmarks, frame | Cropped eye region images |
|
| 15 |
+
| `eye_classifier.py` | Eye crops or landmarks | Eye open/closed prediction (geometric fallback) |
|
| 16 |
+
| `collect_features.py` | BGR frame | 17-d feature vector + temporal features (PERCLOS, blink rate, etc.) |
|
| 17 |
+
|
| 18 |
+
## 2. Training Scripts
|
| 19 |
+
|
| 20 |
+
| Folder | Model | Command |
|
| 21 |
+
|--------|-------|---------|
|
| 22 |
+
| `mlp/` | PyTorch MLP (64→32, 2-class) | `python -m models.mlp.train` |
|
| 23 |
+
| `xgboost/` | XGBoost (600 trees, depth 8) | `python -m models.xgboost.train` |
|
| 24 |
+
|
| 25 |
+
### mlp/
|
| 26 |
+
|
| 27 |
+
- `train.py` — training loop with early stopping, ClearML opt-in
|
| 28 |
+
- `sweep.py` — hyperparameter search (Optuna: lr, batch_size)
|
| 29 |
+
- `eval_accuracy.py` — load checkpoint and print test metrics
|
| 30 |
+
- Saves to **`checkpoints/mlp_best.pt`**
|
| 31 |
+
|
| 32 |
+
### xgboost/
|
| 33 |
+
|
| 34 |
+
- `train.py` — training with eval-set logging
|
| 35 |
+
- `sweep.py` / `sweep_local.py` — hyperparameter search (Optuna + ClearML)
|
| 36 |
+
- `eval_accuracy.py` — load checkpoint and print test metrics
|
| 37 |
+
- Saves to **`checkpoints/xgboost_face_orientation_best.json`**
|
| 38 |
+
|
| 39 |
+
## 3. Data Loading
|
| 40 |
+
|
| 41 |
+
All training scripts import from `data_preparation.prepare_dataset`:
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from data_preparation.prepare_dataset import get_numpy_splits # XGBoost
|
| 45 |
+
from data_preparation.prepare_dataset import get_dataloaders # MLP (PyTorch)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## 4. Results
|
| 49 |
+
|
| 50 |
+
| Model | Test Accuracy | F1 | ROC-AUC |
|
| 51 |
+
|-------|--------------|-----|---------|
|
| 52 |
+
| XGBoost | 95.87% | 0.959 | 0.991 |
|
| 53 |
+
| MLP | 92.92% | 0.929 | 0.971 |
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
models/cnn/CNN_MODEL/.claude/settings.local.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(# Check Dataset_subset counts echo \"\"=== Dataset_subset/train/open ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/train/open/ | wc -l && echo \"\"=== Dataset_subset/train/closed ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/train/closed/ | wc -l && echo \"\"=== Dataset_subset/val/open ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/val/open/ | wc -l && echo \"\"=== Dataset_subset/val/closed ===\"\" && ls /Users/mohammedalketbi22/Downloads/GAP_Large_project-feature-dataset-model-test-92_30-clean/Dataset_subset/val/closed/)"
|
| 5 |
+
]
|
| 6 |
+
}
|
| 7 |
+
}
|
models/cnn/CNN_MODEL/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
DATA/** filter=lfs diff=lfs merge=lfs -text
|
models/cnn/CNN_MODEL/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Dataset/train/
|
| 2 |
+
Dataset/val/
|
| 3 |
+
Dataset/test/
|
| 4 |
+
.DS_Store
|
models/cnn/CNN_MODEL/README.md
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eye Open / Closed Classifier (YOLOv11-CLS)
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
Binary classifier: **open** vs **closed** eyes.
|
| 5 |
+
Used as a baseline for eye-tracking, drowsiness, or focus detection.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Model team task
|
| 10 |
+
|
| 11 |
+
- **Train** the YOLOv11s-cls eye classifier in a **separate notebook** (data split, epochs, GPU, export `best.pt`).
|
| 12 |
+
- Provide **trained weights** (`best.pt`) for this repo’s evaluation and inference scripts.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Repo contents
|
| 19 |
+
|
| 20 |
+
- **notebooks/eye_classifier_colab.ipynb** — Data download (Kaggle), clean, split, undersample, **evaluate** (needs `best.pt` from model team), export.
|
| 21 |
+
- **scripts/predict_image.py** — Run classifier on single images (needs `best.pt`).
|
| 22 |
+
- **scripts/webcam_live.py** — Live webcam open/closed (needs `best.pt` + optional `weights/face_landmarker.task`).
|
| 23 |
+
- **scripts/video_infer.py** — Run on video files.
|
| 24 |
+
- **scripts/focus_infer.py** — Focus/attention inference.
|
| 25 |
+
- **weights/** — Put `best.pt` here; `face_landmarker.task` is downloaded on first webcam run if missing.
|
| 26 |
+
- **docs/** — Extra docs (e.g. UNNECESSARY_FILES.md if present).
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## Dataset
|
| 31 |
+
|
| 32 |
+
- **Source:** [Kaggle — open/closed eyes](https://www.kaggle.com/datasets/sehriyarmemmedli/open-closed-eyes-dataset)
|
| 33 |
+
- The Colab notebook downloads it via `kagglehub`; no local copy in repo.
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Weights
|
| 38 |
+
|
| 39 |
+
- Put **best.pt** from the model team in **weights/best.pt** (or `runs/classify/runs_cls/eye_open_closed_cpu/weights/best.pt`).
|
| 40 |
+
- For webcam: **face_landmarker.task** is downloaded into **weights/** on first run if missing.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Local setup
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
pip install ultralytics opencv-python mediapipe "numpy<2"
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Optional: use a venv. From repo root:
|
| 51 |
+
- `python scripts/predict_image.py <image.png>`
|
| 52 |
+
- `python scripts/webcam_live.py`
|
| 53 |
+
- `python scripts/video_infer.py` (expects 1.mp4 / 2.mp4 in repo root or set `VIDEOS` env)
|
| 54 |
+
- `python scripts/focus_infer.py`
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## Project structure
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
├── notebooks/
|
| 62 |
+
│ └── eye_classifier_colab.ipynb # Data + eval (no training)
|
| 63 |
+
├── scripts/
|
| 64 |
+
│ ├── predict_image.py
|
| 65 |
+
│ ├── webcam_live.py
|
| 66 |
+
│ ├── video_infer.py
|
| 67 |
+
│ └── focus_infer.py
|
| 68 |
+
├── weights/ # best.pt, face_landmarker.task
|
| 69 |
+
├── docs/ # extra docs
|
| 70 |
+
├── README.md
|
| 71 |
+
└── venv/ # optional
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Training and weight generation: **model team, separate notebook.**
|
models/cnn/CNN_MODEL/notebooks/eye_classifier_colab.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/cnn/CNN_MODEL/scripts/focus_infer.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def list_images(folder: Path):
|
| 12 |
+
exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
|
| 13 |
+
return sorted([p for p in folder.iterdir() if p.suffix.lower() in exts])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def find_weights(project_root: Path) -> Path | None:
|
| 17 |
+
candidates = [
|
| 18 |
+
project_root / "weights" / "best.pt",
|
| 19 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
|
| 20 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
|
| 21 |
+
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
|
| 22 |
+
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
|
| 23 |
+
]
|
| 24 |
+
return next((p for p in candidates if p.is_file()), None)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def detect_eyelid_boundary(gray: np.ndarray) -> np.ndarray | None:
|
| 28 |
+
"""
|
| 29 |
+
Returns an ellipse fit to the largest contour near the eye boundary.
|
| 30 |
+
Output format: (center(x,y), (axis1, axis2), angle) or None.
|
| 31 |
+
"""
|
| 32 |
+
blur = cv2.GaussianBlur(gray, (5, 5), 0)
|
| 33 |
+
edges = cv2.Canny(blur, 40, 120)
|
| 34 |
+
edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
|
| 35 |
+
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 36 |
+
if not contours:
|
| 37 |
+
return None
|
| 38 |
+
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
| 39 |
+
for c in contours:
|
| 40 |
+
if len(c) >= 5 and cv2.contourArea(c) > 50:
|
| 41 |
+
return cv2.fitEllipse(c)
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
|
| 46 |
+
"""
|
| 47 |
+
More robust pupil detection:
|
| 48 |
+
- enhance contrast (CLAHE)
|
| 49 |
+
- find dark blobs
|
| 50 |
+
- score by circularity and proximity to center
|
| 51 |
+
"""
|
| 52 |
+
h, w = gray.shape
|
| 53 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 54 |
+
eq = clahe.apply(gray)
|
| 55 |
+
blur = cv2.GaussianBlur(eq, (7, 7), 0)
|
| 56 |
+
|
| 57 |
+
# Focus on the central region to avoid eyelashes/edges
|
| 58 |
+
cx, cy = w // 2, h // 2
|
| 59 |
+
rx, ry = int(w * 0.3), int(h * 0.3)
|
| 60 |
+
x0, x1 = max(cx - rx, 0), min(cx + rx, w)
|
| 61 |
+
y0, y1 = max(cy - ry, 0), min(cy + ry, h)
|
| 62 |
+
roi = blur[y0:y1, x0:x1]
|
| 63 |
+
|
| 64 |
+
# Inverted threshold to capture dark pupil
|
| 65 |
+
_, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 66 |
+
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
|
| 67 |
+
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
|
| 68 |
+
|
| 69 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 70 |
+
if not contours:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
best = None
|
| 74 |
+
best_score = -1.0
|
| 75 |
+
for c in contours:
|
| 76 |
+
area = cv2.contourArea(c)
|
| 77 |
+
if area < 15:
|
| 78 |
+
continue
|
| 79 |
+
perimeter = cv2.arcLength(c, True)
|
| 80 |
+
if perimeter <= 0:
|
| 81 |
+
continue
|
| 82 |
+
circularity = 4 * np.pi * (area / (perimeter * perimeter))
|
| 83 |
+
if circularity < 0.3:
|
| 84 |
+
continue
|
| 85 |
+
m = cv2.moments(c)
|
| 86 |
+
if m["m00"] == 0:
|
| 87 |
+
continue
|
| 88 |
+
px = int(m["m10"] / m["m00"]) + x0
|
| 89 |
+
py = int(m["m01"] / m["m00"]) + y0
|
| 90 |
+
|
| 91 |
+
# Score by circularity and distance to center
|
| 92 |
+
dist = np.hypot(px - cx, py - cy) / max(w, h)
|
| 93 |
+
score = circularity - dist
|
| 94 |
+
if score > best_score:
|
| 95 |
+
best_score = score
|
| 96 |
+
best = (px, py)
|
| 97 |
+
|
| 98 |
+
return best
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
|
| 102 |
+
"""
|
| 103 |
+
Decide focus based on pupil offset from image center.
|
| 104 |
+
"""
|
| 105 |
+
h, w = img_shape
|
| 106 |
+
cx, cy = w // 2, h // 2
|
| 107 |
+
px, py = pupil_center
|
| 108 |
+
dx = abs(px - cx) / max(w, 1)
|
| 109 |
+
dy = abs(py - cy) / max(h, 1)
|
| 110 |
+
return (dx < 0.12) and (dy < 0.12)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def annotate(img_bgr: np.ndarray, ellipse, pupil_center, focused: bool, cls_label: str, conf: float):
|
| 114 |
+
out = img_bgr.copy()
|
| 115 |
+
if ellipse is not None:
|
| 116 |
+
cv2.ellipse(out, ellipse, (0, 255, 255), 2)
|
| 117 |
+
if pupil_center is not None:
|
| 118 |
+
cv2.circle(out, pupil_center, 4, (0, 0, 255), -1)
|
| 119 |
+
label = f"{cls_label} ({conf:.2f}) | focused={int(focused)}"
|
| 120 |
+
cv2.putText(out, label, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
| 121 |
+
return out
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main():
|
| 125 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 126 |
+
data_dir = project_root / "Dataset"
|
| 127 |
+
alt_data_dir = project_root / "DATA"
|
| 128 |
+
out_dir = project_root / "runs_focus"
|
| 129 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 130 |
+
|
| 131 |
+
weights = find_weights(project_root)
|
| 132 |
+
if weights is None:
|
| 133 |
+
print("Weights not found. Train first.")
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
# Support both Dataset/test/{open,closed} and Dataset/{open,closed}
|
| 137 |
+
def resolve_test_dirs(root: Path):
|
| 138 |
+
test_open = root / "test" / "open"
|
| 139 |
+
test_closed = root / "test" / "closed"
|
| 140 |
+
if test_open.exists() and test_closed.exists():
|
| 141 |
+
return test_open, test_closed
|
| 142 |
+
test_open = root / "open"
|
| 143 |
+
test_closed = root / "closed"
|
| 144 |
+
if test_open.exists() and test_closed.exists():
|
| 145 |
+
return test_open, test_closed
|
| 146 |
+
alt_closed = root / "close"
|
| 147 |
+
if test_open.exists() and alt_closed.exists():
|
| 148 |
+
return test_open, alt_closed
|
| 149 |
+
return None, None
|
| 150 |
+
|
| 151 |
+
test_open, test_closed = resolve_test_dirs(data_dir)
|
| 152 |
+
if (test_open is None or test_closed is None) and alt_data_dir.exists():
|
| 153 |
+
test_open, test_closed = resolve_test_dirs(alt_data_dir)
|
| 154 |
+
|
| 155 |
+
if not test_open.exists() or not test_closed.exists():
|
| 156 |
+
print("Test folders missing. Expected:")
|
| 157 |
+
print(test_open)
|
| 158 |
+
print(test_closed)
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
test_files = list_images(test_open) + list_images(test_closed)
|
| 162 |
+
print("Total test images:", len(test_files))
|
| 163 |
+
max_images = int(os.getenv("MAX_IMAGES", "0"))
|
| 164 |
+
if max_images > 0:
|
| 165 |
+
test_files = test_files[:max_images]
|
| 166 |
+
print("Limiting to MAX_IMAGES:", max_images)
|
| 167 |
+
|
| 168 |
+
model = YOLO(str(weights))
|
| 169 |
+
results = model.predict(test_files, imgsz=224, device="cpu", verbose=False)
|
| 170 |
+
|
| 171 |
+
names = model.names
|
| 172 |
+
for r in results:
|
| 173 |
+
probs = r.probs
|
| 174 |
+
top_idx = int(probs.top1)
|
| 175 |
+
top_conf = float(probs.top1conf)
|
| 176 |
+
pred_label = names[top_idx]
|
| 177 |
+
|
| 178 |
+
img = cv2.imread(r.path)
|
| 179 |
+
if img is None:
|
| 180 |
+
continue
|
| 181 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 182 |
+
|
| 183 |
+
ellipse = detect_eyelid_boundary(gray)
|
| 184 |
+
pupil_center = detect_pupil_center(gray)
|
| 185 |
+
focused = False
|
| 186 |
+
if pred_label.lower() == "open" and pupil_center is not None:
|
| 187 |
+
focused = is_focused(pupil_center, gray.shape)
|
| 188 |
+
|
| 189 |
+
annotated = annotate(img, ellipse, pupil_center, focused, pred_label, top_conf)
|
| 190 |
+
out_path = out_dir / (Path(r.path).stem + "_annotated.jpg")
|
| 191 |
+
cv2.imwrite(str(out_path), annotated)
|
| 192 |
+
|
| 193 |
+
print(f"{Path(r.path).name}: pred={pred_label} conf={top_conf:.3f} focused={focused}")
|
| 194 |
+
|
| 195 |
+
print(f"\nAnnotated outputs saved to: {out_dir}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
main()
|
models/cnn/CNN_MODEL/scripts/predict_image.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run the eye open/closed model on one or more images."""
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from ultralytics import YOLO
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 10 |
+
weight_candidates = [
|
| 11 |
+
project_root / "weights" / "best.pt",
|
| 12 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
|
| 13 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
|
| 14 |
+
]
|
| 15 |
+
weights = next((p for p in weight_candidates if p.is_file()), None)
|
| 16 |
+
if weights is None:
|
| 17 |
+
print("Weights not found. Put best.pt in weights/ or runs/.../weights/ (from model team).")
|
| 18 |
+
sys.exit(1)
|
| 19 |
+
|
| 20 |
+
if len(sys.argv) < 2:
|
| 21 |
+
print("Usage: python scripts/predict_image.py <image1> [image2 ...]")
|
| 22 |
+
print("Example: python scripts/predict_image.py path/to/image.png")
|
| 23 |
+
sys.exit(0)
|
| 24 |
+
|
| 25 |
+
model = YOLO(str(weights))
|
| 26 |
+
names = model.names
|
| 27 |
+
|
| 28 |
+
for path in sys.argv[1:]:
|
| 29 |
+
p = Path(path)
|
| 30 |
+
if not p.is_file():
|
| 31 |
+
print(p, "- file not found")
|
| 32 |
+
continue
|
| 33 |
+
try:
|
| 34 |
+
results = model.predict(str(p), imgsz=224, device="cpu", verbose=False)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(p, "- error:", e)
|
| 37 |
+
continue
|
| 38 |
+
if not results:
|
| 39 |
+
print(p, "- no result")
|
| 40 |
+
continue
|
| 41 |
+
r = results[0]
|
| 42 |
+
top_idx = int(r.probs.top1)
|
| 43 |
+
conf = float(r.probs.top1conf)
|
| 44 |
+
label = names[top_idx]
|
| 45 |
+
print(f"{p.name}: {label} ({conf:.2%})")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
main()
|
models/cnn/CNN_MODEL/scripts/video_infer.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import mediapipe as mp
|
| 12 |
+
except Exception: # pragma: no cover
|
| 13 |
+
mp = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def find_weights(project_root: Path) -> Path | None:
|
| 17 |
+
candidates = [
|
| 18 |
+
project_root / "weights" / "best.pt",
|
| 19 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
|
| 20 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
|
| 21 |
+
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
|
| 22 |
+
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
|
| 23 |
+
]
|
| 24 |
+
return next((p for p in candidates if p.is_file()), None)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
|
| 28 |
+
h, w = gray.shape
|
| 29 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 30 |
+
eq = clahe.apply(gray)
|
| 31 |
+
blur = cv2.GaussianBlur(eq, (7, 7), 0)
|
| 32 |
+
|
| 33 |
+
cx, cy = w // 2, h // 2
|
| 34 |
+
rx, ry = int(w * 0.3), int(h * 0.3)
|
| 35 |
+
x0, x1 = max(cx - rx, 0), min(cx + rx, w)
|
| 36 |
+
y0, y1 = max(cy - ry, 0), min(cy + ry, h)
|
| 37 |
+
roi = blur[y0:y1, x0:x1]
|
| 38 |
+
|
| 39 |
+
_, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 40 |
+
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
|
| 41 |
+
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
|
| 42 |
+
|
| 43 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 44 |
+
if not contours:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
best = None
|
| 48 |
+
best_score = -1.0
|
| 49 |
+
for c in contours:
|
| 50 |
+
area = cv2.contourArea(c)
|
| 51 |
+
if area < 15:
|
| 52 |
+
continue
|
| 53 |
+
perimeter = cv2.arcLength(c, True)
|
| 54 |
+
if perimeter <= 0:
|
| 55 |
+
continue
|
| 56 |
+
circularity = 4 * np.pi * (area / (perimeter * perimeter))
|
| 57 |
+
if circularity < 0.3:
|
| 58 |
+
continue
|
| 59 |
+
m = cv2.moments(c)
|
| 60 |
+
if m["m00"] == 0:
|
| 61 |
+
continue
|
| 62 |
+
px = int(m["m10"] / m["m00"]) + x0
|
| 63 |
+
py = int(m["m01"] / m["m00"]) + y0
|
| 64 |
+
|
| 65 |
+
dist = np.hypot(px - cx, py - cy) / max(w, h)
|
| 66 |
+
score = circularity - dist
|
| 67 |
+
if score > best_score:
|
| 68 |
+
best_score = score
|
| 69 |
+
best = (px, py)
|
| 70 |
+
|
| 71 |
+
return best
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
|
| 75 |
+
h, w = img_shape
|
| 76 |
+
cx = w // 2
|
| 77 |
+
px, _ = pupil_center
|
| 78 |
+
dx = abs(px - cx) / max(w, 1)
|
| 79 |
+
return dx < 0.12
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def classify_frame(model: YOLO, frame: np.ndarray) -> tuple[str, float]:
|
| 83 |
+
# Use classifier directly on frame (assumes frame is eye crop)
|
| 84 |
+
results = model.predict(frame, imgsz=224, device="cpu", verbose=False)
|
| 85 |
+
r = results[0]
|
| 86 |
+
probs = r.probs
|
| 87 |
+
top_idx = int(probs.top1)
|
| 88 |
+
top_conf = float(probs.top1conf)
|
| 89 |
+
pred_label = model.names[top_idx]
|
| 90 |
+
return pred_label, top_conf
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def annotate_frame(frame: np.ndarray, label: str, focused: bool, conf: float, time_sec: float):
|
| 94 |
+
out = frame.copy()
|
| 95 |
+
text = f"{label} | focused={int(focused)} | conf={conf:.2f} | t={time_sec:.2f}s"
|
| 96 |
+
cv2.putText(out, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def write_segments(path: Path, segments: list[tuple[float, float, str]]):
|
| 101 |
+
with path.open("w") as f:
|
| 102 |
+
for start, end, label in segments:
|
| 103 |
+
f.write(f"{start:.2f},{end:.2f},{label}\n")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def process_video(video_path: Path, model: YOLO | None):
|
| 107 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 108 |
+
if not cap.isOpened():
|
| 109 |
+
print(f"Failed to open {video_path}")
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 113 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 114 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 115 |
+
|
| 116 |
+
out_path = video_path.with_name(video_path.stem + "_pred.mp4")
|
| 117 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 118 |
+
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height))
|
| 119 |
+
|
| 120 |
+
csv_path = video_path.with_name(video_path.stem + "_predictions.csv")
|
| 121 |
+
seg_path = video_path.with_name(video_path.stem + "_segments.txt")
|
| 122 |
+
|
| 123 |
+
frame_idx = 0
|
| 124 |
+
last_label = None
|
| 125 |
+
seg_start = 0.0
|
| 126 |
+
segments: list[tuple[float, float, str]] = []
|
| 127 |
+
|
| 128 |
+
with csv_path.open("w") as fcsv:
|
| 129 |
+
fcsv.write("time_sec,label,focused,conf\n")
|
| 130 |
+
if mp is None:
|
| 131 |
+
print("mediapipe is not installed. Falling back to classifier-only mode.")
|
| 132 |
+
use_mp = mp is not None
|
| 133 |
+
if use_mp:
|
| 134 |
+
mp_face_mesh = mp.solutions.face_mesh
|
| 135 |
+
face_mesh = mp_face_mesh.FaceMesh(
|
| 136 |
+
static_image_mode=False,
|
| 137 |
+
max_num_faces=1,
|
| 138 |
+
refine_landmarks=True,
|
| 139 |
+
min_detection_confidence=0.5,
|
| 140 |
+
min_tracking_confidence=0.5,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
while True:
|
| 144 |
+
ret, frame = cap.read()
|
| 145 |
+
if not ret:
|
| 146 |
+
break
|
| 147 |
+
time_sec = frame_idx / fps
|
| 148 |
+
conf = 0.0
|
| 149 |
+
pred_label = "open"
|
| 150 |
+
focused = False
|
| 151 |
+
|
| 152 |
+
if use_mp:
|
| 153 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 154 |
+
res = face_mesh.process(rgb)
|
| 155 |
+
if res.multi_face_landmarks:
|
| 156 |
+
lm = res.multi_face_landmarks[0].landmark
|
| 157 |
+
h, w = frame.shape[:2]
|
| 158 |
+
|
| 159 |
+
# Eye landmarks (MediaPipe FaceMesh)
|
| 160 |
+
left_eye = [33, 160, 158, 133, 153, 144]
|
| 161 |
+
right_eye = [362, 385, 387, 263, 373, 380]
|
| 162 |
+
left_iris = [468, 469, 470, 471]
|
| 163 |
+
right_iris = [473, 474, 475, 476]
|
| 164 |
+
|
| 165 |
+
def pts(idxs):
|
| 166 |
+
return np.array([(int(lm[i].x * w), int(lm[i].y * h)) for i in idxs])
|
| 167 |
+
|
| 168 |
+
def ear(eye_pts):
|
| 169 |
+
# EAR using 6 points
|
| 170 |
+
p1, p2, p3, p4, p5, p6 = eye_pts
|
| 171 |
+
v1 = np.linalg.norm(p2 - p6)
|
| 172 |
+
v2 = np.linalg.norm(p3 - p5)
|
| 173 |
+
h1 = np.linalg.norm(p1 - p4)
|
| 174 |
+
return (v1 + v2) / (2.0 * h1 + 1e-6)
|
| 175 |
+
|
| 176 |
+
le = pts(left_eye)
|
| 177 |
+
re = pts(right_eye)
|
| 178 |
+
le_ear = ear(le)
|
| 179 |
+
re_ear = ear(re)
|
| 180 |
+
ear_avg = (le_ear + re_ear) / 2.0
|
| 181 |
+
|
| 182 |
+
# openness threshold
|
| 183 |
+
pred_label = "open" if ear_avg > 0.22 else "closed"
|
| 184 |
+
|
| 185 |
+
# iris centers
|
| 186 |
+
li = pts(left_iris)
|
| 187 |
+
ri = pts(right_iris)
|
| 188 |
+
li_c = li.mean(axis=0).astype(int)
|
| 189 |
+
ri_c = ri.mean(axis=0).astype(int)
|
| 190 |
+
|
| 191 |
+
# eye centers (midpoint of corners)
|
| 192 |
+
le_c = ((le[0] + le[3]) / 2).astype(int)
|
| 193 |
+
re_c = ((re[0] + re[3]) / 2).astype(int)
|
| 194 |
+
|
| 195 |
+
# focus = iris close to eye center horizontally for both eyes
|
| 196 |
+
le_dx = abs(li_c[0] - le_c[0]) / max(np.linalg.norm(le[0] - le[3]), 1)
|
| 197 |
+
re_dx = abs(ri_c[0] - re_c[0]) / max(np.linalg.norm(re[0] - re[3]), 1)
|
| 198 |
+
focused = (pred_label == "open") and (le_dx < 0.18) and (re_dx < 0.18)
|
| 199 |
+
|
| 200 |
+
# draw eye boundaries
|
| 201 |
+
cv2.polylines(frame, [le], True, (0, 255, 255), 1)
|
| 202 |
+
cv2.polylines(frame, [re], True, (0, 255, 255), 1)
|
| 203 |
+
# draw iris centers
|
| 204 |
+
cv2.circle(frame, tuple(li_c), 3, (0, 0, 255), -1)
|
| 205 |
+
cv2.circle(frame, tuple(ri_c), 3, (0, 0, 255), -1)
|
| 206 |
+
else:
|
| 207 |
+
pred_label = "closed"
|
| 208 |
+
focused = False
|
| 209 |
+
else:
|
| 210 |
+
if model is not None:
|
| 211 |
+
pred_label, conf = classify_frame(model, frame)
|
| 212 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 213 |
+
pupil_center = detect_pupil_center(gray) if pred_label.lower() == "open" else None
|
| 214 |
+
focused = False
|
| 215 |
+
if pred_label.lower() == "open" and pupil_center is not None:
|
| 216 |
+
focused = is_focused(pupil_center, gray.shape)
|
| 217 |
+
|
| 218 |
+
if pred_label.lower() != "open":
|
| 219 |
+
focused = False
|
| 220 |
+
|
| 221 |
+
label = "open_focused" if (pred_label.lower() == "open" and focused) else "open_not_focused"
|
| 222 |
+
if pred_label.lower() != "open":
|
| 223 |
+
label = "closed_not_focused"
|
| 224 |
+
|
| 225 |
+
fcsv.write(f"{time_sec:.2f},{label},{int(focused)},{conf:.4f}\n")
|
| 226 |
+
|
| 227 |
+
if last_label is None:
|
| 228 |
+
last_label = label
|
| 229 |
+
seg_start = time_sec
|
| 230 |
+
elif label != last_label:
|
| 231 |
+
segments.append((seg_start, time_sec, last_label))
|
| 232 |
+
seg_start = time_sec
|
| 233 |
+
last_label = label
|
| 234 |
+
|
| 235 |
+
annotated = annotate_frame(frame, label, focused, conf, time_sec)
|
| 236 |
+
writer.write(annotated)
|
| 237 |
+
frame_idx += 1
|
| 238 |
+
|
| 239 |
+
if last_label is not None:
|
| 240 |
+
end_time = frame_idx / fps
|
| 241 |
+
segments.append((seg_start, end_time, last_label))
|
| 242 |
+
write_segments(seg_path, segments)
|
| 243 |
+
|
| 244 |
+
cap.release()
|
| 245 |
+
writer.release()
|
| 246 |
+
print(f"Saved: {out_path}")
|
| 247 |
+
print(f"CSV: {csv_path}")
|
| 248 |
+
print(f"Segments: {seg_path}")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def main():
|
| 252 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 253 |
+
weights = find_weights(project_root)
|
| 254 |
+
model = YOLO(str(weights)) if weights is not None else None
|
| 255 |
+
|
| 256 |
+
# Default to 1.mp4 and 2.mp4 in project root
|
| 257 |
+
videos = []
|
| 258 |
+
for name in ["1.mp4", "2.mp4"]:
|
| 259 |
+
p = project_root / name
|
| 260 |
+
if p.exists():
|
| 261 |
+
videos.append(p)
|
| 262 |
+
|
| 263 |
+
# Also allow passing paths via env var
|
| 264 |
+
extra = os.getenv("VIDEOS", "")
|
| 265 |
+
for v in [x.strip() for x in extra.split(",") if x.strip()]:
|
| 266 |
+
vp = Path(v)
|
| 267 |
+
if not vp.is_absolute():
|
| 268 |
+
vp = project_root / vp
|
| 269 |
+
if vp.exists():
|
| 270 |
+
videos.append(vp)
|
| 271 |
+
|
| 272 |
+
if not videos:
|
| 273 |
+
print("No videos found. Expected 1.mp4 / 2.mp4 in project root.")
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
for v in videos:
|
| 277 |
+
process_video(v, model)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
models/cnn/CNN_MODEL/scripts/webcam_live.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Live webcam: detect face, crop each eye, run open/closed classifier, show on screen.
|
| 3 |
+
Requires: opencv-python, ultralytics, mediapipe (pip install mediapipe).
|
| 4 |
+
Press 'q' to quit.
|
| 5 |
+
"""
|
| 6 |
+
import urllib.request
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from ultralytics import YOLO
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import mediapipe as mp
|
| 15 |
+
_mp_has_solutions = hasattr(mp, "solutions")
|
| 16 |
+
except ImportError:
|
| 17 |
+
mp = None
|
| 18 |
+
_mp_has_solutions = False
|
| 19 |
+
|
| 20 |
+
# New MediaPipe Tasks API (Face Landmarker) eye indices
|
| 21 |
+
LEFT_EYE_INDICES_NEW = [263, 249, 390, 373, 374, 380, 381, 382, 362, 466, 388, 387, 386, 385, 384, 398]
|
| 22 |
+
RIGHT_EYE_INDICES_NEW = [33, 7, 163, 144, 145, 153, 154, 155, 133, 246, 161, 160, 159, 158, 157, 173]
|
| 23 |
+
# Old Face Mesh (solutions) indices
|
| 24 |
+
LEFT_EYE_INDICES_OLD = [33, 160, 158, 133, 153, 144]
|
| 25 |
+
RIGHT_EYE_INDICES_OLD = [362, 385, 387, 263, 373, 380]
|
| 26 |
+
EYE_PADDING = 0.35
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def find_weights(project_root: Path) -> Path | None:
|
| 30 |
+
candidates = [
|
| 31 |
+
project_root / "weights" / "best.pt",
|
| 32 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
|
| 33 |
+
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
|
| 34 |
+
]
|
| 35 |
+
return next((p for p in candidates if p.is_file()), None)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_eye_roi(frame: np.ndarray, landmarks, indices: list[int]) -> np.ndarray | None:
|
| 39 |
+
h, w = frame.shape[:2]
|
| 40 |
+
pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in indices])
|
| 41 |
+
x_min, y_min = pts.min(axis=0)
|
| 42 |
+
x_max, y_max = pts.max(axis=0)
|
| 43 |
+
dx = max(int((x_max - x_min) * EYE_PADDING), 8)
|
| 44 |
+
dy = max(int((y_max - y_min) * EYE_PADDING), 8)
|
| 45 |
+
x0 = max(0, x_min - dx)
|
| 46 |
+
y0 = max(0, y_min - dy)
|
| 47 |
+
x1 = min(w, x_max + dx)
|
| 48 |
+
y1 = min(h, y_max + dy)
|
| 49 |
+
if x1 <= x0 or y1 <= y0:
|
| 50 |
+
return None
|
| 51 |
+
return frame[y0:y1, x0:x1].copy()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _run_with_solutions(mp, model, cap):
|
| 55 |
+
face_mesh = mp.solutions.face_mesh.FaceMesh(
|
| 56 |
+
static_image_mode=False,
|
| 57 |
+
max_num_faces=1,
|
| 58 |
+
refine_landmarks=True,
|
| 59 |
+
min_detection_confidence=0.5,
|
| 60 |
+
min_tracking_confidence=0.5,
|
| 61 |
+
)
|
| 62 |
+
while True:
|
| 63 |
+
ret, frame = cap.read()
|
| 64 |
+
if not ret:
|
| 65 |
+
break
|
| 66 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 67 |
+
results = face_mesh.process(rgb)
|
| 68 |
+
left_label, left_conf = "—", 0.0
|
| 69 |
+
right_label, right_conf = "—", 0.0
|
| 70 |
+
if results.multi_face_landmarks:
|
| 71 |
+
lm = results.multi_face_landmarks[0].landmark
|
| 72 |
+
for roi, indices, side in [
|
| 73 |
+
(get_eye_roi(frame, lm, LEFT_EYE_INDICES_OLD), LEFT_EYE_INDICES_OLD, "left"),
|
| 74 |
+
(get_eye_roi(frame, lm, RIGHT_EYE_INDICES_OLD), RIGHT_EYE_INDICES_OLD, "right"),
|
| 75 |
+
]:
|
| 76 |
+
if roi is not None and roi.size > 0:
|
| 77 |
+
try:
|
| 78 |
+
pred = model.predict(roi, imgsz=224, device="cpu", verbose=False)
|
| 79 |
+
if pred:
|
| 80 |
+
r = pred[0]
|
| 81 |
+
label = model.names[int(r.probs.top1)]
|
| 82 |
+
conf = float(r.probs.top1conf)
|
| 83 |
+
if side == "left":
|
| 84 |
+
left_label, left_conf = label, conf
|
| 85 |
+
else:
|
| 86 |
+
right_label, right_conf = label, conf
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
cv2.putText(frame, f"L: {left_label} ({left_conf:.0%})", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
| 90 |
+
cv2.putText(frame, f"R: {right_label} ({right_conf:.0%})", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
| 91 |
+
cv2.imshow("Eye open/closed (q to quit)", frame)
|
| 92 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _run_with_tasks(project_root: Path, model, cap):
|
| 97 |
+
from mediapipe.tasks.python import BaseOptions
|
| 98 |
+
from mediapipe.tasks.python.vision import FaceLandmarker, FaceLandmarkerOptions
|
| 99 |
+
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode
|
| 100 |
+
from mediapipe.tasks.python.vision.core import image as image_lib
|
| 101 |
+
|
| 102 |
+
model_path = project_root / "weights" / "face_landmarker.task"
|
| 103 |
+
if not model_path.is_file():
|
| 104 |
+
print("Downloading face_landmarker.task ...")
|
| 105 |
+
url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task"
|
| 106 |
+
urllib.request.urlretrieve(url, model_path)
|
| 107 |
+
print("Done.")
|
| 108 |
+
|
| 109 |
+
options = FaceLandmarkerOptions(
|
| 110 |
+
base_options=BaseOptions(model_asset_path=str(model_path)),
|
| 111 |
+
running_mode=running_mode.VisionTaskRunningMode.IMAGE,
|
| 112 |
+
num_faces=1,
|
| 113 |
+
)
|
| 114 |
+
face_landmarker = FaceLandmarker.create_from_options(options)
|
| 115 |
+
ImageFormat = image_lib.ImageFormat
|
| 116 |
+
|
| 117 |
+
while True:
|
| 118 |
+
ret, frame = cap.read()
|
| 119 |
+
if not ret:
|
| 120 |
+
break
|
| 121 |
+
left_label, left_conf = "—", 0.0
|
| 122 |
+
right_label, right_conf = "—", 0.0
|
| 123 |
+
|
| 124 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 125 |
+
rgb_contiguous = np.ascontiguousarray(rgb)
|
| 126 |
+
mp_image = image_lib.Image(ImageFormat.SRGB, rgb_contiguous)
|
| 127 |
+
result = face_landmarker.detect(mp_image)
|
| 128 |
+
|
| 129 |
+
if result.face_landmarks:
|
| 130 |
+
lm = result.face_landmarks[0]
|
| 131 |
+
for roi, side in [
|
| 132 |
+
(get_eye_roi(frame, lm, LEFT_EYE_INDICES_NEW), "left"),
|
| 133 |
+
(get_eye_roi(frame, lm, RIGHT_EYE_INDICES_NEW), "right"),
|
| 134 |
+
]:
|
| 135 |
+
if roi is not None and roi.size > 0:
|
| 136 |
+
try:
|
| 137 |
+
pred = model.predict(roi, imgsz=224, device="cpu", verbose=False)
|
| 138 |
+
if pred:
|
| 139 |
+
r = pred[0]
|
| 140 |
+
label = model.names[int(r.probs.top1)]
|
| 141 |
+
conf = float(r.probs.top1conf)
|
| 142 |
+
if side == "left":
|
| 143 |
+
left_label, left_conf = label, conf
|
| 144 |
+
else:
|
| 145 |
+
right_label, right_conf = label, conf
|
| 146 |
+
except Exception:
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
cv2.putText(frame, f"L: {left_label} ({left_conf:.0%})", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
| 150 |
+
cv2.putText(frame, f"R: {right_label} ({right_conf:.0%})", (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
| 151 |
+
cv2.imshow("Eye open/closed (q to quit)", frame)
|
| 152 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main():
|
| 157 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 158 |
+
weights = find_weights(project_root)
|
| 159 |
+
if weights is None:
|
| 160 |
+
print("Weights not found. Put best.pt in weights/ or runs/.../weights/ (from model team).")
|
| 161 |
+
return
|
| 162 |
+
if mp is None:
|
| 163 |
+
print("MediaPipe required. Install: pip install mediapipe")
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
model = YOLO(str(weights))
|
| 167 |
+
cap = cv2.VideoCapture(0)
|
| 168 |
+
if not cap.isOpened():
|
| 169 |
+
print("Could not open webcam.")
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
print("Live eye open/closed on your face. Press 'q' to quit.")
|
| 173 |
+
try:
|
| 174 |
+
if _mp_has_solutions:
|
| 175 |
+
_run_with_solutions(mp, model, cap)
|
| 176 |
+
else:
|
| 177 |
+
_run_with_tasks(project_root, model, cap)
|
| 178 |
+
finally:
|
| 179 |
+
cap.release()
|
| 180 |
+
cv2.destroyAllWindows()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
models/cnn/CNN_MODEL/weights/yolo11s-cls.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e2b605d1c8c212b434a75a32759a6f7adf1d2b29c35f76bdccd4c794cb653cf2
|
| 3 |
+
size 13630112
|
models/cnn/__init__.py
ADDED
|
File without changes
|
models/cnn/eye_attention/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
models/cnn/eye_attention/classifier.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EyeClassifier(ABC):
|
| 10 |
+
@property
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def name(self) -> str:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GeometricOnlyClassifier(EyeClassifier):
|
| 21 |
+
@property
|
| 22 |
+
def name(self) -> str:
|
| 23 |
+
return "geometric"
|
| 24 |
+
|
| 25 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 26 |
+
return 1.0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class YOLOv11Classifier(EyeClassifier):
|
| 30 |
+
def __init__(self, checkpoint_path: str, device: str = "cpu"):
|
| 31 |
+
from ultralytics import YOLO
|
| 32 |
+
|
| 33 |
+
self._model = YOLO(checkpoint_path)
|
| 34 |
+
self._device = device
|
| 35 |
+
|
| 36 |
+
names = self._model.names
|
| 37 |
+
self._attentive_idx = None
|
| 38 |
+
for idx, cls_name in names.items():
|
| 39 |
+
if cls_name in ("open", "attentive"):
|
| 40 |
+
self._attentive_idx = idx
|
| 41 |
+
break
|
| 42 |
+
if self._attentive_idx is None:
|
| 43 |
+
self._attentive_idx = max(names.keys())
|
| 44 |
+
print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def name(self) -> str:
|
| 48 |
+
return "yolo"
|
| 49 |
+
|
| 50 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 51 |
+
if not crops_bgr:
|
| 52 |
+
return 1.0
|
| 53 |
+
results = self._model.predict(crops_bgr, device=self._device, verbose=False)
|
| 54 |
+
scores = [float(r.probs.data[self._attentive_idx]) for r in results]
|
| 55 |
+
return sum(scores) / len(scores) if scores else 1.0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class EyeCNNClassifier(EyeClassifier):
|
| 59 |
+
"""Loader for the custom PyTorch EyeCNN (trained on Kaggle eye crops)."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, checkpoint_path: str, device: str = "cpu"):
|
| 62 |
+
import torch
|
| 63 |
+
import torch.nn as nn
|
| 64 |
+
|
| 65 |
+
class EyeCNN(nn.Module):
|
| 66 |
+
def __init__(self, num_classes=2, dropout_rate=0.3):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.conv_layers = nn.Sequential(
|
| 69 |
+
nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
|
| 70 |
+
nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
|
| 71 |
+
nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2),
|
| 72 |
+
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2),
|
| 73 |
+
)
|
| 74 |
+
self.fc_layers = nn.Sequential(
|
| 75 |
+
nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
|
| 76 |
+
nn.Linear(256, 512), nn.ReLU(), nn.Dropout(dropout_rate),
|
| 77 |
+
nn.Linear(512, num_classes),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
return self.fc_layers(self.conv_layers(x))
|
| 82 |
+
|
| 83 |
+
self._device = torch.device(device)
|
| 84 |
+
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
|
| 85 |
+
dropout_rate = checkpoint.get("config", {}).get("dropout_rate", 0.35)
|
| 86 |
+
self._model = EyeCNN(num_classes=2, dropout_rate=dropout_rate)
|
| 87 |
+
self._model.load_state_dict(checkpoint["model_state_dict"])
|
| 88 |
+
self._model.to(self._device)
|
| 89 |
+
self._model.eval()
|
| 90 |
+
|
| 91 |
+
self._transform = None # built lazily
|
| 92 |
+
|
| 93 |
+
def _get_transform(self):
|
| 94 |
+
if self._transform is None:
|
| 95 |
+
from torchvision import transforms
|
| 96 |
+
self._transform = transforms.Compose([
|
| 97 |
+
transforms.ToPILImage(),
|
| 98 |
+
transforms.Resize((96, 96)),
|
| 99 |
+
transforms.ToTensor(),
|
| 100 |
+
transforms.Normalize(
|
| 101 |
+
mean=[0.485, 0.456, 0.406],
|
| 102 |
+
std=[0.229, 0.224, 0.225],
|
| 103 |
+
),
|
| 104 |
+
])
|
| 105 |
+
return self._transform
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def name(self) -> str:
|
| 109 |
+
return "eye_cnn"
|
| 110 |
+
|
| 111 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 112 |
+
if not crops_bgr:
|
| 113 |
+
return 1.0
|
| 114 |
+
|
| 115 |
+
import torch
|
| 116 |
+
import cv2
|
| 117 |
+
|
| 118 |
+
transform = self._get_transform()
|
| 119 |
+
scores = []
|
| 120 |
+
for crop in crops_bgr:
|
| 121 |
+
if crop is None or crop.size == 0:
|
| 122 |
+
scores.append(1.0)
|
| 123 |
+
continue
|
| 124 |
+
rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
| 125 |
+
tensor = transform(rgb).unsqueeze(0).to(self._device)
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
output = self._model(tensor)
|
| 128 |
+
prob = torch.softmax(output, dim=1)[0, 1].item() # prob of "open"
|
| 129 |
+
scores.append(prob)
|
| 130 |
+
return sum(scores) / len(scores)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
_EXT_TO_BACKEND = {".pth": "cnn", ".pt": "yolo"}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_eye_classifier(
|
| 137 |
+
path: str | None = None,
|
| 138 |
+
backend: str = "yolo",
|
| 139 |
+
device: str = "cpu",
|
| 140 |
+
) -> EyeClassifier:
|
| 141 |
+
if backend == "geometric":
|
| 142 |
+
return GeometricOnlyClassifier()
|
| 143 |
+
|
| 144 |
+
if path is None:
|
| 145 |
+
print(f"[CLASSIFIER] No model path for backend {backend!r}, falling back to geometric")
|
| 146 |
+
return GeometricOnlyClassifier()
|
| 147 |
+
|
| 148 |
+
ext = os.path.splitext(path)[1].lower()
|
| 149 |
+
inferred = _EXT_TO_BACKEND.get(ext)
|
| 150 |
+
if inferred and inferred != backend:
|
| 151 |
+
print(f"[CLASSIFIER] File extension {ext!r} implies backend {inferred!r}, "
|
| 152 |
+
f"overriding requested {backend!r}")
|
| 153 |
+
backend = inferred
|
| 154 |
+
|
| 155 |
+
print(f"[CLASSIFIER] backend={backend!r}, path={path!r}")
|
| 156 |
+
|
| 157 |
+
if backend == "cnn":
|
| 158 |
+
return EyeCNNClassifier(path, device=device)
|
| 159 |
+
|
| 160 |
+
if backend == "yolo":
|
| 161 |
+
try:
|
| 162 |
+
return YOLOv11Classifier(path, device=device)
|
| 163 |
+
except ImportError:
|
| 164 |
+
print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
|
| 165 |
+
raise
|
| 166 |
+
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"Unknown eye backend {backend!r}. Choose from: yolo, cnn, geometric"
|
| 169 |
+
)
|
models/cnn/eye_attention/crop.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from models.pretrained.face_mesh.face_mesh import FaceMeshDetector
|
| 5 |
+
|
| 6 |
+
LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
|
| 7 |
+
RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
|
| 8 |
+
|
| 9 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 10 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 11 |
+
|
| 12 |
+
CROP_SIZE = 96
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _bbox_from_landmarks(
|
| 16 |
+
landmarks: np.ndarray,
|
| 17 |
+
indices: list[int],
|
| 18 |
+
frame_w: int,
|
| 19 |
+
frame_h: int,
|
| 20 |
+
expand: float = 0.4,
|
| 21 |
+
) -> tuple[int, int, int, int]:
|
| 22 |
+
pts = landmarks[indices, :2]
|
| 23 |
+
px = pts[:, 0] * frame_w
|
| 24 |
+
py = pts[:, 1] * frame_h
|
| 25 |
+
|
| 26 |
+
x_min, x_max = px.min(), px.max()
|
| 27 |
+
y_min, y_max = py.min(), py.max()
|
| 28 |
+
w = x_max - x_min
|
| 29 |
+
h = y_max - y_min
|
| 30 |
+
cx = (x_min + x_max) / 2
|
| 31 |
+
cy = (y_min + y_max) / 2
|
| 32 |
+
|
| 33 |
+
size = max(w, h) * (1 + expand)
|
| 34 |
+
half = size / 2
|
| 35 |
+
|
| 36 |
+
x1 = int(max(cx - half, 0))
|
| 37 |
+
y1 = int(max(cy - half, 0))
|
| 38 |
+
x2 = int(min(cx + half, frame_w))
|
| 39 |
+
y2 = int(min(cy + half, frame_h))
|
| 40 |
+
|
| 41 |
+
return x1, y1, x2, y2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extract_eye_crops(
|
| 45 |
+
frame: np.ndarray,
|
| 46 |
+
landmarks: np.ndarray,
|
| 47 |
+
expand: float = 0.4,
|
| 48 |
+
crop_size: int = CROP_SIZE,
|
| 49 |
+
) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
|
| 50 |
+
h, w = frame.shape[:2]
|
| 51 |
+
|
| 52 |
+
left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
|
| 53 |
+
right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
|
| 54 |
+
|
| 55 |
+
left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
|
| 56 |
+
right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
|
| 57 |
+
|
| 58 |
+
left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
|
| 59 |
+
right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
|
| 60 |
+
|
| 61 |
+
return left_crop, right_crop, left_bbox, right_bbox
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def crop_to_tensor(crop_bgr: np.ndarray):
|
| 65 |
+
import torch
|
| 66 |
+
|
| 67 |
+
rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 68 |
+
for c in range(3):
|
| 69 |
+
rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
|
| 70 |
+
return torch.from_numpy(rgb.transpose(2, 0, 1))
|
models/cnn/eye_attention/train.py
ADDED
|
File without changes
|
models/cnn/notebooks/EyeCNN.ipynb
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"source": [
|
| 22 |
+
"import os\n",
|
| 23 |
+
"import torch\n",
|
| 24 |
+
"import torch.nn as nn\n",
|
| 25 |
+
"import torch.optim as optim\n",
|
| 26 |
+
"from torch.utils.data import DataLoader\n",
|
| 27 |
+
"from torchvision import datasets, transforms\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"from google.colab import drive\n",
|
| 30 |
+
"drive.mount('/content/drive')\n",
|
| 31 |
+
"!cp -r /content/drive/MyDrive/Dataset_clean /content/\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"#Verify structure\n",
|
| 34 |
+
"for split in ['train', 'val', 'test']:\n",
|
| 35 |
+
" path = f'/content/Dataset_clean/{split}'\n",
|
| 36 |
+
" classes = os.listdir(path)\n",
|
| 37 |
+
" total = sum(len(os.listdir(os.path.join(path, c))) for c in classes)\n",
|
| 38 |
+
" print(f'{split}: {total} images | classes: {classes}')"
|
| 39 |
+
],
|
| 40 |
+
"metadata": {
|
| 41 |
+
"colab": {
|
| 42 |
+
"base_uri": "https://localhost:8080/"
|
| 43 |
+
},
|
| 44 |
+
"id": "sE1F3em-V5go",
|
| 45 |
+
"outputId": "2c73a9a6-a198-468c-a2cc-253b2de7cc3f"
|
| 46 |
+
},
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"outputs": [
|
| 49 |
+
{
|
| 50 |
+
"output_type": "stream",
|
| 51 |
+
"name": "stdout",
|
| 52 |
+
"text": [
|
| 53 |
+
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {
|
| 62 |
+
"id": "nG2bh66rQ56G"
|
| 63 |
+
},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"class EyeCNN(nn.Module):\n",
|
| 67 |
+
" def __init__(self, num_classes=2):\n",
|
| 68 |
+
" super(EyeCNN, self).__init__()\n",
|
| 69 |
+
" self.conv_layers = nn.Sequential(\n",
|
| 70 |
+
" nn.Conv2d(3, 32, 3, 1, 1),\n",
|
| 71 |
+
" nn.BatchNorm2d(32),\n",
|
| 72 |
+
" nn.ReLU(),\n",
|
| 73 |
+
" nn.MaxPool2d(2, 2),\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" nn.Conv2d(32, 64, 3, 1, 1),\n",
|
| 76 |
+
" nn.BatchNorm2d(64),\n",
|
| 77 |
+
" nn.ReLU(),\n",
|
| 78 |
+
" nn.MaxPool2d(2, 2),\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" nn.Conv2d(64, 128, 3, 1, 1),\n",
|
| 81 |
+
" nn.BatchNorm2d(128),\n",
|
| 82 |
+
" nn.ReLU(),\n",
|
| 83 |
+
" nn.MaxPool2d(2, 2),\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" nn.Conv2d(128, 256, 3, 1, 1),\n",
|
| 86 |
+
" nn.BatchNorm2d(256),\n",
|
| 87 |
+
" nn.ReLU(),\n",
|
| 88 |
+
" nn.MaxPool2d(2, 2)\n",
|
| 89 |
+
" )\n",
|
| 90 |
+
"\n",
|
| 91 |
+
" self.fc_layers = nn.Sequential(\n",
|
| 92 |
+
" nn.AdaptiveAvgPool2d((1, 1)),\n",
|
| 93 |
+
" nn.Flatten(),\n",
|
| 94 |
+
" nn.Linear(256, 512),\n",
|
| 95 |
+
" nn.ReLU(),\n",
|
| 96 |
+
" nn.Dropout(0.35),\n",
|
| 97 |
+
" nn.Linear(512, num_classes)\n",
|
| 98 |
+
" )\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" def forward(self, x):\n",
|
| 101 |
+
" x = self.conv_layers(x)\n",
|
| 102 |
+
" x = self.fc_layers(x)\n",
|
| 103 |
+
" return x"
|
| 104 |
+
]
|
| 105 |
+
}
|
| 106 |
+
]
|
| 107 |
+
}
|
models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/cnn/notebooks/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# GAP Large Project
|
models/collect_features.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
import collections
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
if _PROJECT_ROOT not in sys.path:
|
| 14 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 15 |
+
|
| 16 |
+
from models.face_mesh import FaceMeshDetector
|
| 17 |
+
from models.head_pose import HeadPoseEstimator
|
| 18 |
+
from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
|
| 19 |
+
|
| 20 |
+
FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 21 |
+
GREEN = (0, 255, 0)
|
| 22 |
+
RED = (0, 0, 255)
|
| 23 |
+
WHITE = (255, 255, 255)
|
| 24 |
+
YELLOW = (0, 255, 255)
|
| 25 |
+
ORANGE = (0, 165, 255)
|
| 26 |
+
GRAY = (120, 120, 120)
|
| 27 |
+
|
| 28 |
+
FEATURE_NAMES = [
|
| 29 |
+
"ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar",
|
| 30 |
+
"yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation",
|
| 31 |
+
"perclos", "blink_rate", "closure_duration", "yawn_duration",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
NUM_FEATURES = len(FEATURE_NAMES)
|
| 35 |
+
assert NUM_FEATURES == 17
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TemporalTracker:
|
| 39 |
+
EAR_BLINK_THRESH = 0.21
|
| 40 |
+
MAR_YAWN_THRESH = 0.55
|
| 41 |
+
PERCLOS_WINDOW = 60
|
| 42 |
+
BLINK_WINDOW_SEC = 30.0
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW)
|
| 46 |
+
self.blink_timestamps = collections.deque()
|
| 47 |
+
self._eyes_closed = False
|
| 48 |
+
self._closure_start = None
|
| 49 |
+
self._yawn_start = None
|
| 50 |
+
|
| 51 |
+
def update(self, ear_avg, mar, now=None):
|
| 52 |
+
if now is None:
|
| 53 |
+
now = time.time()
|
| 54 |
+
|
| 55 |
+
closed = ear_avg < self.EAR_BLINK_THRESH
|
| 56 |
+
self.ear_history.append(1.0 if closed else 0.0)
|
| 57 |
+
perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0
|
| 58 |
+
|
| 59 |
+
if self._eyes_closed and not closed:
|
| 60 |
+
self.blink_timestamps.append(now)
|
| 61 |
+
self._eyes_closed = closed
|
| 62 |
+
|
| 63 |
+
cutoff = now - self.BLINK_WINDOW_SEC
|
| 64 |
+
while self.blink_timestamps and self.blink_timestamps[0] < cutoff:
|
| 65 |
+
self.blink_timestamps.popleft()
|
| 66 |
+
blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC)
|
| 67 |
+
|
| 68 |
+
if closed:
|
| 69 |
+
if self._closure_start is None:
|
| 70 |
+
self._closure_start = now
|
| 71 |
+
closure_dur = now - self._closure_start
|
| 72 |
+
else:
|
| 73 |
+
self._closure_start = None
|
| 74 |
+
closure_dur = 0.0
|
| 75 |
+
|
| 76 |
+
yawning = mar > self.MAR_YAWN_THRESH
|
| 77 |
+
if yawning:
|
| 78 |
+
if self._yawn_start is None:
|
| 79 |
+
self._yawn_start = now
|
| 80 |
+
yawn_dur = now - self._yawn_start
|
| 81 |
+
else:
|
| 82 |
+
self._yawn_start = None
|
| 83 |
+
yawn_dur = 0.0
|
| 84 |
+
|
| 85 |
+
return perclos, blink_rate, closure_dur, yawn_dur
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal,
|
| 89 |
+
*, _pre=None):
|
| 90 |
+
from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
|
| 91 |
+
|
| 92 |
+
p = _pre or {}
|
| 93 |
+
|
| 94 |
+
ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR))
|
| 95 |
+
ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR))
|
| 96 |
+
ear_avg = (ear_left + ear_right) / 2.0
|
| 97 |
+
|
| 98 |
+
if "h_gaze" in p and "v_gaze" in p:
|
| 99 |
+
h_gaze, v_gaze = p["h_gaze"], p["v_gaze"]
|
| 100 |
+
else:
|
| 101 |
+
h_gaze, v_gaze = compute_gaze_ratio(landmarks)
|
| 102 |
+
|
| 103 |
+
mar = p.get("mar", compute_mar(landmarks))
|
| 104 |
+
|
| 105 |
+
angles = p.get("angles")
|
| 106 |
+
if angles is None:
|
| 107 |
+
angles = head_pose.estimate(landmarks, w, h)
|
| 108 |
+
yaw = angles[0] if angles else 0.0
|
| 109 |
+
pitch = angles[1] if angles else 0.0
|
| 110 |
+
roll = angles[2] if angles else 0.0
|
| 111 |
+
|
| 112 |
+
s_face = p.get("s_face", head_pose.score(landmarks, w, h))
|
| 113 |
+
s_eye = p.get("s_eye", eye_scorer.score(landmarks))
|
| 114 |
+
|
| 115 |
+
gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
|
| 116 |
+
head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream
|
| 117 |
+
|
| 118 |
+
perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
|
| 119 |
+
|
| 120 |
+
return np.array([
|
| 121 |
+
ear_left, ear_right, ear_avg,
|
| 122 |
+
h_gaze, v_gaze,
|
| 123 |
+
mar,
|
| 124 |
+
yaw, pitch, roll,
|
| 125 |
+
s_face, s_eye,
|
| 126 |
+
gaze_offset,
|
| 127 |
+
head_deviation,
|
| 128 |
+
perclos, blink_rate, closure_dur, yawn_dur,
|
| 129 |
+
], dtype=np.float32)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def quality_report(labels):
|
| 133 |
+
n = len(labels)
|
| 134 |
+
n1 = int((labels == 1).sum())
|
| 135 |
+
n0 = n - n1
|
| 136 |
+
transitions = int(np.sum(np.diff(labels) != 0))
|
| 137 |
+
duration_sec = n / 30.0 # approximate at 30fps
|
| 138 |
+
|
| 139 |
+
warnings = []
|
| 140 |
+
|
| 141 |
+
print(f"\n{'='*50}")
|
| 142 |
+
print(f" DATA QUALITY REPORT")
|
| 143 |
+
print(f"{'='*50}")
|
| 144 |
+
print(f" Total samples : {n}")
|
| 145 |
+
print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)")
|
| 146 |
+
print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)")
|
| 147 |
+
print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)")
|
| 148 |
+
print(f" Transitions : {transitions}")
|
| 149 |
+
if transitions > 0:
|
| 150 |
+
print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)")
|
| 151 |
+
|
| 152 |
+
# checks
|
| 153 |
+
if duration_sec < 120:
|
| 154 |
+
warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)")
|
| 155 |
+
|
| 156 |
+
if n < 3000:
|
| 157 |
+
warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)")
|
| 158 |
+
|
| 159 |
+
balance = n1 / max(n, 1)
|
| 160 |
+
if balance < 0.3 or balance > 0.7:
|
| 161 |
+
warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused")
|
| 162 |
+
|
| 163 |
+
if transitions < 10:
|
| 164 |
+
warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+")
|
| 165 |
+
|
| 166 |
+
if transitions == 1:
|
| 167 |
+
warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — "
|
| 168 |
+
"model will learn temporal position, not focus patterns")
|
| 169 |
+
|
| 170 |
+
if warnings:
|
| 171 |
+
print(f"\n ⚠️ WARNINGS ({len(warnings)}):")
|
| 172 |
+
for w in warnings:
|
| 173 |
+
print(f" • {w}")
|
| 174 |
+
print(f"\n Consider re-recording this session.")
|
| 175 |
+
else:
|
| 176 |
+
print(f"\n ✅ All checks passed!")
|
| 177 |
+
|
| 178 |
+
print(f"{'='*50}\n")
|
| 179 |
+
return len(warnings) == 0
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
# Main
|
| 184 |
+
def main():
|
| 185 |
+
parser = argparse.ArgumentParser()
|
| 186 |
+
parser.add_argument("--name", type=str, default="session",
|
| 187 |
+
help="Your name or session ID")
|
| 188 |
+
parser.add_argument("--camera", type=int, default=0,
|
| 189 |
+
help="Camera index")
|
| 190 |
+
parser.add_argument("--duration", type=int, default=600,
|
| 191 |
+
help="Max recording time (seconds, default 10 min)")
|
| 192 |
+
parser.add_argument("--output-dir", type=str,
|
| 193 |
+
default=os.path.join(_PROJECT_ROOT, "data", "collected_data"),
|
| 194 |
+
help="Where to save .npz files")
|
| 195 |
+
args = parser.parse_args()
|
| 196 |
+
|
| 197 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
detector = FaceMeshDetector()
|
| 200 |
+
head_pose = HeadPoseEstimator()
|
| 201 |
+
eye_scorer = EyeBehaviourScorer()
|
| 202 |
+
temporal = TemporalTracker()
|
| 203 |
+
|
| 204 |
+
cap = cv2.VideoCapture(args.camera)
|
| 205 |
+
if not cap.isOpened():
|
| 206 |
+
print("[COLLECT] ERROR: can't open camera")
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
print("[COLLECT] Data Collection Tool")
|
| 210 |
+
print(f"[COLLECT] Session: {args.name}, max {args.duration}s")
|
| 211 |
+
print(f"[COLLECT] Features per frame: {NUM_FEATURES}")
|
| 212 |
+
print("[COLLECT] Controls:")
|
| 213 |
+
print(" 1 = FOCUSED (looking at screen normally)")
|
| 214 |
+
print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)")
|
| 215 |
+
print(" p = pause")
|
| 216 |
+
print(" q = save & quit")
|
| 217 |
+
print()
|
| 218 |
+
print("[COLLECT] TIPS for good data:")
|
| 219 |
+
print(" • Switch between 1 and 0 every 10-30 seconds")
|
| 220 |
+
print(" • Aim for 20+ transitions total")
|
| 221 |
+
print(" • Act out varied scenarios: reading, phone, talking, drowsy")
|
| 222 |
+
print(" • Record at least 5 minutes")
|
| 223 |
+
print()
|
| 224 |
+
|
| 225 |
+
features_list = []
|
| 226 |
+
labels_list = []
|
| 227 |
+
label = None # None = paused
|
| 228 |
+
transitions = 0 # count label switches
|
| 229 |
+
prev_label = None
|
| 230 |
+
status = "PAUSED -- press 1 (focused) or 0 (not focused)"
|
| 231 |
+
t_start = time.time()
|
| 232 |
+
prev_time = time.time()
|
| 233 |
+
fps = 0.0
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
while True:
|
| 237 |
+
elapsed = time.time() - t_start
|
| 238 |
+
if elapsed > args.duration:
|
| 239 |
+
print(f"[COLLECT] Time limit ({args.duration}s)")
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
ret, frame = cap.read()
|
| 243 |
+
if not ret:
|
| 244 |
+
break
|
| 245 |
+
|
| 246 |
+
h, w = frame.shape[:2]
|
| 247 |
+
landmarks = detector.process(frame)
|
| 248 |
+
face_ok = landmarks is not None
|
| 249 |
+
|
| 250 |
+
if face_ok and label is not None:
|
| 251 |
+
vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
|
| 252 |
+
features_list.append(vec)
|
| 253 |
+
labels_list.append(label)
|
| 254 |
+
|
| 255 |
+
if prev_label is not None and label != prev_label:
|
| 256 |
+
transitions += 1
|
| 257 |
+
prev_label = label
|
| 258 |
+
|
| 259 |
+
now = time.time()
|
| 260 |
+
fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
|
| 261 |
+
prev_time = now
|
| 262 |
+
|
| 263 |
+
# --- draw UI ---
|
| 264 |
+
n = len(labels_list)
|
| 265 |
+
n1 = sum(1 for x in labels_list if x == 1)
|
| 266 |
+
n0 = n - n1
|
| 267 |
+
remaining = max(0, args.duration - elapsed)
|
| 268 |
+
|
| 269 |
+
bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80))
|
| 270 |
+
cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1)
|
| 271 |
+
cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA)
|
| 272 |
+
cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}",
|
| 273 |
+
(10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA)
|
| 274 |
+
cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
|
| 275 |
+
cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA)
|
| 276 |
+
|
| 277 |
+
if n > 0:
|
| 278 |
+
bar_w = min(w - 20, 300)
|
| 279 |
+
bar_x = w - bar_w - 10
|
| 280 |
+
bar_y = 58
|
| 281 |
+
frac = n1 / n
|
| 282 |
+
cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1)
|
| 283 |
+
cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1)
|
| 284 |
+
cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8),
|
| 285 |
+
FONT, 0.3, GRAY, 1, cv2.LINE_AA)
|
| 286 |
+
|
| 287 |
+
if not face_ok:
|
| 288 |
+
cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA)
|
| 289 |
+
|
| 290 |
+
# red dot = recording
|
| 291 |
+
if label is not None and face_ok:
|
| 292 |
+
cv2.circle(frame, (w - 20, 80), 8, RED, -1)
|
| 293 |
+
|
| 294 |
+
# live warnings
|
| 295 |
+
warn_y = h - 35
|
| 296 |
+
if n > 100 and transitions < 3:
|
| 297 |
+
cv2.putText(frame, "! Switch more often (aim for 20+ transitions)",
|
| 298 |
+
(10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
|
| 299 |
+
warn_y -= 18
|
| 300 |
+
if elapsed > 30 and n > 0:
|
| 301 |
+
bal = n1 / n
|
| 302 |
+
if bal < 0.25 or bal > 0.75:
|
| 303 |
+
cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other",
|
| 304 |
+
(10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
|
| 305 |
+
warn_y -= 18
|
| 306 |
+
|
| 307 |
+
cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit",
|
| 308 |
+
(10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA)
|
| 309 |
+
|
| 310 |
+
cv2.imshow("FocusGuard -- Data Collection", frame)
|
| 311 |
+
|
| 312 |
+
key = cv2.waitKey(1) & 0xFF
|
| 313 |
+
if key == ord("1"):
|
| 314 |
+
label = 1
|
| 315 |
+
status = "Recording: FOCUSED"
|
| 316 |
+
print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})")
|
| 317 |
+
elif key == ord("0"):
|
| 318 |
+
label = 0
|
| 319 |
+
status = "Recording: NOT FOCUSED"
|
| 320 |
+
print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})")
|
| 321 |
+
elif key == ord("p"):
|
| 322 |
+
label = None
|
| 323 |
+
status = "PAUSED"
|
| 324 |
+
print(f"[COLLECT] paused (n={n})")
|
| 325 |
+
elif key == ord("q"):
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
finally:
|
| 329 |
+
cap.release()
|
| 330 |
+
cv2.destroyAllWindows()
|
| 331 |
+
detector.close()
|
| 332 |
+
|
| 333 |
+
if len(features_list) > 0:
|
| 334 |
+
feats = np.stack(features_list)
|
| 335 |
+
labs = np.array(labels_list, dtype=np.int64)
|
| 336 |
+
|
| 337 |
+
ts = time.strftime("%Y%m%d_%H%M%S")
|
| 338 |
+
fname = f"{args.name}_{ts}.npz"
|
| 339 |
+
fpath = os.path.join(args.output_dir, fname)
|
| 340 |
+
np.savez(fpath,
|
| 341 |
+
features=feats,
|
| 342 |
+
labels=labs,
|
| 343 |
+
feature_names=np.array(FEATURE_NAMES))
|
| 344 |
+
|
| 345 |
+
print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}")
|
| 346 |
+
print(f" Shape: {feats.shape} ({NUM_FEATURES} features)")
|
| 347 |
+
|
| 348 |
+
quality_report(labs)
|
| 349 |
+
else:
|
| 350 |
+
print("\n[COLLECT] No data collected")
|
| 351 |
+
|
| 352 |
+
print("[COLLECT] Done")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if __name__ == "__main__":
|
| 356 |
+
main()
|
models/eye_classifier.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EyeClassifier(ABC):
|
| 9 |
+
@property
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def name(self) -> str:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GeometricOnlyClassifier(EyeClassifier):
|
| 20 |
+
@property
|
| 21 |
+
def name(self) -> str:
|
| 22 |
+
return "geometric"
|
| 23 |
+
|
| 24 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 25 |
+
return 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class YOLOv11Classifier(EyeClassifier):
|
| 29 |
+
def __init__(self, checkpoint_path: str, device: str = "cpu"):
|
| 30 |
+
from ultralytics import YOLO
|
| 31 |
+
|
| 32 |
+
self._model = YOLO(checkpoint_path)
|
| 33 |
+
self._device = device
|
| 34 |
+
|
| 35 |
+
names = self._model.names
|
| 36 |
+
self._attentive_idx = None
|
| 37 |
+
for idx, cls_name in names.items():
|
| 38 |
+
if cls_name in ("open", "attentive"):
|
| 39 |
+
self._attentive_idx = idx
|
| 40 |
+
break
|
| 41 |
+
if self._attentive_idx is None:
|
| 42 |
+
self._attentive_idx = max(names.keys())
|
| 43 |
+
print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def name(self) -> str:
|
| 47 |
+
return "yolo"
|
| 48 |
+
|
| 49 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 50 |
+
if not crops_bgr:
|
| 51 |
+
return 1.0
|
| 52 |
+
results = self._model.predict(crops_bgr, device=self._device, verbose=False)
|
| 53 |
+
scores = [float(r.probs.data[self._attentive_idx]) for r in results]
|
| 54 |
+
return sum(scores) / len(scores) if scores else 1.0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_eye_classifier(
|
| 58 |
+
path: str | None = None,
|
| 59 |
+
backend: str = "yolo",
|
| 60 |
+
device: str = "cpu",
|
| 61 |
+
) -> EyeClassifier:
|
| 62 |
+
if path is None or backend == "geometric":
|
| 63 |
+
return GeometricOnlyClassifier()
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
return YOLOv11Classifier(path, device=device)
|
| 67 |
+
except ImportError:
|
| 68 |
+
print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
|
| 69 |
+
raise
|
models/eye_crop.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from models.face_mesh import FaceMeshDetector
|
| 5 |
+
|
| 6 |
+
LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
|
| 7 |
+
RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
|
| 8 |
+
|
| 9 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 10 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 11 |
+
|
| 12 |
+
CROP_SIZE = 96
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _bbox_from_landmarks(
|
| 16 |
+
landmarks: np.ndarray,
|
| 17 |
+
indices: list[int],
|
| 18 |
+
frame_w: int,
|
| 19 |
+
frame_h: int,
|
| 20 |
+
expand: float = 0.4,
|
| 21 |
+
) -> tuple[int, int, int, int]:
|
| 22 |
+
pts = landmarks[indices, :2]
|
| 23 |
+
px = pts[:, 0] * frame_w
|
| 24 |
+
py = pts[:, 1] * frame_h
|
| 25 |
+
|
| 26 |
+
x_min, x_max = px.min(), px.max()
|
| 27 |
+
y_min, y_max = py.min(), py.max()
|
| 28 |
+
w = x_max - x_min
|
| 29 |
+
h = y_max - y_min
|
| 30 |
+
cx = (x_min + x_max) / 2
|
| 31 |
+
cy = (y_min + y_max) / 2
|
| 32 |
+
|
| 33 |
+
size = max(w, h) * (1 + expand)
|
| 34 |
+
half = size / 2
|
| 35 |
+
|
| 36 |
+
x1 = int(max(cx - half, 0))
|
| 37 |
+
y1 = int(max(cy - half, 0))
|
| 38 |
+
x2 = int(min(cx + half, frame_w))
|
| 39 |
+
y2 = int(min(cy + half, frame_h))
|
| 40 |
+
|
| 41 |
+
return x1, y1, x2, y2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extract_eye_crops(
|
| 45 |
+
frame: np.ndarray,
|
| 46 |
+
landmarks: np.ndarray,
|
| 47 |
+
expand: float = 0.4,
|
| 48 |
+
crop_size: int = CROP_SIZE,
|
| 49 |
+
) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
|
| 50 |
+
h, w = frame.shape[:2]
|
| 51 |
+
|
| 52 |
+
left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
|
| 53 |
+
right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
|
| 54 |
+
|
| 55 |
+
left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
|
| 56 |
+
right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
|
| 57 |
+
|
| 58 |
+
if left_crop.size == 0:
|
| 59 |
+
left_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
|
| 60 |
+
else:
|
| 61 |
+
left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
|
| 62 |
+
|
| 63 |
+
if right_crop.size == 0:
|
| 64 |
+
right_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
|
| 65 |
+
else:
|
| 66 |
+
right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
|
| 67 |
+
|
| 68 |
+
return left_crop, right_crop, left_bbox, right_bbox
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def crop_to_tensor(crop_bgr: np.ndarray):
|
| 72 |
+
import torch
|
| 73 |
+
|
| 74 |
+
rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 75 |
+
for c in range(3):
|
| 76 |
+
rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
|
| 77 |
+
return torch.from_numpy(rgb.transpose(2, 0, 1))
|
models/eye_scorer.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
_LEFT_EYE_EAR = [33, 160, 158, 133, 153, 145]
|
| 6 |
+
_RIGHT_EYE_EAR = [362, 385, 387, 263, 373, 380]
|
| 7 |
+
|
| 8 |
+
_LEFT_IRIS_CENTER = 468
|
| 9 |
+
_RIGHT_IRIS_CENTER = 473
|
| 10 |
+
|
| 11 |
+
_LEFT_EYE_INNER = 133
|
| 12 |
+
_LEFT_EYE_OUTER = 33
|
| 13 |
+
_RIGHT_EYE_INNER = 362
|
| 14 |
+
_RIGHT_EYE_OUTER = 263
|
| 15 |
+
|
| 16 |
+
_LEFT_EYE_TOP = 159
|
| 17 |
+
_LEFT_EYE_BOTTOM = 145
|
| 18 |
+
_RIGHT_EYE_TOP = 386
|
| 19 |
+
_RIGHT_EYE_BOTTOM = 374
|
| 20 |
+
|
| 21 |
+
_MOUTH_TOP = 13
|
| 22 |
+
_MOUTH_BOTTOM = 14
|
| 23 |
+
_MOUTH_LEFT = 78
|
| 24 |
+
_MOUTH_RIGHT = 308
|
| 25 |
+
_MOUTH_UPPER_1 = 82
|
| 26 |
+
_MOUTH_UPPER_2 = 312
|
| 27 |
+
_MOUTH_LOWER_1 = 87
|
| 28 |
+
_MOUTH_LOWER_2 = 317
|
| 29 |
+
|
| 30 |
+
MAR_YAWN_THRESHOLD = 0.55
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _distance(p1: np.ndarray, p2: np.ndarray) -> float:
|
| 34 |
+
return float(np.linalg.norm(p1 - p2))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def compute_ear(landmarks: np.ndarray, eye_indices: list[int]) -> float:
|
| 38 |
+
p1 = landmarks[eye_indices[0], :2]
|
| 39 |
+
p2 = landmarks[eye_indices[1], :2]
|
| 40 |
+
p3 = landmarks[eye_indices[2], :2]
|
| 41 |
+
p4 = landmarks[eye_indices[3], :2]
|
| 42 |
+
p5 = landmarks[eye_indices[4], :2]
|
| 43 |
+
p6 = landmarks[eye_indices[5], :2]
|
| 44 |
+
|
| 45 |
+
vertical1 = _distance(p2, p6)
|
| 46 |
+
vertical2 = _distance(p3, p5)
|
| 47 |
+
horizontal = _distance(p1, p4)
|
| 48 |
+
|
| 49 |
+
if horizontal < 1e-6:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
return (vertical1 + vertical2) / (2.0 * horizontal)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_avg_ear(landmarks: np.ndarray) -> float:
|
| 56 |
+
left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
|
| 57 |
+
right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
|
| 58 |
+
return (left_ear + right_ear) / 2.0
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def compute_gaze_ratio(landmarks: np.ndarray) -> tuple[float, float]:
|
| 62 |
+
left_iris = landmarks[_LEFT_IRIS_CENTER, :2]
|
| 63 |
+
left_inner = landmarks[_LEFT_EYE_INNER, :2]
|
| 64 |
+
left_outer = landmarks[_LEFT_EYE_OUTER, :2]
|
| 65 |
+
left_top = landmarks[_LEFT_EYE_TOP, :2]
|
| 66 |
+
left_bottom = landmarks[_LEFT_EYE_BOTTOM, :2]
|
| 67 |
+
|
| 68 |
+
right_iris = landmarks[_RIGHT_IRIS_CENTER, :2]
|
| 69 |
+
right_inner = landmarks[_RIGHT_EYE_INNER, :2]
|
| 70 |
+
right_outer = landmarks[_RIGHT_EYE_OUTER, :2]
|
| 71 |
+
right_top = landmarks[_RIGHT_EYE_TOP, :2]
|
| 72 |
+
right_bottom = landmarks[_RIGHT_EYE_BOTTOM, :2]
|
| 73 |
+
|
| 74 |
+
left_h_total = _distance(left_inner, left_outer)
|
| 75 |
+
right_h_total = _distance(right_inner, right_outer)
|
| 76 |
+
|
| 77 |
+
if left_h_total < 1e-6 or right_h_total < 1e-6:
|
| 78 |
+
return 0.5, 0.5
|
| 79 |
+
|
| 80 |
+
left_h_ratio = _distance(left_outer, left_iris) / left_h_total
|
| 81 |
+
right_h_ratio = _distance(right_outer, right_iris) / right_h_total
|
| 82 |
+
h_ratio = (left_h_ratio + right_h_ratio) / 2.0
|
| 83 |
+
|
| 84 |
+
left_v_total = _distance(left_top, left_bottom)
|
| 85 |
+
right_v_total = _distance(right_top, right_bottom)
|
| 86 |
+
|
| 87 |
+
if left_v_total < 1e-6 or right_v_total < 1e-6:
|
| 88 |
+
return h_ratio, 0.5
|
| 89 |
+
|
| 90 |
+
left_v_ratio = _distance(left_top, left_iris) / left_v_total
|
| 91 |
+
right_v_ratio = _distance(right_top, right_iris) / right_v_total
|
| 92 |
+
v_ratio = (left_v_ratio + right_v_ratio) / 2.0
|
| 93 |
+
|
| 94 |
+
return float(np.clip(h_ratio, 0, 1)), float(np.clip(v_ratio, 0, 1))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def compute_mar(landmarks: np.ndarray) -> float:
|
| 98 |
+
top = landmarks[_MOUTH_TOP, :2]
|
| 99 |
+
bottom = landmarks[_MOUTH_BOTTOM, :2]
|
| 100 |
+
left = landmarks[_MOUTH_LEFT, :2]
|
| 101 |
+
right = landmarks[_MOUTH_RIGHT, :2]
|
| 102 |
+
upper1 = landmarks[_MOUTH_UPPER_1, :2]
|
| 103 |
+
lower1 = landmarks[_MOUTH_LOWER_1, :2]
|
| 104 |
+
upper2 = landmarks[_MOUTH_UPPER_2, :2]
|
| 105 |
+
lower2 = landmarks[_MOUTH_LOWER_2, :2]
|
| 106 |
+
|
| 107 |
+
horizontal = _distance(left, right)
|
| 108 |
+
if horizontal < 1e-6:
|
| 109 |
+
return 0.0
|
| 110 |
+
v1 = _distance(upper1, lower1)
|
| 111 |
+
v2 = _distance(top, bottom)
|
| 112 |
+
v3 = _distance(upper2, lower2)
|
| 113 |
+
return (v1 + v2 + v3) / (2.0 * horizontal)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class EyeBehaviourScorer:
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
ear_open: float = 0.30,
|
| 120 |
+
ear_closed: float = 0.16,
|
| 121 |
+
gaze_max_offset: float = 0.28,
|
| 122 |
+
):
|
| 123 |
+
self.ear_open = ear_open
|
| 124 |
+
self.ear_closed = ear_closed
|
| 125 |
+
self.gaze_max_offset = gaze_max_offset
|
| 126 |
+
|
| 127 |
+
def _ear_score(self, ear: float) -> float:
|
| 128 |
+
if ear >= self.ear_open:
|
| 129 |
+
return 1.0
|
| 130 |
+
if ear <= self.ear_closed:
|
| 131 |
+
return 0.0
|
| 132 |
+
return (ear - self.ear_closed) / (self.ear_open - self.ear_closed)
|
| 133 |
+
|
| 134 |
+
def _gaze_score(self, h_ratio: float, v_ratio: float) -> float:
|
| 135 |
+
h_offset = abs(h_ratio - 0.5)
|
| 136 |
+
v_offset = abs(v_ratio - 0.5)
|
| 137 |
+
offset = math.sqrt(h_offset**2 + v_offset**2)
|
| 138 |
+
t = min(offset / self.gaze_max_offset, 1.0)
|
| 139 |
+
return 0.5 * (1.0 + math.cos(math.pi * t))
|
| 140 |
+
|
| 141 |
+
def score(self, landmarks: np.ndarray) -> float:
|
| 142 |
+
left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
|
| 143 |
+
right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
|
| 144 |
+
# Use minimum EAR so closing ONE eye is enough to drop the score
|
| 145 |
+
ear = min(left_ear, right_ear)
|
| 146 |
+
ear_s = self._ear_score(ear)
|
| 147 |
+
if ear_s < 0.3:
|
| 148 |
+
return ear_s
|
| 149 |
+
h_ratio, v_ratio = compute_gaze_ratio(landmarks)
|
| 150 |
+
gaze_s = self._gaze_score(h_ratio, v_ratio)
|
| 151 |
+
return ear_s * gaze_s
|
| 152 |
+
|
| 153 |
+
def detailed_score(self, landmarks: np.ndarray) -> dict:
|
| 154 |
+
left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
|
| 155 |
+
right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
|
| 156 |
+
ear = min(left_ear, right_ear)
|
| 157 |
+
ear_s = self._ear_score(ear)
|
| 158 |
+
h_ratio, v_ratio = compute_gaze_ratio(landmarks)
|
| 159 |
+
gaze_s = self._gaze_score(h_ratio, v_ratio)
|
| 160 |
+
s_eye = ear_s if ear_s < 0.3 else ear_s * gaze_s
|
| 161 |
+
return {
|
| 162 |
+
"ear": round(ear, 4),
|
| 163 |
+
"ear_score": round(ear_s, 4),
|
| 164 |
+
"h_gaze": round(h_ratio, 4),
|
| 165 |
+
"v_gaze": round(v_ratio, 4),
|
| 166 |
+
"gaze_score": round(gaze_s, 4),
|
| 167 |
+
"s_eye": round(s_eye, 4),
|
| 168 |
+
}
|
models/face_mesh.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from urllib.request import urlretrieve
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import mediapipe as mp
|
| 9 |
+
from mediapipe.tasks.python.vision import FaceLandmarkerOptions, FaceLandmarker, RunningMode
|
| 10 |
+
from mediapipe.tasks import python as mp_tasks
|
| 11 |
+
|
| 12 |
+
_MODEL_URL = (
|
| 13 |
+
"https://storage.googleapis.com/mediapipe-models/face_landmarker/"
|
| 14 |
+
"face_landmarker/float16/latest/face_landmarker.task"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _ensure_model() -> str:
|
| 19 |
+
cache_dir = Path(os.environ.get(
|
| 20 |
+
"FOCUSGUARD_CACHE_DIR",
|
| 21 |
+
Path.home() / ".cache" / "focusguard",
|
| 22 |
+
))
|
| 23 |
+
model_path = cache_dir / "face_landmarker.task"
|
| 24 |
+
if model_path.exists():
|
| 25 |
+
return str(model_path)
|
| 26 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 27 |
+
print(f"[FACE_MESH] Downloading model to {model_path}...")
|
| 28 |
+
urlretrieve(_MODEL_URL, model_path)
|
| 29 |
+
print("[FACE_MESH] Download complete.")
|
| 30 |
+
return str(model_path)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FaceMeshDetector:
|
| 34 |
+
LEFT_EYE_INDICES = [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]
|
| 35 |
+
RIGHT_EYE_INDICES = [362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398]
|
| 36 |
+
LEFT_IRIS_INDICES = [468, 469, 470, 471, 472]
|
| 37 |
+
RIGHT_IRIS_INDICES = [473, 474, 475, 476, 477]
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
max_num_faces: int = 1,
|
| 42 |
+
min_detection_confidence: float = 0.5,
|
| 43 |
+
min_tracking_confidence: float = 0.5,
|
| 44 |
+
):
|
| 45 |
+
model_path = _ensure_model()
|
| 46 |
+
options = FaceLandmarkerOptions(
|
| 47 |
+
base_options=mp_tasks.BaseOptions(model_asset_path=model_path),
|
| 48 |
+
num_faces=max_num_faces,
|
| 49 |
+
min_face_detection_confidence=min_detection_confidence,
|
| 50 |
+
min_face_presence_confidence=min_detection_confidence,
|
| 51 |
+
min_tracking_confidence=min_tracking_confidence,
|
| 52 |
+
running_mode=RunningMode.VIDEO,
|
| 53 |
+
)
|
| 54 |
+
self._landmarker = FaceLandmarker.create_from_options(options)
|
| 55 |
+
self._t0 = time.monotonic()
|
| 56 |
+
self._last_ts = 0
|
| 57 |
+
|
| 58 |
+
def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
|
| 59 |
+
# BGR in -> (478,3) norm x,y,z or None
|
| 60 |
+
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 61 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
|
| 62 |
+
ts = max(int((time.monotonic() - self._t0) * 1000), self._last_ts + 1)
|
| 63 |
+
self._last_ts = ts
|
| 64 |
+
result = self._landmarker.detect_for_video(mp_image, ts)
|
| 65 |
+
|
| 66 |
+
if not result.face_landmarks:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
face = result.face_landmarks[0]
|
| 70 |
+
return np.array([(lm.x, lm.y, lm.z) for lm in face], dtype=np.float32)
|
| 71 |
+
|
| 72 |
+
def get_pixel_landmarks(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> np.ndarray:
|
| 73 |
+
# norm -> pixel (x,y)
|
| 74 |
+
pixel = np.zeros((landmarks.shape[0], 2), dtype=np.int32)
|
| 75 |
+
pixel[:, 0] = (landmarks[:, 0] * frame_w).astype(np.int32)
|
| 76 |
+
pixel[:, 1] = (landmarks[:, 1] * frame_h).astype(np.int32)
|
| 77 |
+
return pixel
|
| 78 |
+
|
| 79 |
+
def get_3d_landmarks(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> np.ndarray:
|
| 80 |
+
# norm -> pixel-scale x,y,z (z scaled by width)
|
| 81 |
+
pts = np.zeros_like(landmarks)
|
| 82 |
+
pts[:, 0] = landmarks[:, 0] * frame_w
|
| 83 |
+
pts[:, 1] = landmarks[:, 1] * frame_h
|
| 84 |
+
pts[:, 2] = landmarks[:, 2] * frame_w
|
| 85 |
+
return pts
|
| 86 |
+
|
| 87 |
+
def close(self):
|
| 88 |
+
self._landmarker.close()
|
| 89 |
+
|
| 90 |
+
def __enter__(self):
|
| 91 |
+
return self
|
| 92 |
+
|
| 93 |
+
def __exit__(self, *args):
|
| 94 |
+
self.close()
|
models/head_pose.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
_LANDMARK_INDICES = [1, 152, 33, 263, 61, 291]
|
| 7 |
+
|
| 8 |
+
_MODEL_POINTS = np.array(
|
| 9 |
+
[
|
| 10 |
+
[0.0, 0.0, 0.0],
|
| 11 |
+
[0.0, -330.0, -65.0],
|
| 12 |
+
[-225.0, 170.0, -135.0],
|
| 13 |
+
[225.0, 170.0, -135.0],
|
| 14 |
+
[-150.0, -150.0, -125.0],
|
| 15 |
+
[150.0, -150.0, -125.0],
|
| 16 |
+
],
|
| 17 |
+
dtype=np.float64,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HeadPoseEstimator:
|
| 22 |
+
def __init__(self, max_angle: float = 30.0, roll_weight: float = 0.5):
|
| 23 |
+
self.max_angle = max_angle
|
| 24 |
+
self.roll_weight = roll_weight
|
| 25 |
+
self._camera_matrix = None
|
| 26 |
+
self._frame_size = None
|
| 27 |
+
self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
|
| 28 |
+
self._cache_key = None
|
| 29 |
+
self._cache_result = None
|
| 30 |
+
|
| 31 |
+
def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
|
| 32 |
+
if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
|
| 33 |
+
return self._camera_matrix
|
| 34 |
+
focal_length = float(frame_w)
|
| 35 |
+
cx, cy = frame_w / 2.0, frame_h / 2.0
|
| 36 |
+
self._camera_matrix = np.array(
|
| 37 |
+
[[focal_length, 0, cx], [0, focal_length, cy], [0, 0, 1]],
|
| 38 |
+
dtype=np.float64,
|
| 39 |
+
)
|
| 40 |
+
self._frame_size = (frame_w, frame_h)
|
| 41 |
+
return self._camera_matrix
|
| 42 |
+
|
| 43 |
+
def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
|
| 44 |
+
key = (landmarks.data.tobytes(), frame_w, frame_h)
|
| 45 |
+
if self._cache_key == key:
|
| 46 |
+
return self._cache_result
|
| 47 |
+
|
| 48 |
+
image_points = np.array(
|
| 49 |
+
[
|
| 50 |
+
[landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
|
| 51 |
+
for i in _LANDMARK_INDICES
|
| 52 |
+
],
|
| 53 |
+
dtype=np.float64,
|
| 54 |
+
)
|
| 55 |
+
camera_matrix = self._get_camera_matrix(frame_w, frame_h)
|
| 56 |
+
success, rvec, tvec = cv2.solvePnP(
|
| 57 |
+
_MODEL_POINTS,
|
| 58 |
+
image_points,
|
| 59 |
+
camera_matrix,
|
| 60 |
+
self._dist_coeffs,
|
| 61 |
+
flags=cv2.SOLVEPNP_ITERATIVE,
|
| 62 |
+
)
|
| 63 |
+
result = (success, rvec, tvec, image_points)
|
| 64 |
+
self._cache_key = key
|
| 65 |
+
self._cache_result = result
|
| 66 |
+
return result
|
| 67 |
+
|
| 68 |
+
def estimate(
|
| 69 |
+
self, landmarks: np.ndarray, frame_w: int, frame_h: int
|
| 70 |
+
) -> tuple[float, float, float] | None:
|
| 71 |
+
success, rvec, tvec, _ = self._solve(landmarks, frame_w, frame_h)
|
| 72 |
+
if not success:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
rmat, _ = cv2.Rodrigues(rvec)
|
| 76 |
+
nose_dir = rmat @ np.array([0.0, 0.0, 1.0])
|
| 77 |
+
face_up = rmat @ np.array([0.0, 1.0, 0.0])
|
| 78 |
+
|
| 79 |
+
yaw = math.degrees(math.atan2(nose_dir[0], -nose_dir[2]))
|
| 80 |
+
pitch = math.degrees(math.asin(np.clip(-nose_dir[1], -1.0, 1.0)))
|
| 81 |
+
roll = math.degrees(math.atan2(face_up[0], -face_up[1]))
|
| 82 |
+
|
| 83 |
+
return (yaw, pitch, roll)
|
| 84 |
+
|
| 85 |
+
def score(self, landmarks: np.ndarray, frame_w: int, frame_h: int) -> float:
|
| 86 |
+
angles = self.estimate(landmarks, frame_w, frame_h)
|
| 87 |
+
if angles is None:
|
| 88 |
+
return 0.0
|
| 89 |
+
|
| 90 |
+
yaw, pitch, roll = angles
|
| 91 |
+
deviation = math.sqrt(yaw**2 + pitch**2 + (self.roll_weight * roll) ** 2)
|
| 92 |
+
t = min(deviation / self.max_angle, 1.0)
|
| 93 |
+
return 0.5 * (1.0 + math.cos(math.pi * t))
|
| 94 |
+
|
| 95 |
+
def draw_axes(
|
| 96 |
+
self,
|
| 97 |
+
frame: np.ndarray,
|
| 98 |
+
landmarks: np.ndarray,
|
| 99 |
+
axis_length: float = 50.0,
|
| 100 |
+
) -> np.ndarray:
|
| 101 |
+
h, w = frame.shape[:2]
|
| 102 |
+
success, rvec, tvec, image_points = self._solve(landmarks, w, h)
|
| 103 |
+
if not success:
|
| 104 |
+
return frame
|
| 105 |
+
|
| 106 |
+
camera_matrix = self._get_camera_matrix(w, h)
|
| 107 |
+
nose = tuple(image_points[0].astype(int))
|
| 108 |
+
|
| 109 |
+
axes_3d = np.float64(
|
| 110 |
+
[[axis_length, 0, 0], [0, axis_length, 0], [0, 0, axis_length]]
|
| 111 |
+
)
|
| 112 |
+
projected, _ = cv2.projectPoints(
|
| 113 |
+
axes_3d, rvec, tvec, camera_matrix, self._dist_coeffs
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]
|
| 117 |
+
for i, color in enumerate(colors):
|
| 118 |
+
pt = tuple(projected[i].ravel().astype(int))
|
| 119 |
+
cv2.line(frame, nose, pt, color, 2)
|
| 120 |
+
|
| 121 |
+
return frame
|
models/mlp/__init__.py
ADDED
|
File without changes
|
models/mlp/eval_accuracy.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load saved MLP checkpoint and print test accuracy, F1, AUC."""
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from sklearn.metrics import f1_score, roc_auc_score
|
| 8 |
+
|
| 9 |
+
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 10 |
+
if REPO_ROOT not in sys.path:
|
| 11 |
+
sys.path.insert(0, REPO_ROOT)
|
| 12 |
+
|
| 13 |
+
from data_preparation.prepare_dataset import get_dataloaders
|
| 14 |
+
from models.mlp.train import BaseModel
|
| 15 |
+
|
| 16 |
+
CKPT_PATH = os.path.join(REPO_ROOT, "checkpoints", "mlp_best.pt")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
if not os.path.isfile(CKPT_PATH):
|
| 21 |
+
print(f"No checkpoint at {CKPT_PATH}. Train first: python -m models.mlp.train")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
train_loader, val_loader, test_loader, num_features, num_classes, _ = get_dataloaders(
|
| 26 |
+
model_name="face_orientation",
|
| 27 |
+
batch_size=32,
|
| 28 |
+
split_ratios=(0.7, 0.15, 0.15),
|
| 29 |
+
seed=42,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
model = BaseModel(num_features, num_classes).to(device)
|
| 33 |
+
model.load_state_dict(torch.load(CKPT_PATH, map_location=device, weights_only=True))
|
| 34 |
+
model.eval()
|
| 35 |
+
|
| 36 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 37 |
+
test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(
|
| 38 |
+
test_loader, criterion, device
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
f1 = float(f1_score(test_labels, test_preds, average="weighted"))
|
| 42 |
+
if num_classes > 2:
|
| 43 |
+
auc = float(roc_auc_score(test_labels, test_probs, multi_class="ovr", average="weighted"))
|
| 44 |
+
else:
|
| 45 |
+
auc = float(roc_auc_score(test_labels, test_probs[:, 1]))
|
| 46 |
+
|
| 47 |
+
print("MLP (face_orientation) — test set")
|
| 48 |
+
print(" Accuracy: {:.2%}".format(test_acc))
|
| 49 |
+
print(" F1: {:.4f}".format(f1))
|
| 50 |
+
print(" ROC-AUC: {:.4f}".format(auc))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
models/mlp/sweep.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLP Hyperparameter Sweep (Optuna).
|
| 3 |
+
Run: python -m models.mlp.sweep
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import optuna
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
|
| 13 |
+
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 14 |
+
if REPO_ROOT not in sys.path:
|
| 15 |
+
sys.path.insert(0, REPO_ROOT)
|
| 16 |
+
|
| 17 |
+
from data_preparation.prepare_dataset import get_dataloaders
|
| 18 |
+
from models.mlp.train import BaseModel, set_seed
|
| 19 |
+
|
| 20 |
+
SEED = 42
|
| 21 |
+
N_TRIALS = 20
|
| 22 |
+
EPOCHS_PER_TRIAL = 15
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def objective(trial):
|
| 26 |
+
set_seed(SEED)
|
| 27 |
+
|
| 28 |
+
lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
|
| 29 |
+
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 30 |
+
|
| 31 |
+
train_loader, val_loader, _, num_features, num_classes, _ = get_dataloaders(
|
| 32 |
+
model_name="face_orientation",
|
| 33 |
+
batch_size=batch_size,
|
| 34 |
+
split_ratios=(0.7, 0.15, 0.15),
|
| 35 |
+
seed=SEED,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
model = BaseModel(num_features, num_classes).to(device)
|
| 40 |
+
criterion = nn.CrossEntropyLoss()
|
| 41 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 42 |
+
|
| 43 |
+
best_val_acc = 0.0
|
| 44 |
+
for epoch in range(1, EPOCHS_PER_TRIAL + 1):
|
| 45 |
+
model.training_step(train_loader, optimizer, criterion, device)
|
| 46 |
+
val_loss, val_acc = model.validation_step(val_loader, criterion, device)
|
| 47 |
+
if val_acc > best_val_acc:
|
| 48 |
+
best_val_acc = val_acc
|
| 49 |
+
|
| 50 |
+
return 1.0 - best_val_acc # minimize (1 - accuracy)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
study = optuna.create_study(direction="minimize", study_name="mlp_sweep")
|
| 55 |
+
print(f"[SWEEP] MLP Optuna sweep: {N_TRIALS} trials, {EPOCHS_PER_TRIAL} epochs each")
|
| 56 |
+
study.optimize(objective, n_trials=N_TRIALS)
|
| 57 |
+
|
| 58 |
+
print("\n[SWEEP] Top-5 trials by validation accuracy")
|
| 59 |
+
best = sorted(study.trials, key=lambda t: t.value if t.value is not None else float("inf"))[:5]
|
| 60 |
+
for i, t in enumerate(best, 1):
|
| 61 |
+
acc = (1.0 - t.value) * 100
|
| 62 |
+
print(f" #{i} Val Acc: {acc:.2f}% params={t.params}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
main()
|
models/mlp/train.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os, sys
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from sklearn.metrics import f1_score, roc_auc_score
|
| 10 |
+
|
| 11 |
+
from data_preparation.prepare_dataset import get_dataloaders
|
| 12 |
+
|
| 13 |
+
USE_CLEARML = False
|
| 14 |
+
|
| 15 |
+
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 16 |
+
CFG = {
|
| 17 |
+
"model_name": "face_orientation",
|
| 18 |
+
"epochs": 30,
|
| 19 |
+
"batch_size": 32,
|
| 20 |
+
"lr": 1e-3,
|
| 21 |
+
"seed": 42,
|
| 22 |
+
"split_ratios": (0.7, 0.15, 0.15),
|
| 23 |
+
"checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
|
| 24 |
+
"logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ==== ClearML (opt-in) =============================================
|
| 29 |
+
task = None
|
| 30 |
+
if USE_CLEARML:
|
| 31 |
+
from clearml import Task
|
| 32 |
+
task = Task.init(
|
| 33 |
+
project_name="Focus Guard",
|
| 34 |
+
task_name="MLP Model Training",
|
| 35 |
+
tags=["training", "mlp_model"]
|
| 36 |
+
)
|
| 37 |
+
task.connect(CFG)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ==== Model =============================================
|
| 42 |
+
def set_seed(seed: int):
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
if torch.cuda.is_available():
|
| 47 |
+
torch.cuda.manual_seed_all(seed)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BaseModel(nn.Module):
|
| 51 |
+
def __init__(self, num_features: int, num_classes: int):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.network = nn.Sequential(
|
| 54 |
+
nn.Linear(num_features, 64),
|
| 55 |
+
nn.ReLU(),
|
| 56 |
+
nn.Linear(64, 32),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
nn.Linear(32, num_classes),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
return self.network(x)
|
| 63 |
+
|
| 64 |
+
def training_step(self, loader, optimizer, criterion, device):
|
| 65 |
+
self.train()
|
| 66 |
+
total_loss = 0.0
|
| 67 |
+
correct = 0
|
| 68 |
+
total = 0
|
| 69 |
+
|
| 70 |
+
for features, labels in loader:
|
| 71 |
+
features, labels = features.to(device), labels.to(device)
|
| 72 |
+
|
| 73 |
+
optimizer.zero_grad()
|
| 74 |
+
outputs = self(features)
|
| 75 |
+
loss = criterion(outputs, labels)
|
| 76 |
+
loss.backward()
|
| 77 |
+
optimizer.step()
|
| 78 |
+
|
| 79 |
+
total_loss += loss.item() * features.size(0)
|
| 80 |
+
correct += (outputs.argmax(dim=1) == labels).sum().item()
|
| 81 |
+
total += features.size(0)
|
| 82 |
+
|
| 83 |
+
return total_loss / total, correct / total
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def validation_step(self, loader, criterion, device):
|
| 87 |
+
self.eval()
|
| 88 |
+
total_loss = 0.0
|
| 89 |
+
correct = 0
|
| 90 |
+
total = 0
|
| 91 |
+
|
| 92 |
+
for features, labels in loader:
|
| 93 |
+
features, labels = features.to(device), labels.to(device)
|
| 94 |
+
outputs = self(features)
|
| 95 |
+
loss = criterion(outputs, labels)
|
| 96 |
+
|
| 97 |
+
total_loss += loss.item() * features.size(0)
|
| 98 |
+
correct += (outputs.argmax(dim=1) == labels).sum().item()
|
| 99 |
+
total += features.size(0)
|
| 100 |
+
|
| 101 |
+
return total_loss / total, correct / total
|
| 102 |
+
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def test_step(self, loader, criterion, device):
|
| 105 |
+
self.eval()
|
| 106 |
+
total_loss = 0.0
|
| 107 |
+
correct = 0
|
| 108 |
+
total = 0
|
| 109 |
+
|
| 110 |
+
all_preds = []
|
| 111 |
+
all_labels = []
|
| 112 |
+
all_probs = []
|
| 113 |
+
|
| 114 |
+
for features, labels in loader:
|
| 115 |
+
features, labels = features.to(device), labels.to(device)
|
| 116 |
+
outputs = self(features)
|
| 117 |
+
loss = criterion(outputs, labels)
|
| 118 |
+
|
| 119 |
+
total_loss += loss.item() * features.size(0)
|
| 120 |
+
preds = outputs.argmax(dim=1)
|
| 121 |
+
correct += (preds == labels).sum().item()
|
| 122 |
+
total += features.size(0)
|
| 123 |
+
|
| 124 |
+
probs = torch.softmax(outputs, dim=1)
|
| 125 |
+
all_preds.extend(preds.cpu().numpy())
|
| 126 |
+
all_labels.extend(labels.cpu().numpy())
|
| 127 |
+
all_probs.extend(probs.cpu().numpy())
|
| 128 |
+
|
| 129 |
+
return total_loss / total, correct / total, np.array(all_probs), np.array(all_preds), np.array(all_labels)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
set_seed(CFG["seed"])
|
| 134 |
+
|
| 135 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 136 |
+
print(f"[TRAIN] Device: {device}")
|
| 137 |
+
print(f"[TRAIN] Model: {CFG['model_name']}")
|
| 138 |
+
|
| 139 |
+
train_loader, val_loader, test_loader, num_features, num_classes, scaler = get_dataloaders(
|
| 140 |
+
model_name=CFG["model_name"],
|
| 141 |
+
batch_size=CFG["batch_size"],
|
| 142 |
+
split_ratios=CFG["split_ratios"],
|
| 143 |
+
seed=CFG["seed"],
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
model = BaseModel(num_features, num_classes).to(device)
|
| 147 |
+
criterion = nn.CrossEntropyLoss()
|
| 148 |
+
optimizer = optim.Adam(model.parameters(), lr=CFG["lr"])
|
| 149 |
+
|
| 150 |
+
param_count = sum(p.numel() for p in model.parameters())
|
| 151 |
+
print(f"[TRAIN] Parameters: {param_count:,}")
|
| 152 |
+
|
| 153 |
+
ckpt_dir = CFG["checkpoints_dir"]
|
| 154 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 155 |
+
best_ckpt_path = os.path.join(ckpt_dir, "mlp_best.pt")
|
| 156 |
+
|
| 157 |
+
history = {
|
| 158 |
+
"model_name": CFG["model_name"],
|
| 159 |
+
"param_count": param_count,
|
| 160 |
+
"epochs": [],
|
| 161 |
+
"train_loss": [],
|
| 162 |
+
"train_acc": [],
|
| 163 |
+
"val_loss": [],
|
| 164 |
+
"val_acc": [],
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
best_val_acc = 0.0
|
| 168 |
+
|
| 169 |
+
print(f"\n{'Epoch':>6} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>10} | {'Val Acc':>9}")
|
| 170 |
+
print("-" * 60)
|
| 171 |
+
|
| 172 |
+
for epoch in range(1, CFG["epochs"] + 1):
|
| 173 |
+
train_loss, train_acc = model.training_step(train_loader, optimizer, criterion, device)
|
| 174 |
+
val_loss, val_acc = model.validation_step(val_loader, criterion, device)
|
| 175 |
+
|
| 176 |
+
history["epochs"].append(epoch)
|
| 177 |
+
history["train_loss"].append(round(train_loss, 4))
|
| 178 |
+
history["train_acc"].append(round(train_acc, 4))
|
| 179 |
+
history["val_loss"].append(round(val_loss, 4))
|
| 180 |
+
history["val_acc"].append(round(val_acc, 4))
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 184 |
+
if task is not None:
|
| 185 |
+
task.logger.report_scalar("Loss", "Train", float(train_loss), iteration=epoch)
|
| 186 |
+
task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
|
| 187 |
+
task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
|
| 188 |
+
task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
|
| 189 |
+
task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
|
| 190 |
+
task.logger.flush()
|
| 191 |
+
|
| 192 |
+
marker = ""
|
| 193 |
+
if val_acc > best_val_acc:
|
| 194 |
+
best_val_acc = val_acc
|
| 195 |
+
torch.save(model.state_dict(), best_ckpt_path)
|
| 196 |
+
marker = " *"
|
| 197 |
+
|
| 198 |
+
print(f"{epoch:>6} | {train_loss:>10.4f} | {train_acc:>8.2%} | {val_loss:>10.4f} | {val_acc:>8.2%}{marker}")
|
| 199 |
+
|
| 200 |
+
print(f"\nBest validation accuracy: {best_val_acc:.2%}")
|
| 201 |
+
print(f"Checkpoint saved to: {best_ckpt_path}")
|
| 202 |
+
|
| 203 |
+
model.load_state_dict(torch.load(best_ckpt_path, weights_only=True))
|
| 204 |
+
test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(test_loader, criterion, device)
|
| 205 |
+
|
| 206 |
+
test_f1 = f1_score(test_labels, test_preds, average='weighted')
|
| 207 |
+
# Handle potentially >2 classes for AUC
|
| 208 |
+
if num_classes > 2:
|
| 209 |
+
test_auc = roc_auc_score(test_labels, test_probs, multi_class='ovr', average='weighted')
|
| 210 |
+
else:
|
| 211 |
+
test_auc = roc_auc_score(test_labels, test_probs[:, 1])
|
| 212 |
+
|
| 213 |
+
print(f"\n[TEST] Loss: {test_loss:.4f} | Accuracy: {test_acc:.2%}")
|
| 214 |
+
print(f"[TEST] F1: {test_f1:.4f} | ROC-AUC: {test_auc:.4f}")
|
| 215 |
+
|
| 216 |
+
history["test_loss"] = round(test_loss, 4)
|
| 217 |
+
history["test_acc"] = round(test_acc, 4)
|
| 218 |
+
history["test_f1"] = round(test_f1, 4)
|
| 219 |
+
history["test_auc"] = round(test_auc, 4)
|
| 220 |
+
|
| 221 |
+
logs_dir = CFG["logs_dir"]
|
| 222 |
+
os.makedirs(logs_dir, exist_ok=True)
|
| 223 |
+
log_path = os.path.join(logs_dir, f"{CFG['model_name']}_training_log.json")
|
| 224 |
+
|
| 225 |
+
with open(log_path, "w") as f:
|
| 226 |
+
json.dump(history, f, indent=2)
|
| 227 |
+
|
| 228 |
+
print(f"[LOG] Training history saved to: {log_path}")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
main()
|
models/xgboost/add_accuracy.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from xgboost import XGBClassifier
|
| 4 |
+
from sklearn.metrics import accuracy_score
|
| 5 |
+
from data_preparation.prepare_dataset import get_numpy_splits
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
print("Loading dataset for evaluation...")
|
| 9 |
+
splits, _, _, _ = get_numpy_splits(
|
| 10 |
+
model_name="face_orientation",
|
| 11 |
+
split_ratios=(0.7, 0.15, 0.15),
|
| 12 |
+
seed=42,
|
| 13 |
+
scale=False
|
| 14 |
+
)
|
| 15 |
+
X_train, y_train = splits["X_train"], splits["y_train"]
|
| 16 |
+
X_val, y_val = splits["X_val"], splits["y_val"]
|
| 17 |
+
|
| 18 |
+
csv_path = 'models/xgboost/sweep_results_all_40.csv'
|
| 19 |
+
df = pd.read_csv(csv_path)
|
| 20 |
+
|
| 21 |
+
# We will calculate accuracy for each row
|
| 22 |
+
accuracies = []
|
| 23 |
+
|
| 24 |
+
print(f"Re-evaluating {len(df)} configurations for accuracy. This will take a few minutes...")
|
| 25 |
+
for idx, row in df.iterrows():
|
| 26 |
+
params = {
|
| 27 |
+
"n_estimators": int(row["n_estimators"]),
|
| 28 |
+
"max_depth": int(row["max_depth"]),
|
| 29 |
+
"learning_rate": float(row["learning_rate"]),
|
| 30 |
+
"subsample": float(row["subsample"]),
|
| 31 |
+
"colsample_bytree": float(row["colsample_bytree"]),
|
| 32 |
+
"reg_alpha": float(row["reg_alpha"]),
|
| 33 |
+
"reg_lambda": float(row["reg_lambda"]),
|
| 34 |
+
"random_state": 42,
|
| 35 |
+
"use_label_encoder": False,
|
| 36 |
+
"verbosity": 0,
|
| 37 |
+
"eval_metric": "logloss"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Train the exact same model quickly
|
| 41 |
+
model = XGBClassifier(**params)
|
| 42 |
+
model.fit(X_train, y_train)
|
| 43 |
+
|
| 44 |
+
# Get validation predictions and calculate accuracy
|
| 45 |
+
val_preds = model.predict(X_val)
|
| 46 |
+
acc = accuracy_score(y_val, val_preds)
|
| 47 |
+
accuracies.append(round(acc, 4))
|
| 48 |
+
|
| 49 |
+
if (idx + 1) % 5 == 0:
|
| 50 |
+
print(f"Processed {idx + 1}/{len(df)} trials...")
|
| 51 |
+
|
| 52 |
+
# Add accuracy column and save back to CSV
|
| 53 |
+
df.insert(2, 'val_accuracy', accuracies)
|
| 54 |
+
df.to_csv(csv_path, index=False)
|
| 55 |
+
|
| 56 |
+
print(f"\nDone! Updated {csv_path} with 'val_accuracy'.")
|
| 57 |
+
# Display the top 5 by accuracy now just to see
|
| 58 |
+
top5_acc = df.nlargest(5, 'val_accuracy')[['task_id', 'val_accuracy', 'val_f1', 'val_loss']]
|
| 59 |
+
print("\nTop 5 Trials by Accuracy:")
|
| 60 |
+
print(top5_acc.to_string(index=False))
|
models/xgboost/checkpoints/face_orientation_best.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/xgboost/eval_accuracy.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load saved XGBoost checkpoint and print test accuracy, F1, AUC."""
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.metrics import f1_score, roc_auc_score
|
| 7 |
+
from xgboost import XGBClassifier
|
| 8 |
+
|
| 9 |
+
# run from repo root: python -m models.xgboost.eval_accuracy
|
| 10 |
+
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 11 |
+
if REPO_ROOT not in sys.path:
|
| 12 |
+
sys.path.insert(0, REPO_ROOT)
|
| 13 |
+
|
| 14 |
+
from data_preparation.prepare_dataset import get_numpy_splits
|
| 15 |
+
|
| 16 |
+
MODEL_NAME = "face_orientation"
|
| 17 |
+
CKPT_DIR = os.path.join(REPO_ROOT, "checkpoints")
|
| 18 |
+
MODEL_PATH = os.path.join(CKPT_DIR, f"xgboost_{MODEL_NAME}_best.json")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
if not os.path.isfile(MODEL_PATH):
|
| 23 |
+
print(f"No checkpoint at {MODEL_PATH}. Train first: python -m models.xgboost.train")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
splits, num_features, num_classes, _ = get_numpy_splits(
|
| 27 |
+
model_name=MODEL_NAME,
|
| 28 |
+
split_ratios=(0.7, 0.15, 0.15),
|
| 29 |
+
seed=42,
|
| 30 |
+
scale=False,
|
| 31 |
+
)
|
| 32 |
+
X_test = splits["X_test"]
|
| 33 |
+
y_test = splits["y_test"]
|
| 34 |
+
|
| 35 |
+
model = XGBClassifier()
|
| 36 |
+
model.load_model(MODEL_PATH)
|
| 37 |
+
|
| 38 |
+
preds = model.predict(X_test)
|
| 39 |
+
probs = model.predict_proba(X_test)
|
| 40 |
+
acc = float(np.mean(preds == y_test))
|
| 41 |
+
f1 = float(f1_score(y_test, preds, average="weighted"))
|
| 42 |
+
if num_classes > 2:
|
| 43 |
+
auc = float(roc_auc_score(y_test, probs, multi_class="ovr", average="weighted"))
|
| 44 |
+
else:
|
| 45 |
+
auc = float(roc_auc_score(y_test, probs[:, 1]))
|
| 46 |
+
|
| 47 |
+
print("XGBoost (face_orientation) — test set")
|
| 48 |
+
print(" Accuracy: {:.2%}".format(acc))
|
| 49 |
+
print(" F1: {:.4f}".format(f1))
|
| 50 |
+
print(" ROC-AUC: {:.4f}".format(auc))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|